From 2ef888e3747eaa25938ad9f1da21db76cec6803f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Grodzi=C5=84ski?= Date: Mon, 2 Jun 2025 10:45:44 +0900 Subject: [PATCH] Generate functions --- src/Lexer.h | 4 +-- src/ModuleBuilder.cpp | 64 +++++++++++++++++++++++++++++++++---------- src/ModuleBuilder.h | 12 ++++++-- src/Parser.h | 1 + src/Statement.cpp | 20 ++++++++++++-- src/Statement.h | 8 ++++-- src/main.cpp | 14 ++++------ 7 files changed, 92 insertions(+), 31 deletions(-) diff --git a/src/Lexer.h b/src/Lexer.h index e35dbad..b18dbb6 100644 --- a/src/Lexer.h +++ b/src/Lexer.h @@ -1,10 +1,10 @@ #ifndef LEXER_H #define LEXER_H -#include "Token.h" - #include +#include "Token.h" + using namespace std; class Lexer { diff --git a/src/ModuleBuilder.cpp b/src/ModuleBuilder.cpp index 09f7919..56ea3ab 100644 --- a/src/ModuleBuilder.cpp +++ b/src/ModuleBuilder.cpp @@ -3,7 +3,7 @@ #include "llvm/IR/Constants.h" #include "llvm/Support/raw_ostream.h" -ModuleBuilder::ModuleBuilder(shared_ptr expression): expression(expression) { +ModuleBuilder::ModuleBuilder(vector> statements): statements(statements) { context = make_shared(); module = make_shared("dummy", *context); builder = make_shared>(*context); @@ -12,6 +12,52 @@ ModuleBuilder::ModuleBuilder(shared_ptr expression): expression(expr int32Type = llvm::Type::getInt32Ty(*context); } +void ModuleBuilder::buildCodeForStatement(shared_ptr statement) { + switch (statement->getKind()) { + case Statement::Kind::FUNCTION_DECLARATION: + buildFunction(statement); + break; + case Statement::Kind::BLOCK: + buildBlock(statement); + break; + case Statement::Kind::RETURN: + buildReturn(statement); + break; + case Statement::Kind::EXPRESSION: + buildExpression(statement); + break; + default: + exit(1); + } +} + +void ModuleBuilder::buildFunction(shared_ptr statement) { + llvm::FunctionType *funType = llvm::FunctionType::get(int32Type, false); + llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get()); + llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun); + builder->SetInsertPoint(block); + buildCodeForStatement(statement->getBlockStatement()); +} + +void ModuleBuilder::buildBlock(shared_ptr statement) { + for (shared_ptr &innerStatement : statement->getStatements()) { + buildCodeForStatement(innerStatement); + } +} + +void ModuleBuilder::buildReturn(shared_ptr statement) { + if (statement->getExpression() != nullptr) { + llvm::Value *value = valueForExpression(statement->getExpression()); + builder->CreateRet(value); + } else { + builder->CreateRetVoid(); + } +} + +void ModuleBuilder::buildExpression(shared_ptr statement) { + +} + llvm::Value *ModuleBuilder::valueForExpression(shared_ptr expression) { switch (expression->getKind()) { case Expression::Kind::LITERAL: @@ -38,18 +84,8 @@ llvm::Value *ModuleBuilder::valueForExpression(shared_ptr expression } shared_ptr ModuleBuilder::getModule() { - //llvm::Value *value = valueForExpression(expression); - - llvm::FunctionType *fType = llvm::FunctionType::get(int32Type, false); - llvm::Function *f = llvm::Function::Create(fType, llvm::GlobalValue::InternalLinkage, "dummyFunc", module.get()); - - llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, "entry", f); - builder->SetInsertPoint(block); - llvm::Value *value = valueForExpression(expression); - //builder->CreateRetVoid(); - builder->CreateRet(value); - - //value->print(llvm::outs(), false); - //cout << endl; + for (shared_ptr &statement : statements) { + buildCodeForStatement(statement); + } return module; } \ No newline at end of file diff --git a/src/ModuleBuilder.h b/src/ModuleBuilder.h index 68e078f..fb11686 100644 --- a/src/ModuleBuilder.h +++ b/src/ModuleBuilder.h @@ -3,7 +3,9 @@ #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" + #include "Expression.h" +#include "Statement.h" using namespace std; @@ -16,11 +18,17 @@ private: llvm::Type *voidType; llvm::IntegerType *int32Type; - shared_ptr expression; + vector> statements; + + void buildCodeForStatement(shared_ptr statement); + void buildFunction(shared_ptr statement); + void buildBlock(shared_ptr statement); + void buildReturn(shared_ptr statement); + void buildExpression(shared_ptr statement); llvm::Value *valueForExpression(shared_ptr expression); public: - ModuleBuilder(shared_ptr expression); + ModuleBuilder(vector> statements); shared_ptr getModule(); }; diff --git a/src/Parser.h b/src/Parser.h index fe39db2..4281b63 100644 --- a/src/Parser.h +++ b/src/Parser.h @@ -2,6 +2,7 @@ #define PARSER_H #include + #include "Token.h" #include "Expression.h" #include "Statement.h" diff --git a/src/Statement.cpp b/src/Statement.cpp index 2049e75..38a6519 100644 --- a/src/Statement.cpp +++ b/src/Statement.cpp @@ -4,14 +4,30 @@ Statement::Statement(Kind kind, shared_ptr token, shared_ptr kind(kind), token(token), expression(expression), blockStatement(blockStatement), statements(statements), name(name) { } -shared_ptr Statement::getExpression() { - return expression; +Statement::Kind Statement::getKind() { + return kind; } shared_ptr Statement::getToken() { return token; } +shared_ptr Statement::getExpression() { + return expression; +} + +shared_ptr Statement::getBlockStatement() { + return blockStatement; +} + +vector> Statement::getStatements() { + return statements; +} + +string Statement::getName() { + return name; +} + bool Statement::isValid() { return kind != Statement::Kind::INVALID; } diff --git a/src/Statement.h b/src/Statement.h index 362521c..fb13598 100644 --- a/src/Statement.h +++ b/src/Statement.h @@ -11,10 +11,10 @@ using namespace std; class Statement { public: enum Kind { - EXPRESSION, - BLOCK, FUNCTION_DECLARATION, + BLOCK, RETURN, + EXPRESSION, INVALID }; @@ -28,8 +28,12 @@ private: public: Statement(Kind kind, shared_ptr token, shared_ptr expression, shared_ptr blockStatement, vector> statements, string name); + Kind getKind(); shared_ptr getToken(); shared_ptr getExpression(); + shared_ptr getBlockStatement(); + vector> getStatements(); + string getName(); bool isValid(); string toString(); }; diff --git a/src/main.cpp b/src/main.cpp index 95e243d..f3475bb 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -52,18 +52,14 @@ int main(int argc, char **argv) { cout << statement->toString(); cout << endl; } - //shared_ptr expression = parser.getExpression(); - //if (!expression) { - // exit(1); - //} - //cout << expression->toString() << endl; //ModuleBuilder moduleBuilder(expression); - //shared_ptr module = moduleBuilder.getModule(); - //module->print(llvm::outs(), nullptr); + ModuleBuilder moduleBuilder(statements); + shared_ptr module = moduleBuilder.getModule(); + module->print(llvm::outs(), nullptr); - //CodeGenerator codeGenerator(module);; - //codeGenerator.generateObjectFile("dummy.s"); + CodeGenerator codeGenerator(module); + codeGenerator.generateObjectFile("dummy.s"); return 0; } \ No newline at end of file