Added scopes

This commit is contained in:
Rafał Grodziński
2025-07-05 18:42:58 +09:00
parent 2e0015c9be
commit da1b5852ff
3 changed files with 85 additions and 15 deletions

View File

@@ -21,7 +21,7 @@ printNum fun: number sint32
; ;
; ;
// Print 20 first fibonacci numbers // Print first 20 fibonaci numbers
main fun -> sint32 main fun -> sint32
rep i sint32 <- 0, i < 20: rep i sint32 <- 0, i < 20:
res sint32 <- fib(i) res sint32 <- fib(i)

View File

@@ -33,9 +33,9 @@ moduleName(moduleName), sourceFileName(sourceFileName), statements(statements) {
} }
shared_ptr<llvm::Module> ModuleBuilder::getModule() { shared_ptr<llvm::Module> ModuleBuilder::getModule() {
for (shared_ptr<Statement> &statement : statements) { scopes.push(Scope());
for (shared_ptr<Statement> &statement : statements)
buildStatement(statement); buildStatement(statement);
}
return module; return module;
} }
@@ -80,12 +80,15 @@ void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunction> state
// build function declaration // build function declaration
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false); llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::ExternalLinkage, statement->getName(), module.get()); llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::ExternalLinkage, statement->getName(), module.get());
funMap[statement->getName()] = fun; if (!setFun(statement->getName(), fun))
return;
// define function body // define function body
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun); llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block); builder->SetInsertPoint(block);
scopes.push(Scope());
// build arguments // build arguments
int i=0; int i=0;
for (auto &arg : fun->args()) { for (auto &arg : fun->args()) {
@@ -94,7 +97,8 @@ void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunction> state
arg.setName(name); arg.setName(name);
llvm::AllocaInst *alloca = builder->CreateAlloca(type, nullptr, name); llvm::AllocaInst *alloca = builder->CreateAlloca(type, nullptr, name);
allocaMap[name] = alloca; if (!setAlloca(name, alloca))
return;
builder->CreateStore(&arg, alloca); builder->CreateStore(&arg, alloca);
i++; i++;
@@ -103,6 +107,8 @@ void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunction> state
// build function body // build function body
buildStatement(statement->getStatementBlock()); buildStatement(statement->getStatementBlock());
scopes.pop();
// verify // verify
string errorMessage; string errorMessage;
llvm::raw_string_ostream llvmErrorMessage(errorMessage); llvm::raw_string_ostream llvmErrorMessage(errorMessage);
@@ -113,14 +119,16 @@ void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunction> state
void ModuleBuilder::buildVarDeclaration(shared_ptr<StatementVariable> statement) { void ModuleBuilder::buildVarDeclaration(shared_ptr<StatementVariable> statement) {
llvm::Value *value = valueForExpression(statement->getExpression()); llvm::Value *value = valueForExpression(statement->getExpression());
llvm::AllocaInst *alloca = builder->CreateAlloca(typeForValueType(statement->getValueType()), nullptr, statement->getName()); llvm::AllocaInst *alloca = builder->CreateAlloca(typeForValueType(statement->getValueType()), nullptr, statement->getName());
allocaMap[statement->getName()] = alloca;
if (!setAlloca(statement->getName(), alloca))
return;
builder->CreateStore(value, alloca); builder->CreateStore(value, alloca);
} }
void ModuleBuilder::buildAssignment(shared_ptr<StatementAssignment> statement) { void ModuleBuilder::buildAssignment(shared_ptr<StatementAssignment> statement) {
llvm::AllocaInst *alloca = allocaMap[statement->getName()]; llvm::AllocaInst *alloca = getAlloca(statement->getName());
if (alloca == nullptr) if (alloca == nullptr)
failWithMessage("Variable " + statement->getName() + " not defined"); return;
llvm::Value *value = valueForExpression(statement->getExpression()); llvm::Value *value = valueForExpression(statement->getExpression());
builder->CreateStore(value, alloca); builder->CreateStore(value, alloca);
@@ -151,6 +159,8 @@ void ModuleBuilder::buildLoop(shared_ptr<StatementRepeat> statement) {
llvm::BasicBlock *bodyBlock = llvm::BasicBlock::Create(*context, "loopBody"); llvm::BasicBlock *bodyBlock = llvm::BasicBlock::Create(*context, "loopBody");
llvm::BasicBlock *afterBlock = llvm::BasicBlock::Create(*context, "loopPost"); llvm::BasicBlock *afterBlock = llvm::BasicBlock::Create(*context, "loopPost");
scopes.push(Scope());
// loop init // loop init
if (initStatement != nullptr) if (initStatement != nullptr)
buildStatement(statement->getInitStatement()); buildStatement(statement->getInitStatement());
@@ -181,6 +191,8 @@ void ModuleBuilder::buildLoop(shared_ptr<StatementRepeat> statement) {
// loop post // loop post
fun->insert(fun->end(), afterBlock); fun->insert(fun->end(), afterBlock);
builder->SetInsertPoint(afterBlock); builder->SetInsertPoint(afterBlock);
scopes.pop();
} }
void ModuleBuilder::buildMetaExternFunction(shared_ptr<StatementMetaExternFunction> statement) { void ModuleBuilder::buildMetaExternFunction(shared_ptr<StatementMetaExternFunction> statement) {
@@ -193,7 +205,8 @@ void ModuleBuilder::buildMetaExternFunction(shared_ptr<StatementMetaExternFuncti
// build function declaration // build function declaration
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false); llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::ExternalLinkage, statement->getName(), module.get()); llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::ExternalLinkage, statement->getName(), module.get());
funMap[statement->getName()] = fun; if (!setFun(statement->getName(), fun))
return;
// build arguments // build arguments
int i=0; int i=0;
@@ -343,13 +356,16 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expressi
builder->CreateCondBr(conditionValue, thenBlock, elseBlock); builder->CreateCondBr(conditionValue, thenBlock, elseBlock);
// Then // Then
scopes.push(Scope());
builder->SetInsertPoint(thenBlock); builder->SetInsertPoint(thenBlock);
buildStatement(expression->getThenBlock()->getStatementBlock()); buildStatement(expression->getThenBlock()->getStatementBlock());
llvm::Value *thenValue = valueForExpression(expression->getThenBlock()->getResultStatementExpression()->getExpression()); llvm::Value *thenValue = valueForExpression(expression->getThenBlock()->getResultStatementExpression()->getExpression());
builder->CreateBr(mergeBlock); builder->CreateBr(mergeBlock);
thenBlock = builder->GetInsertBlock(); thenBlock = builder->GetInsertBlock();
scopes.pop();
// Else // Else
scopes.push(Scope());
fun->insert(fun->end(), elseBlock); fun->insert(fun->end(), elseBlock);
builder->SetInsertPoint(elseBlock); builder->SetInsertPoint(elseBlock);
llvm::Value *elseValue = nullptr; llvm::Value *elseValue = nullptr;
@@ -360,6 +376,7 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expressi
} }
builder->CreateBr(mergeBlock); builder->CreateBr(mergeBlock);
elseBlock = builder->GetInsertBlock(); elseBlock = builder->GetInsertBlock();
scopes.pop();
// Merge // Merge
fun->insert(fun->end(), mergeBlock); fun->insert(fun->end(), mergeBlock);
@@ -378,17 +395,17 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expressi
} }
llvm::Value *ModuleBuilder::valueForVar(shared_ptr<ExpressionVariable> expression) { llvm::Value *ModuleBuilder::valueForVar(shared_ptr<ExpressionVariable> expression) {
llvm::AllocaInst *alloca = allocaMap[expression->getName()]; llvm::AllocaInst *alloca = getAlloca(expression->getName());
if (alloca == nullptr) if (alloca == nullptr)
failWithMessage("Variable " + expression->getName() + " not defined"); return nullptr;
return builder->CreateLoad(alloca->getAllocatedType(), alloca, expression->getName()); return builder->CreateLoad(alloca->getAllocatedType(), alloca, expression->getName());
} }
llvm::Value *ModuleBuilder::valueForCall(shared_ptr<ExpressionCall> expression) { llvm::Value *ModuleBuilder::valueForCall(shared_ptr<ExpressionCall> expression) {
llvm::Function *fun = funMap[expression->getName()]; llvm::Function *fun = getFun(expression->getName());
if (fun == nullptr) if (fun == nullptr)
failWithMessage("Function " + expression->getName() + " not defined"); return nullptr;
llvm::FunctionType *funType = fun->getFunctionType(); llvm::FunctionType *funType = fun->getFunctionType();
vector<llvm::Value*> argValues; vector<llvm::Value*> argValues;
for (shared_ptr<Expression> &argumentExpression : expression->getArgumentExpressions()) { for (shared_ptr<Expression> &argumentExpression : expression->getArgumentExpressions()) {
@@ -398,6 +415,48 @@ llvm::Value *ModuleBuilder::valueForCall(shared_ptr<ExpressionCall> expression)
return builder->CreateCall(funType, fun, llvm::ArrayRef(argValues)); return builder->CreateCall(funType, fun, llvm::ArrayRef(argValues));
} }
bool ModuleBuilder::setAlloca(string name, llvm::AllocaInst *alloca) {
if (scopes.top().allocaMap[name] != nullptr)
return false;
scopes.top().allocaMap[name] = alloca;
return true;
}
llvm::AllocaInst* ModuleBuilder::getAlloca(string name) {
stack<Scope> scopes = this->scopes;
while (!scopes.empty()) {
llvm::AllocaInst *alloca = scopes.top().allocaMap[name];
if (alloca != nullptr)
return alloca;
scopes.pop();
}
return nullptr;
}
bool ModuleBuilder::setFun(string name, llvm::Function *fun) {
if (scopes.top().funMap[name] != nullptr)
return false;
scopes.top().funMap[name] = fun;
return true;
}
llvm::Function* ModuleBuilder::getFun(string name) {
stack<Scope> scopes = this->scopes;
while (!scopes.empty()) {
llvm::Function *fun = scopes.top().funMap[name];
if (fun != nullptr)
return fun;
scopes.pop();
}
return nullptr;
}
llvm::Type *ModuleBuilder::typeForValueType(shared_ptr<ValueType> valueType) { llvm::Type *ModuleBuilder::typeForValueType(shared_ptr<ValueType> valueType) {
switch (valueType->getKind()) { switch (valueType->getKind()) {
case ValueTypeKind::NONE: case ValueTypeKind::NONE:

View File

@@ -2,6 +2,7 @@
#define MODULE_BUILDER_H #define MODULE_BUILDER_H
#include <map> #include <map>
#include <stack>
#include <llvm/IR/Module.h> #include <llvm/IR/Module.h>
#include <llvm/IR/IRBuilder.h> #include <llvm/IR/IRBuilder.h>
@@ -33,6 +34,11 @@ class StatementBlock;
using namespace std; using namespace std;
typedef struct {
map<string, llvm::AllocaInst*> allocaMap;
map<string, llvm::Function*> funMap;
} Scope;
class ModuleBuilder { class ModuleBuilder {
private: private:
string moduleName; string moduleName;
@@ -48,8 +54,7 @@ private:
llvm::Type *typeReal32; llvm::Type *typeReal32;
vector<shared_ptr<Statement>> statements; vector<shared_ptr<Statement>> statements;
map<string, llvm::AllocaInst*> allocaMap; stack<Scope> scopes;
map<string, llvm::Function*> funMap;
void buildStatement(shared_ptr<Statement> statement); void buildStatement(shared_ptr<Statement> statement);
void buildFunctionDeclaration(shared_ptr<StatementFunction> statement); void buildFunctionDeclaration(shared_ptr<StatementFunction> statement);
@@ -72,6 +77,12 @@ private:
llvm::Value *valueForVar(shared_ptr<ExpressionVariable> expression); llvm::Value *valueForVar(shared_ptr<ExpressionVariable> expression);
llvm::Value *valueForCall(shared_ptr<ExpressionCall> expression); llvm::Value *valueForCall(shared_ptr<ExpressionCall> expression);
bool setAlloca(string name, llvm::AllocaInst *alloca);
llvm::AllocaInst *getAlloca(string name);
bool setFun(string name, llvm::Function *fun);
llvm::Function *getFun(string name);
llvm::Type *typeForValueType(shared_ptr<ValueType> valueType); llvm::Type *typeForValueType(shared_ptr<ValueType> valueType);
void failWithMessage(string message); void failWithMessage(string message);