Comparisions in modle builder

This commit is contained in:
Rafał Grodziński
2025-06-05 15:26:19 +09:00
parent 7888b94b6a
commit 1591c5927c
4 changed files with 80 additions and 50 deletions

View File

@@ -124,7 +124,7 @@ shared_ptr<Expression> ExpressionGrouping::getExpression() {
} }
string ExpressionGrouping::toString() { string ExpressionGrouping::toString() {
return "<( " + expression->toString() + " )>"; return "( " + expression->toString() + " )";
} }
// //

View File

@@ -1,9 +1,6 @@
#include "ModuleBuilder.h" #include "ModuleBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/raw_ostream.h"
/*ModuleBuilder::ModuleBuilder(vector<shared_ptr<Statement>> statements): statements(statements) { ModuleBuilder::ModuleBuilder(vector<shared_ptr<Statement>> statements): statements(statements) {
context = make_shared<llvm::LLVMContext>(); context = make_shared<llvm::LLVMContext>();
module = make_shared<llvm::Module>("dummy", *context); module = make_shared<llvm::Module>("dummy", *context);
builder = make_shared<llvm::IRBuilder<>>(*context); builder = make_shared<llvm::IRBuilder<>>(*context);
@@ -12,40 +9,47 @@
int32Type = llvm::Type::getInt32Ty(*context); int32Type = llvm::Type::getInt32Ty(*context);
} }
void ModuleBuilder::buildCodeForStatement(shared_ptr<Statement> statement) { shared_ptr<llvm::Module> ModuleBuilder::getModule() {
for (shared_ptr<Statement> &statement : statements) {
buildStatement(statement);
}
return module;
}
void ModuleBuilder::buildStatement(shared_ptr<Statement> statement) {
switch (statement->getKind()) { switch (statement->getKind()) {
case Statement::Kind::FUNCTION_DECLARATION: case Statement::Kind::FUNCTION_DECLARATION:
buildFunction(statement); buildFunctionDeclaration(dynamic_pointer_cast<StatementFunctionDeclaration>(statement));
break; break;
case Statement::Kind::BLOCK: case Statement::Kind::BLOCK:
buildBlock(statement); buildBlock(dynamic_pointer_cast<StatementBlock>(statement));
break; break;
case Statement::Kind::RETURN: case Statement::Kind::RETURN:
buildReturn(statement); buildReturn(dynamic_pointer_cast<StatementReturn>(statement));
break; break;
case Statement::Kind::EXPRESSION: case Statement::Kind::EXPRESSION:
buildExpression(statement); buildExpression(dynamic_pointer_cast<StatementExpression>(statement));
break; return;
default: default:
exit(1); exit(1);
} }
} }
void ModuleBuilder::buildFunction(shared_ptr<Statement> statement) { void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunctionDeclaration> statement) {
llvm::FunctionType *funType = llvm::FunctionType::get(int32Type, false); llvm::FunctionType *funType = llvm::FunctionType::get(int32Type, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get()); llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get());
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun); llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block); builder->SetInsertPoint(block);
buildCodeForStatement(statement->getBlockStatement()); buildStatement(statement->getStatementBlock());
} }
void ModuleBuilder::buildBlock(shared_ptr<Statement> statement) { void ModuleBuilder::buildBlock(shared_ptr<StatementBlock> statement) {
for (shared_ptr<Statement> &innerStatement : statement->getStatements()) { for (shared_ptr<Statement> &innerStatement : statement->getStatements()) {
buildCodeForStatement(innerStatement); buildStatement(innerStatement);
} }
} }
void ModuleBuilder::buildReturn(shared_ptr<Statement> statement) { void ModuleBuilder::buildReturn(shared_ptr<StatementReturn> statement) {
if (statement->getExpression() != nullptr) { if (statement->getExpression() != nullptr) {
llvm::Value *value = valueForExpression(statement->getExpression()); llvm::Value *value = valueForExpression(statement->getExpression());
builder->CreateRet(value); builder->CreateRet(value);
@@ -54,38 +58,57 @@ void ModuleBuilder::buildReturn(shared_ptr<Statement> statement) {
} }
} }
void ModuleBuilder::buildExpression(shared_ptr<Statement> statement) { void ModuleBuilder::buildExpression(shared_ptr<StatementExpression> statement) {
valueForExpression(statement->getExpression());
} }
llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression) { llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression) {
switch (expression->getKind()) { switch (expression->getKind()) {
case Expression::Kind::LITERAL: case Expression::Kind::LITERAL:
return llvm::ConstantInt::get(int32Type, expression->getInteger(), true); return valueForLiteral(dynamic_pointer_cast<ExpressionLiteral>(expression));
case Expression::Kind::GROUPING: case Expression::Kind::GROUPING:
return valueForExpression(expression->getLeft()); return valueForExpression(dynamic_pointer_cast<ExpressionGrouping>(expression)->getExpression());
case Expression::Kind::BINARY: case Expression::Kind::BINARY:
llvm::Value *leftValue = valueForExpression(expression->getLeft()); return valueForBinary(dynamic_pointer_cast<ExpressionBinary>(expression));
llvm::Value *rightValue = valueForExpression(expression->getRight()); default:
switch (expression->getOperator()) { exit(1);
case Expression::Operator::ADD:
return builder->CreateNSWAdd(leftValue, rightValue);
case Expression::Operator::SUB:
return builder->CreateNSWSub(leftValue, rightValue);
case Expression::Operator::MUL:
return builder->CreateNSWMul(leftValue, rightValue);
case Expression::Operator::DIV:
return builder->CreateSDiv(leftValue, rightValue);
case Expression::Operator::MOD:
return builder->CreateSRem(leftValue, rightValue);
}
break;
} }
} }
shared_ptr<llvm::Module> ModuleBuilder::getModule() { llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr<ExpressionLiteral> expression) {
for (shared_ptr<Statement> &statement : statements) { return llvm::ConstantInt::get(int32Type, expression->getInteger(), true);
buildCodeForStatement(statement); }
llvm::Value *ModuleBuilder::valueForGrouping(shared_ptr<ExpressionGrouping> expression) {
return valueForExpression(expression->getExpression());
}
llvm::Value *ModuleBuilder::valueForBinary(shared_ptr<ExpressionBinary> expression) {
llvm::Value *leftValue = valueForExpression(expression->getLeft());
llvm::Value *rightValue = valueForExpression(expression->getRight());
switch (expression->getOperation()) {
case ExpressionBinary::Operation::EQUAL:
return builder->CreateICmpEQ(leftValue, rightValue);
case ExpressionBinary::Operation::NOT_EQUAL:
return builder->CreateICmpNE(leftValue, rightValue);
case ExpressionBinary::Operation::LESS:
return builder->CreateICmpSLT(leftValue, rightValue);
case ExpressionBinary::Operation::LESS_EQUAL:
return builder->CreateICmpSLE(leftValue, rightValue);
case ExpressionBinary::Operation::GREATER:
return builder->CreateICmpSGT(leftValue, rightValue);
case ExpressionBinary::Operation::GREATER_EQUAL:
return builder->CreateICmpSGE(leftValue, rightValue);
case ExpressionBinary::Operation::ADD:
return builder->CreateNSWAdd(leftValue, rightValue);
case ExpressionBinary::Operation::SUB:
return builder->CreateNSWSub(leftValue, rightValue);
case ExpressionBinary::Operation::MUL:
return builder->CreateNSWMul(leftValue, rightValue);
case ExpressionBinary::Operation::DIV:
return builder->CreateSDiv(leftValue, rightValue);
case ExpressionBinary::Operation::MOD:
return builder->CreateSRem(leftValue, rightValue);
} }
return module; }
}*/

View File

@@ -3,6 +3,9 @@
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/raw_ostream.h"
#include "Expression.h" #include "Expression.h"
#include "Statement.h" #include "Statement.h"
@@ -10,7 +13,7 @@
using namespace std; using namespace std;
class ModuleBuilder { class ModuleBuilder {
/*private: private:
shared_ptr<llvm::LLVMContext> context; shared_ptr<llvm::LLVMContext> context;
shared_ptr<llvm::Module> module; shared_ptr<llvm::Module> module;
shared_ptr<llvm::IRBuilder<>> builder; shared_ptr<llvm::IRBuilder<>> builder;
@@ -20,16 +23,20 @@ class ModuleBuilder {
vector<shared_ptr<Statement>> statements; vector<shared_ptr<Statement>> statements;
void buildCodeForStatement(shared_ptr<Statement> statement); void buildStatement(shared_ptr<Statement> statement);
void buildFunction(shared_ptr<Statement> statement); void buildFunctionDeclaration(shared_ptr<StatementFunctionDeclaration> statement);
void buildBlock(shared_ptr<Statement> statement); void buildBlock(shared_ptr<StatementBlock> statement);
void buildReturn(shared_ptr<Statement> statement); void buildReturn(shared_ptr<StatementReturn> statement);
void buildExpression(shared_ptr<Statement> statement); void buildExpression(shared_ptr<StatementExpression> statement);
llvm::Value *valueForExpression(shared_ptr<Expression> expression); llvm::Value *valueForExpression(shared_ptr<Expression> expression);
llvm::Value *valueForLiteral(shared_ptr<ExpressionLiteral> expression);
llvm::Value *valueForGrouping(shared_ptr<ExpressionGrouping> expression);
llvm::Value *valueForBinary(shared_ptr<ExpressionBinary> expression);
public: public:
ModuleBuilder(vector<shared_ptr<Statement>> statements); ModuleBuilder(vector<shared_ptr<Statement>> statements);
shared_ptr<llvm::Module> getModule();*/ shared_ptr<llvm::Module> getModule();
}; };
#endif #endif

View File

@@ -50,9 +50,9 @@ int main(int argc, char **argv) {
cout << endl; cout << endl;
} }
//ModuleBuilder moduleBuilder(statements); ModuleBuilder moduleBuilder(statements);
//shared_ptr<llvm::Module> module = moduleBuilder.getModule(); shared_ptr<llvm::Module> module = moduleBuilder.getModule();
//module->print(llvm::outs(), nullptr); module->print(llvm::outs(), nullptr);
//CodeGenerator codeGenerator(module); //CodeGenerator codeGenerator(module);
//codeGenerator.generateObjectFile("dummy.s"); //codeGenerator.generateObjectFile("dummy.s");