Cleaned up structure

This commit is contained in:
Rafał Grodziński
2025-06-21 13:33:50 +09:00
parent a1bb97a597
commit 7a76bb7e33
16 changed files with 29 additions and 27 deletions

View File

@@ -0,0 +1,57 @@
#include "CodeGenerator.h"
using namespace std;
CodeGenerator::CodeGenerator(shared_ptr<llvm::Module> module): module(module) {
}
void CodeGenerator::generateObjectFile(OutputKind outputKind) {
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
string targetTriple = llvm::sys::getDefaultTargetTriple();
string errorString;
const llvm::Target *target = llvm::TargetRegistry::lookupTarget(targetTriple, errorString);
if (!target) {
cerr << errorString << endl;
exit(1);
}
llvm::TargetOptions options;
llvm::TargetMachine *targetMachine = target->createTargetMachine(targetTriple, "generic", "", options, llvm::Reloc::PIC_);
module->setDataLayout(targetMachine->createDataLayout());
module->setTargetTriple(targetTriple);
string fileName;
llvm::CodeGenFileType codeGenFileType;
switch (outputKind) {
case OutputKind::ASSEMBLY:
fileName = string(module->getName()) + ".asm";
codeGenFileType = llvm::CodeGenFileType::AssemblyFile;
break;
case OutputKind::OBJECT:
fileName = string(module->getName()) + ".o";
codeGenFileType = llvm::CodeGenFileType::ObjectFile;
break;
}
error_code errorCode;
llvm::raw_fd_ostream outputFile(fileName, errorCode, llvm::sys::fs::OF_None);
if (errorCode) {
cerr << errorCode.message();
exit(1);
}
llvm::legacy::PassManager passManager;
if (targetMachine->addPassesToEmitFile(passManager, outputFile, nullptr, codeGenFileType)) {
cerr << "Failed to generate file " << fileName << endl;
exit(1);
}
passManager.run(*module);
outputFile.flush();
}

View File

@@ -0,0 +1,31 @@
#ifndef CODE_GENERATOR_H
#define CODE_GENERATOR_H
#include <iostream>
#include <llvm/IR/Module.h>
#include <llvm/Support/FileSystem.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/MC/TargetRegistry.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/TargetParser/Host.h>
using namespace std;
class CodeGenerator {
public:
enum class OutputKind {
ASSEMBLY,
OBJECT
};
private:
shared_ptr<llvm::Module> module;
public:
CodeGenerator(shared_ptr<llvm::Module> module);
void generateObjectFile(OutputKind outputKind);
};
#endif

View File

@@ -0,0 +1,335 @@
#include "ModuleBuilder.h"
ModuleBuilder::ModuleBuilder(string moduleName, string sourceFileName, vector<shared_ptr<Statement>> statements):
moduleName(moduleName), sourceFileName(sourceFileName), statements(statements) {
context = make_shared<llvm::LLVMContext>();
module = make_shared<llvm::Module>(moduleName, *context);
module->setSourceFileName(sourceFileName);
builder = make_shared<llvm::IRBuilder<>>(*context);
typeVoid = llvm::Type::getVoidTy(*context);
typeBool = llvm::Type::getInt1Ty(*context);
typeSint32 = llvm::Type::getInt32Ty(*context);
typeReal32 = llvm::Type::getFloatTy(*context);
}
shared_ptr<llvm::Module> ModuleBuilder::getModule() {
for (shared_ptr<Statement> &statement : statements) {
buildStatement(statement);
}
return module;
}
void ModuleBuilder::buildStatement(shared_ptr<Statement> statement) {
switch (statement->getKind()) {
case StatementKind::FUNCTION_DECLARATION:
buildFunctionDeclaration(dynamic_pointer_cast<StatementFunctionDeclaration>(statement));
break;
case StatementKind::VAR_DECLARATION:
buildVarDeclaration(dynamic_pointer_cast<StatementVarDeclaration>(statement));
break;
case StatementKind::BLOCK:
buildBlock(dynamic_pointer_cast<StatementBlock>(statement));
break;
case StatementKind::RETURN:
buildReturn(dynamic_pointer_cast<StatementReturn>(statement));
break;
case StatementKind::META_EXTERN_FUNCTION:
buildMetaExternFunction(dynamic_pointer_cast<StatementMetaExternFunction>(statement));
break;
case StatementKind::EXPRESSION:
buildExpression(dynamic_pointer_cast<StatementExpression>(statement));
return;
default:
failWithMessage("Unexpected statement");
}
}
void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunctionDeclaration> statement) {
// get argument types
vector<llvm::Type *> types;
for (pair<string, ValueType> &arg : statement->getArguments()) {
types.push_back(typeForValueType(arg.second));
}
// build function declaration
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::ExternalLinkage, statement->getName(), module.get());
funMap[statement->getName()] = fun;
// define function body
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block);
// build arguments
int i=0;
for (auto &arg : fun->args()) {
string name = statement->getArguments()[i].first;
llvm::Type *type = types[i];
arg.setName(name);
llvm::AllocaInst *alloca = builder->CreateAlloca(type, nullptr, name);
allocaMap[name] = alloca;
builder->CreateStore(&arg, alloca);
i++;
}
// build function body
buildStatement(statement->getStatementBlock());
// verify
string errorMessage;
llvm::raw_string_ostream llvmErrorMessage(errorMessage);
if (llvm::verifyFunction(*fun, &llvmErrorMessage))
failWithMessage(errorMessage);
}
void ModuleBuilder::buildVarDeclaration(shared_ptr<StatementVarDeclaration> statement) {
llvm::Value *value = valueForExpression(statement->getExpression());
llvm::AllocaInst *alloca = builder->CreateAlloca(typeForValueType(statement->getValueType()), nullptr, statement->getName());
allocaMap[statement->getName()] = alloca;
builder->CreateStore(value, alloca);
}
void ModuleBuilder::buildBlock(shared_ptr<StatementBlock> statement) {
for (shared_ptr<Statement> &innerStatement : statement->getStatements()) {
buildStatement(innerStatement);
}
if (statement->getStatementExpression() != nullptr)
buildStatement(statement->getStatementExpression());
}
void ModuleBuilder::buildReturn(shared_ptr<StatementReturn> statement) {
if (statement->getExpression() != nullptr) {
llvm::Value *value = valueForExpression(statement->getExpression());
builder->CreateRet(value);
} else {
builder->CreateRetVoid();
}
}
void ModuleBuilder::buildMetaExternFunction(shared_ptr<StatementMetaExternFunction> statement) {
// get argument types
vector<llvm::Type *> types;
for (pair<string, ValueType> &arg : statement->getArguments()) {
types.push_back(typeForValueType(arg.second));
}
// build function declaration
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::ExternalLinkage, statement->getName(), module.get());
funMap[statement->getName()] = fun;
// build arguments
int i=0;
for (auto &arg : fun->args()) {
string name = statement->getArguments()[i].first;
arg.setName(name);
i++;
}
}
void ModuleBuilder::buildExpression(shared_ptr<StatementExpression> statement) {
valueForExpression(statement->getExpression());
}
llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression) {
switch (expression->getKind()) {
case ExpressionKind::LITERAL:
return valueForLiteral(dynamic_pointer_cast<ExpressionLiteral>(expression));
case ExpressionKind::GROUPING:
return valueForExpression(dynamic_pointer_cast<ExpressionGrouping>(expression)->getExpression());
case ExpressionKind::BINARY:
return valueForBinary(dynamic_pointer_cast<ExpressionBinary>(expression));
case ExpressionKind::IF_ELSE:
return valueForIfElse(dynamic_pointer_cast<ExpressionIfElse>(expression));
case ExpressionKind::VAR:
return valueForVar(dynamic_pointer_cast<ExpressionVar>(expression));
case ExpressionKind::CALL:
return valueForCall(dynamic_pointer_cast<ExpressionCall>(expression));
default:
failWithMessage("Unexpected expression");
}
}
llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr<ExpressionLiteral> expression) {
switch (expression->getValueType()) {
case ValueType::NONE:
return llvm::UndefValue::get(typeVoid);
case ValueType::BOOL:
return llvm::ConstantInt::get(typeBool, expression->getBoolValue(), true);
case ValueType::SINT32:
return llvm::ConstantInt::get(typeSint32, expression->getSint32Value(), true);
case ValueType::REAL32:
return llvm::ConstantInt::get(typeReal32, expression->getReal32Value(), true);
}
}
llvm::Value *ModuleBuilder::valueForGrouping(shared_ptr<ExpressionGrouping> expression) {
return valueForExpression(expression->getExpression());
}
llvm::Value *ModuleBuilder::valueForBinary(shared_ptr<ExpressionBinary> expression) {
llvm::Value *leftValue = valueForExpression(expression->getLeft());
llvm::Value *rightValue = valueForExpression(expression->getRight());
llvm::Type *type = leftValue->getType();
if (type == typeBool) {
return valueForBinaryBool(expression->getOperation(), leftValue, rightValue);
} else if (type == typeSint32 || type == typeVoid) {
return valueForBinaryInteger(expression->getOperation(), leftValue, rightValue);
} else if (type == typeReal32) {
return valueForBinaryReal(expression->getOperation(), leftValue, rightValue);
}
failWithMessage("Unexpected operation");
}
llvm::Value *ModuleBuilder::valueForBinaryBool(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) {
switch (operation) {
case ExpressionBinary::Operation::EQUAL:
return builder->CreateICmpEQ(leftValue, rightValue);
case ExpressionBinary::Operation::NOT_EQUAL:
return builder->CreateICmpNE(leftValue, rightValue);
default:
failWithMessage("Undefined operation for boolean operands");
}
}
llvm::Value *ModuleBuilder::valueForBinaryInteger(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) {
switch (operation) {
case ExpressionBinary::Operation::EQUAL:
return builder->CreateICmpEQ(leftValue, rightValue);
case ExpressionBinary::Operation::NOT_EQUAL:
return builder->CreateICmpNE(leftValue, rightValue);
case ExpressionBinary::Operation::LESS:
return builder->CreateICmpSLT(leftValue, rightValue);
case ExpressionBinary::Operation::LESS_EQUAL:
return builder->CreateICmpSLE(leftValue, rightValue);
case ExpressionBinary::Operation::GREATER:
return builder->CreateICmpSGT(leftValue, rightValue);
case ExpressionBinary::Operation::GREATER_EQUAL:
return builder->CreateICmpSGE(leftValue, rightValue);
case ExpressionBinary::Operation::ADD:
return builder->CreateNSWAdd(leftValue, rightValue);
case ExpressionBinary::Operation::SUB:
return builder->CreateNSWSub(leftValue, rightValue);
case ExpressionBinary::Operation::MUL:
return builder->CreateNSWMul(leftValue, rightValue);
case ExpressionBinary::Operation::DIV:
return builder->CreateSDiv(leftValue, rightValue);
case ExpressionBinary::Operation::MOD:
return builder->CreateSRem(leftValue, rightValue);
}
}
llvm::Value *ModuleBuilder::valueForBinaryReal(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) {
switch (operation) {
case ExpressionBinary::Operation::EQUAL:
return builder->CreateFCmpOEQ(leftValue, rightValue);
case ExpressionBinary::Operation::NOT_EQUAL:
return builder->CreateFCmpONE(leftValue, rightValue);
case ExpressionBinary::Operation::LESS:
return builder->CreateFCmpOLT(leftValue, rightValue);
case ExpressionBinary::Operation::LESS_EQUAL:
return builder->CreateFCmpOLE(leftValue, rightValue);
case ExpressionBinary::Operation::GREATER:
return builder->CreateFCmpOGT(leftValue, rightValue);
case ExpressionBinary::Operation::GREATER_EQUAL:
return builder->CreateFCmpOGE(leftValue, rightValue);
case ExpressionBinary::Operation::ADD:
return builder->CreateNSWAdd(leftValue, rightValue);
case ExpressionBinary::Operation::SUB:
return builder->CreateNSWSub(leftValue, rightValue);
case ExpressionBinary::Operation::MUL:
return builder->CreateNSWMul(leftValue, rightValue);
case ExpressionBinary::Operation::DIV:
return builder->CreateSDiv(leftValue, rightValue);
case ExpressionBinary::Operation::MOD:
return builder->CreateSRem(leftValue, rightValue);
}
}
llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expression) {
shared_ptr<Expression> conditionExpression = expression->getCondition();
llvm::Function *fun = builder->GetInsertBlock()->getParent();
llvm::Value *conditionValue = valueForExpression(conditionExpression);
llvm::BasicBlock *thenBlock = llvm::BasicBlock::Create(*context, "thenBlock", fun);
llvm::BasicBlock *elseBlock = llvm::BasicBlock::Create(*context, "elseBlock");
llvm::BasicBlock *mergeBlock = llvm::BasicBlock::Create(*context, "mergeBlock");
int valuesCount = 1;
builder->CreateCondBr(conditionValue, thenBlock, elseBlock);
// Then
builder->SetInsertPoint(thenBlock);
llvm::Value *thenValue = valueForExpression(expression->getThenBlock()->getStatementExpression()->getExpression());
buildStatement(expression->getThenBlock());
builder->CreateBr(mergeBlock);
thenBlock = builder->GetInsertBlock();
// Else
fun->insert(fun->end(), elseBlock);
builder->SetInsertPoint(elseBlock);
llvm::Value *elseValue = nullptr;
if (expression->getElseBlock() != nullptr) {
valuesCount++;
elseValue = valueForExpression(expression->getElseBlock()->getStatementExpression()->getExpression());
buildStatement(expression->getElseBlock());
}
builder->CreateBr(mergeBlock);
elseBlock = builder->GetInsertBlock();
// Merge
fun->insert(fun->end(), mergeBlock);
builder->SetInsertPoint(mergeBlock);
llvm::PHINode *phi = builder->CreatePHI(typeForValueType(expression->getValueType()), valuesCount, "phii");
phi->addIncoming(thenValue, thenBlock);
if (elseValue != nullptr)
phi->addIncoming(elseValue, elseBlock);
return phi;
}
llvm::Value *ModuleBuilder::valueForVar(shared_ptr<ExpressionVar> expression) {
llvm::AllocaInst *alloca = allocaMap[expression->getName()];
if (alloca == nullptr)
failWithMessage("Variable " + expression->getName() + " not defined");
return builder->CreateLoad(alloca->getAllocatedType(), alloca, expression->getName());
}
llvm::Value *ModuleBuilder::valueForCall(shared_ptr<ExpressionCall> expression) {
llvm::Function *fun = funMap[expression->getName()];
if (fun == nullptr)
failWithMessage("Function " + expression->getName() + " not defined");
llvm::FunctionType *funType = fun->getFunctionType();
vector<llvm::Value*> argValues;
for (shared_ptr<Expression> &argumentExpression : expression->getArgumentExpressions()) {
llvm::Value *argValue = valueForExpression(argumentExpression);
argValues.push_back(argValue);
}
return builder->CreateCall(funType, fun, llvm::ArrayRef(argValues));
}
llvm::Type *ModuleBuilder::typeForValueType(ValueType valueType) {
switch (valueType) {
case ValueType::NONE:
return typeVoid;
case ValueType::BOOL:
return typeBool;
case ValueType::SINT32:
return typeSint32;
case ValueType::REAL32:
return typeReal32;
}
}
void ModuleBuilder::failWithMessage(string message) {
cerr << "Error! Building module \"" << moduleName << "\" from \"" + sourceFileName + "\" failed:" << endl << message << endl;
exit(1);
}

View File

@@ -0,0 +1,63 @@
#ifndef MODULE_BUILDER_H
#define MODULE_BUILDER_H
#include <map>
#include <llvm/IR/Module.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Constants.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/IR/Verifier.h>
#include "Parser/Expression.h"
#include "Parser/Statement.h"
using namespace std;
class ModuleBuilder {
private:
string moduleName;
string sourceFileName;
shared_ptr<llvm::LLVMContext> context;
shared_ptr<llvm::Module> module;
shared_ptr<llvm::IRBuilder<>> builder;
llvm::Type *typeVoid;
llvm::Type *typeBool;
llvm::IntegerType *typeSint32;
llvm::Type *typeReal32;
vector<shared_ptr<Statement>> statements;
map<string, llvm::AllocaInst*> allocaMap;
map<string, llvm::Function*> funMap;
void buildStatement(shared_ptr<Statement> statement);
void buildFunctionDeclaration(shared_ptr<StatementFunctionDeclaration> statement);
void buildVarDeclaration(shared_ptr<StatementVarDeclaration> statement);
void buildBlock(shared_ptr<StatementBlock> statement);
void buildReturn(shared_ptr<StatementReturn> statement);
void buildMetaExternFunction(shared_ptr<StatementMetaExternFunction> statement);
void buildExpression(shared_ptr<StatementExpression> statement);
llvm::Value *valueForExpression(shared_ptr<Expression> expression);
llvm::Value *valueForLiteral(shared_ptr<ExpressionLiteral> expression);
llvm::Value *valueForGrouping(shared_ptr<ExpressionGrouping> expression);
llvm::Value *valueForBinary(shared_ptr<ExpressionBinary> expression);
llvm::Value *valueForBinaryBool(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue);
llvm::Value *valueForBinaryInteger(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue);
llvm::Value *valueForBinaryReal(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue);
llvm::Value *valueForIfElse(shared_ptr<ExpressionIfElse> expression);
llvm::Value *valueForVar(shared_ptr<ExpressionVar> expression);
llvm::Value *valueForCall(shared_ptr<ExpressionCall> expression);
llvm::Type *typeForValueType(ValueType valueType);
void failWithMessage(string message);
public:
ModuleBuilder(string moduleName, string sourceFileName, vector<shared_ptr<Statement>> statements);
shared_ptr<llvm::Module> getModule();
};
#endif