Comparisions in modle builder

This commit is contained in:
Rafał Grodziński
2025-06-05 15:26:19 +09:00
parent 7888b94b6a
commit 1591c5927c
4 changed files with 80 additions and 50 deletions

View File

@@ -124,7 +124,7 @@ shared_ptr<Expression> ExpressionGrouping::getExpression() {
}
string ExpressionGrouping::toString() {
return "<( " + expression->toString() + " )>";
return "( " + expression->toString() + " )";
}
//

View File

@@ -1,9 +1,6 @@
#include "ModuleBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/raw_ostream.h"
/*ModuleBuilder::ModuleBuilder(vector<shared_ptr<Statement>> statements): statements(statements) {
ModuleBuilder::ModuleBuilder(vector<shared_ptr<Statement>> statements): statements(statements) {
context = make_shared<llvm::LLVMContext>();
module = make_shared<llvm::Module>("dummy", *context);
builder = make_shared<llvm::IRBuilder<>>(*context);
@@ -12,40 +9,47 @@
int32Type = llvm::Type::getInt32Ty(*context);
}
void ModuleBuilder::buildCodeForStatement(shared_ptr<Statement> statement) {
shared_ptr<llvm::Module> ModuleBuilder::getModule() {
for (shared_ptr<Statement> &statement : statements) {
buildStatement(statement);
}
return module;
}
void ModuleBuilder::buildStatement(shared_ptr<Statement> statement) {
switch (statement->getKind()) {
case Statement::Kind::FUNCTION_DECLARATION:
buildFunction(statement);
buildFunctionDeclaration(dynamic_pointer_cast<StatementFunctionDeclaration>(statement));
break;
case Statement::Kind::BLOCK:
buildBlock(statement);
buildBlock(dynamic_pointer_cast<StatementBlock>(statement));
break;
case Statement::Kind::RETURN:
buildReturn(statement);
buildReturn(dynamic_pointer_cast<StatementReturn>(statement));
break;
case Statement::Kind::EXPRESSION:
buildExpression(statement);
break;
buildExpression(dynamic_pointer_cast<StatementExpression>(statement));
return;
default:
exit(1);
}
}
void ModuleBuilder::buildFunction(shared_ptr<Statement> statement) {
void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunctionDeclaration> statement) {
llvm::FunctionType *funType = llvm::FunctionType::get(int32Type, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get());
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block);
buildCodeForStatement(statement->getBlockStatement());
buildStatement(statement->getStatementBlock());
}
void ModuleBuilder::buildBlock(shared_ptr<Statement> statement) {
void ModuleBuilder::buildBlock(shared_ptr<StatementBlock> statement) {
for (shared_ptr<Statement> &innerStatement : statement->getStatements()) {
buildCodeForStatement(innerStatement);
buildStatement(innerStatement);
}
}
void ModuleBuilder::buildReturn(shared_ptr<Statement> statement) {
void ModuleBuilder::buildReturn(shared_ptr<StatementReturn> statement) {
if (statement->getExpression() != nullptr) {
llvm::Value *value = valueForExpression(statement->getExpression());
builder->CreateRet(value);
@@ -54,38 +58,57 @@ void ModuleBuilder::buildReturn(shared_ptr<Statement> statement) {
}
}
void ModuleBuilder::buildExpression(shared_ptr<Statement> statement) {
void ModuleBuilder::buildExpression(shared_ptr<StatementExpression> statement) {
valueForExpression(statement->getExpression());
}
llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression) {
switch (expression->getKind()) {
case Expression::Kind::LITERAL:
return llvm::ConstantInt::get(int32Type, expression->getInteger(), true);
return valueForLiteral(dynamic_pointer_cast<ExpressionLiteral>(expression));
case Expression::Kind::GROUPING:
return valueForExpression(expression->getLeft());
return valueForExpression(dynamic_pointer_cast<ExpressionGrouping>(expression)->getExpression());
case Expression::Kind::BINARY:
llvm::Value *leftValue = valueForExpression(expression->getLeft());
llvm::Value *rightValue = valueForExpression(expression->getRight());
switch (expression->getOperator()) {
case Expression::Operator::ADD:
return builder->CreateNSWAdd(leftValue, rightValue);
case Expression::Operator::SUB:
return builder->CreateNSWSub(leftValue, rightValue);
case Expression::Operator::MUL:
return builder->CreateNSWMul(leftValue, rightValue);
case Expression::Operator::DIV:
return builder->CreateSDiv(leftValue, rightValue);
case Expression::Operator::MOD:
return builder->CreateSRem(leftValue, rightValue);
}
break;
return valueForBinary(dynamic_pointer_cast<ExpressionBinary>(expression));
default:
exit(1);
}
}
shared_ptr<llvm::Module> ModuleBuilder::getModule() {
for (shared_ptr<Statement> &statement : statements) {
buildCodeForStatement(statement);
llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr<ExpressionLiteral> expression) {
return llvm::ConstantInt::get(int32Type, expression->getInteger(), true);
}
llvm::Value *ModuleBuilder::valueForGrouping(shared_ptr<ExpressionGrouping> expression) {
return valueForExpression(expression->getExpression());
}
llvm::Value *ModuleBuilder::valueForBinary(shared_ptr<ExpressionBinary> expression) {
llvm::Value *leftValue = valueForExpression(expression->getLeft());
llvm::Value *rightValue = valueForExpression(expression->getRight());
switch (expression->getOperation()) {
case ExpressionBinary::Operation::EQUAL:
return builder->CreateICmpEQ(leftValue, rightValue);
case ExpressionBinary::Operation::NOT_EQUAL:
return builder->CreateICmpNE(leftValue, rightValue);
case ExpressionBinary::Operation::LESS:
return builder->CreateICmpSLT(leftValue, rightValue);
case ExpressionBinary::Operation::LESS_EQUAL:
return builder->CreateICmpSLE(leftValue, rightValue);
case ExpressionBinary::Operation::GREATER:
return builder->CreateICmpSGT(leftValue, rightValue);
case ExpressionBinary::Operation::GREATER_EQUAL:
return builder->CreateICmpSGE(leftValue, rightValue);
case ExpressionBinary::Operation::ADD:
return builder->CreateNSWAdd(leftValue, rightValue);
case ExpressionBinary::Operation::SUB:
return builder->CreateNSWSub(leftValue, rightValue);
case ExpressionBinary::Operation::MUL:
return builder->CreateNSWMul(leftValue, rightValue);
case ExpressionBinary::Operation::DIV:
return builder->CreateSDiv(leftValue, rightValue);
case ExpressionBinary::Operation::MOD:
return builder->CreateSRem(leftValue, rightValue);
}
return module;
}*/
}

View File

@@ -3,6 +3,9 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/raw_ostream.h"
#include "Expression.h"
#include "Statement.h"
@@ -10,7 +13,7 @@
using namespace std;
class ModuleBuilder {
/*private:
private:
shared_ptr<llvm::LLVMContext> context;
shared_ptr<llvm::Module> module;
shared_ptr<llvm::IRBuilder<>> builder;
@@ -20,16 +23,20 @@ class ModuleBuilder {
vector<shared_ptr<Statement>> statements;
void buildCodeForStatement(shared_ptr<Statement> statement);
void buildFunction(shared_ptr<Statement> statement);
void buildBlock(shared_ptr<Statement> statement);
void buildReturn(shared_ptr<Statement> statement);
void buildExpression(shared_ptr<Statement> statement);
void buildStatement(shared_ptr<Statement> statement);
void buildFunctionDeclaration(shared_ptr<StatementFunctionDeclaration> statement);
void buildBlock(shared_ptr<StatementBlock> statement);
void buildReturn(shared_ptr<StatementReturn> statement);
void buildExpression(shared_ptr<StatementExpression> statement);
llvm::Value *valueForExpression(shared_ptr<Expression> expression);
llvm::Value *valueForLiteral(shared_ptr<ExpressionLiteral> expression);
llvm::Value *valueForGrouping(shared_ptr<ExpressionGrouping> expression);
llvm::Value *valueForBinary(shared_ptr<ExpressionBinary> expression);
public:
ModuleBuilder(vector<shared_ptr<Statement>> statements);
shared_ptr<llvm::Module> getModule();*/
shared_ptr<llvm::Module> getModule();
};
#endif

View File

@@ -50,9 +50,9 @@ int main(int argc, char **argv) {
cout << endl;
}
//ModuleBuilder moduleBuilder(statements);
//shared_ptr<llvm::Module> module = moduleBuilder.getModule();
//module->print(llvm::outs(), nullptr);
ModuleBuilder moduleBuilder(statements);
shared_ptr<llvm::Module> module = moduleBuilder.getModule();
module->print(llvm::outs(), nullptr);
//CodeGenerator codeGenerator(module);
//codeGenerator.generateObjectFile("dummy.s");