#include "ModuleBuilder.h" ModuleBuilder::ModuleBuilder(vector> statements): statements(statements) { context = make_shared(); module = make_shared("dummy", *context); builder = make_shared>(*context); voidType = llvm::Type::getVoidTy(*context); int32Type = llvm::Type::getInt32Ty(*context); } shared_ptr ModuleBuilder::getModule() { for (shared_ptr &statement : statements) { buildStatement(statement); } return module; } void ModuleBuilder::buildStatement(shared_ptr statement) { switch (statement->getKind()) { case Statement::Kind::FUNCTION_DECLARATION: buildFunctionDeclaration(dynamic_pointer_cast(statement)); break; case Statement::Kind::BLOCK: buildBlock(dynamic_pointer_cast(statement)); break; case Statement::Kind::RETURN: buildReturn(dynamic_pointer_cast(statement)); break; case Statement::Kind::EXPRESSION: buildExpression(dynamic_pointer_cast(statement)); return; default: exit(1); } } void ModuleBuilder::buildFunctionDeclaration(shared_ptr statement) { llvm::FunctionType *funType = llvm::FunctionType::get(int32Type, false); llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get()); llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun); builder->SetInsertPoint(block); buildStatement(statement->getStatementBlock()); } void ModuleBuilder::buildBlock(shared_ptr statement) { for (shared_ptr &innerStatement : statement->getStatements()) { buildStatement(innerStatement); } } void ModuleBuilder::buildReturn(shared_ptr statement) { if (statement->getExpression() != nullptr) { llvm::Value *value = valueForExpression(statement->getExpression()); builder->CreateRet(value); } else { builder->CreateRetVoid(); } } void ModuleBuilder::buildExpression(shared_ptr statement) { valueForExpression(statement->getExpression()); } llvm::Value *ModuleBuilder::valueForExpression(shared_ptr expression) { switch (expression->getKind()) { case Expression::Kind::LITERAL: return valueForLiteral(dynamic_pointer_cast(expression)); case Expression::Kind::GROUPING: return valueForExpression(dynamic_pointer_cast(expression)->getExpression()); case Expression::Kind::BINARY: return valueForBinary(dynamic_pointer_cast(expression)); default: exit(1); } } llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr expression) { return llvm::ConstantInt::get(int32Type, expression->getInteger(), true); } llvm::Value *ModuleBuilder::valueForGrouping(shared_ptr expression) { return valueForExpression(expression->getExpression()); } llvm::Value *ModuleBuilder::valueForBinary(shared_ptr expression) { llvm::Value *leftValue = valueForExpression(expression->getLeft()); llvm::Value *rightValue = valueForExpression(expression->getRight()); switch (expression->getOperation()) { case ExpressionBinary::Operation::EQUAL: return builder->CreateICmpEQ(leftValue, rightValue); case ExpressionBinary::Operation::NOT_EQUAL: return builder->CreateICmpNE(leftValue, rightValue); case ExpressionBinary::Operation::LESS: return builder->CreateICmpSLT(leftValue, rightValue); case ExpressionBinary::Operation::LESS_EQUAL: return builder->CreateICmpSLE(leftValue, rightValue); case ExpressionBinary::Operation::GREATER: return builder->CreateICmpSGT(leftValue, rightValue); case ExpressionBinary::Operation::GREATER_EQUAL: return builder->CreateICmpSGE(leftValue, rightValue); case ExpressionBinary::Operation::ADD: return builder->CreateNSWAdd(leftValue, rightValue); case ExpressionBinary::Operation::SUB: return builder->CreateNSWSub(leftValue, rightValue); case ExpressionBinary::Operation::MUL: return builder->CreateNSWMul(leftValue, rightValue); case ExpressionBinary::Operation::DIV: return builder->CreateSDiv(leftValue, rightValue); case ExpressionBinary::Operation::MOD: return builder->CreateSRem(leftValue, rightValue); } }