Generate functions

This commit is contained in:
Rafał Grodziński
2025-06-02 10:45:44 +09:00
parent 2cecb456bb
commit 2ef888e374
7 changed files with 92 additions and 31 deletions

View File

@@ -3,7 +3,7 @@
#include "llvm/IR/Constants.h"
#include "llvm/Support/raw_ostream.h"
ModuleBuilder::ModuleBuilder(shared_ptr<Expression> expression): expression(expression) {
ModuleBuilder::ModuleBuilder(vector<shared_ptr<Statement>> statements): statements(statements) {
context = make_shared<llvm::LLVMContext>();
module = make_shared<llvm::Module>("dummy", *context);
builder = make_shared<llvm::IRBuilder<>>(*context);
@@ -12,6 +12,52 @@ ModuleBuilder::ModuleBuilder(shared_ptr<Expression> expression): expression(expr
int32Type = llvm::Type::getInt32Ty(*context);
}
void ModuleBuilder::buildCodeForStatement(shared_ptr<Statement> statement) {
switch (statement->getKind()) {
case Statement::Kind::FUNCTION_DECLARATION:
buildFunction(statement);
break;
case Statement::Kind::BLOCK:
buildBlock(statement);
break;
case Statement::Kind::RETURN:
buildReturn(statement);
break;
case Statement::Kind::EXPRESSION:
buildExpression(statement);
break;
default:
exit(1);
}
}
void ModuleBuilder::buildFunction(shared_ptr<Statement> statement) {
llvm::FunctionType *funType = llvm::FunctionType::get(int32Type, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get());
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block);
buildCodeForStatement(statement->getBlockStatement());
}
void ModuleBuilder::buildBlock(shared_ptr<Statement> statement) {
for (shared_ptr<Statement> &innerStatement : statement->getStatements()) {
buildCodeForStatement(innerStatement);
}
}
void ModuleBuilder::buildReturn(shared_ptr<Statement> statement) {
if (statement->getExpression() != nullptr) {
llvm::Value *value = valueForExpression(statement->getExpression());
builder->CreateRet(value);
} else {
builder->CreateRetVoid();
}
}
void ModuleBuilder::buildExpression(shared_ptr<Statement> statement) {
}
llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression) {
switch (expression->getKind()) {
case Expression::Kind::LITERAL:
@@ -38,18 +84,8 @@ llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression
}
shared_ptr<llvm::Module> ModuleBuilder::getModule() {
//llvm::Value *value = valueForExpression(expression);
llvm::FunctionType *fType = llvm::FunctionType::get(int32Type, false);
llvm::Function *f = llvm::Function::Create(fType, llvm::GlobalValue::InternalLinkage, "dummyFunc", module.get());
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, "entry", f);
builder->SetInsertPoint(block);
llvm::Value *value = valueForExpression(expression);
//builder->CreateRetVoid();
builder->CreateRet(value);
//value->print(llvm::outs(), false);
//cout << endl;
for (shared_ptr<Statement> &statement : statements) {
buildCodeForStatement(statement);
}
return module;
}