diff --git a/src/Expression.cpp b/src/Expression.cpp index 7caf860..0c71700 100644 --- a/src/Expression.cpp +++ b/src/Expression.cpp @@ -124,7 +124,7 @@ shared_ptr ExpressionGrouping::getExpression() { } string ExpressionGrouping::toString() { - return "<( " + expression->toString() + " )>"; + return "( " + expression->toString() + " )"; } // diff --git a/src/ModuleBuilder.cpp b/src/ModuleBuilder.cpp index 4a9f9de..012bd7f 100644 --- a/src/ModuleBuilder.cpp +++ b/src/ModuleBuilder.cpp @@ -1,9 +1,6 @@ #include "ModuleBuilder.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Constants.h" -#include "llvm/Support/raw_ostream.h" -/*ModuleBuilder::ModuleBuilder(vector> statements): statements(statements) { +ModuleBuilder::ModuleBuilder(vector> statements): statements(statements) { context = make_shared(); module = make_shared("dummy", *context); builder = make_shared>(*context); @@ -12,40 +9,47 @@ int32Type = llvm::Type::getInt32Ty(*context); } -void ModuleBuilder::buildCodeForStatement(shared_ptr statement) { +shared_ptr ModuleBuilder::getModule() { + for (shared_ptr &statement : statements) { + buildStatement(statement); + } + return module; +} + +void ModuleBuilder::buildStatement(shared_ptr statement) { switch (statement->getKind()) { case Statement::Kind::FUNCTION_DECLARATION: - buildFunction(statement); + buildFunctionDeclaration(dynamic_pointer_cast(statement)); break; case Statement::Kind::BLOCK: - buildBlock(statement); + buildBlock(dynamic_pointer_cast(statement)); break; case Statement::Kind::RETURN: - buildReturn(statement); + buildReturn(dynamic_pointer_cast(statement)); break; case Statement::Kind::EXPRESSION: - buildExpression(statement); - break; + buildExpression(dynamic_pointer_cast(statement)); + return; default: exit(1); } } -void ModuleBuilder::buildFunction(shared_ptr statement) { +void ModuleBuilder::buildFunctionDeclaration(shared_ptr statement) { llvm::FunctionType *funType = llvm::FunctionType::get(int32Type, false); llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get()); llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun); builder->SetInsertPoint(block); - buildCodeForStatement(statement->getBlockStatement()); + buildStatement(statement->getStatementBlock()); } -void ModuleBuilder::buildBlock(shared_ptr statement) { +void ModuleBuilder::buildBlock(shared_ptr statement) { for (shared_ptr &innerStatement : statement->getStatements()) { - buildCodeForStatement(innerStatement); + buildStatement(innerStatement); } } -void ModuleBuilder::buildReturn(shared_ptr statement) { +void ModuleBuilder::buildReturn(shared_ptr statement) { if (statement->getExpression() != nullptr) { llvm::Value *value = valueForExpression(statement->getExpression()); builder->CreateRet(value); @@ -54,38 +58,57 @@ void ModuleBuilder::buildReturn(shared_ptr statement) { } } -void ModuleBuilder::buildExpression(shared_ptr statement) { - +void ModuleBuilder::buildExpression(shared_ptr statement) { + valueForExpression(statement->getExpression()); } llvm::Value *ModuleBuilder::valueForExpression(shared_ptr expression) { switch (expression->getKind()) { case Expression::Kind::LITERAL: - return llvm::ConstantInt::get(int32Type, expression->getInteger(), true); + return valueForLiteral(dynamic_pointer_cast(expression)); case Expression::Kind::GROUPING: - return valueForExpression(expression->getLeft()); + return valueForExpression(dynamic_pointer_cast(expression)->getExpression()); case Expression::Kind::BINARY: - llvm::Value *leftValue = valueForExpression(expression->getLeft()); - llvm::Value *rightValue = valueForExpression(expression->getRight()); - switch (expression->getOperator()) { - case Expression::Operator::ADD: - return builder->CreateNSWAdd(leftValue, rightValue); - case Expression::Operator::SUB: - return builder->CreateNSWSub(leftValue, rightValue); - case Expression::Operator::MUL: - return builder->CreateNSWMul(leftValue, rightValue); - case Expression::Operator::DIV: - return builder->CreateSDiv(leftValue, rightValue); - case Expression::Operator::MOD: - return builder->CreateSRem(leftValue, rightValue); - } - break; + return valueForBinary(dynamic_pointer_cast(expression)); + default: + exit(1); } } -shared_ptr ModuleBuilder::getModule() { - for (shared_ptr &statement : statements) { - buildCodeForStatement(statement); +llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr expression) { + return llvm::ConstantInt::get(int32Type, expression->getInteger(), true); +} + +llvm::Value *ModuleBuilder::valueForGrouping(shared_ptr expression) { + return valueForExpression(expression->getExpression()); +} + +llvm::Value *ModuleBuilder::valueForBinary(shared_ptr expression) { + llvm::Value *leftValue = valueForExpression(expression->getLeft()); + llvm::Value *rightValue = valueForExpression(expression->getRight()); + + switch (expression->getOperation()) { + case ExpressionBinary::Operation::EQUAL: + return builder->CreateICmpEQ(leftValue, rightValue); + case ExpressionBinary::Operation::NOT_EQUAL: + return builder->CreateICmpNE(leftValue, rightValue); + case ExpressionBinary::Operation::LESS: + return builder->CreateICmpSLT(leftValue, rightValue); + case ExpressionBinary::Operation::LESS_EQUAL: + return builder->CreateICmpSLE(leftValue, rightValue); + case ExpressionBinary::Operation::GREATER: + return builder->CreateICmpSGT(leftValue, rightValue); + case ExpressionBinary::Operation::GREATER_EQUAL: + return builder->CreateICmpSGE(leftValue, rightValue); + case ExpressionBinary::Operation::ADD: + return builder->CreateNSWAdd(leftValue, rightValue); + case ExpressionBinary::Operation::SUB: + return builder->CreateNSWSub(leftValue, rightValue); + case ExpressionBinary::Operation::MUL: + return builder->CreateNSWMul(leftValue, rightValue); + case ExpressionBinary::Operation::DIV: + return builder->CreateSDiv(leftValue, rightValue); + case ExpressionBinary::Operation::MOD: + return builder->CreateSRem(leftValue, rightValue); } - return module; -}*/ \ No newline at end of file +} diff --git a/src/ModuleBuilder.h b/src/ModuleBuilder.h index 75e6b30..d6c6a9f 100644 --- a/src/ModuleBuilder.h +++ b/src/ModuleBuilder.h @@ -3,6 +3,9 @@ #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/raw_ostream.h" #include "Expression.h" #include "Statement.h" @@ -10,7 +13,7 @@ using namespace std; class ModuleBuilder { -/*private: +private: shared_ptr context; shared_ptr module; shared_ptr> builder; @@ -20,16 +23,20 @@ class ModuleBuilder { vector> statements; - void buildCodeForStatement(shared_ptr statement); - void buildFunction(shared_ptr statement); - void buildBlock(shared_ptr statement); - void buildReturn(shared_ptr statement); - void buildExpression(shared_ptr statement); + void buildStatement(shared_ptr statement); + void buildFunctionDeclaration(shared_ptr statement); + void buildBlock(shared_ptr statement); + void buildReturn(shared_ptr statement); + void buildExpression(shared_ptr statement); + llvm::Value *valueForExpression(shared_ptr expression); + llvm::Value *valueForLiteral(shared_ptr expression); + llvm::Value *valueForGrouping(shared_ptr expression); + llvm::Value *valueForBinary(shared_ptr expression); public: ModuleBuilder(vector> statements); - shared_ptr getModule();*/ + shared_ptr getModule(); }; #endif \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 371bfa5..f2583bc 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -50,9 +50,9 @@ int main(int argc, char **argv) { cout << endl; } - //ModuleBuilder moduleBuilder(statements); - //shared_ptr module = moduleBuilder.getModule(); - //module->print(llvm::outs(), nullptr); + ModuleBuilder moduleBuilder(statements); + shared_ptr module = moduleBuilder.getModule(); + module->print(llvm::outs(), nullptr); //CodeGenerator codeGenerator(module); //codeGenerator.generateObjectFile("dummy.s");