Generate functions

This commit is contained in:
Rafał Grodziński
2025-06-02 10:45:44 +09:00
parent 2cecb456bb
commit 2ef888e374
7 changed files with 92 additions and 31 deletions

View File

@@ -1,10 +1,10 @@
#ifndef LEXER_H
#define LEXER_H
#include "Token.h"
#include <vector>
#include "Token.h"
using namespace std;
class Lexer {

View File

@@ -3,7 +3,7 @@
#include "llvm/IR/Constants.h"
#include "llvm/Support/raw_ostream.h"
ModuleBuilder::ModuleBuilder(shared_ptr<Expression> expression): expression(expression) {
ModuleBuilder::ModuleBuilder(vector<shared_ptr<Statement>> statements): statements(statements) {
context = make_shared<llvm::LLVMContext>();
module = make_shared<llvm::Module>("dummy", *context);
builder = make_shared<llvm::IRBuilder<>>(*context);
@@ -12,6 +12,52 @@ ModuleBuilder::ModuleBuilder(shared_ptr<Expression> expression): expression(expr
int32Type = llvm::Type::getInt32Ty(*context);
}
void ModuleBuilder::buildCodeForStatement(shared_ptr<Statement> statement) {
switch (statement->getKind()) {
case Statement::Kind::FUNCTION_DECLARATION:
buildFunction(statement);
break;
case Statement::Kind::BLOCK:
buildBlock(statement);
break;
case Statement::Kind::RETURN:
buildReturn(statement);
break;
case Statement::Kind::EXPRESSION:
buildExpression(statement);
break;
default:
exit(1);
}
}
void ModuleBuilder::buildFunction(shared_ptr<Statement> statement) {
llvm::FunctionType *funType = llvm::FunctionType::get(int32Type, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get());
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block);
buildCodeForStatement(statement->getBlockStatement());
}
void ModuleBuilder::buildBlock(shared_ptr<Statement> statement) {
for (shared_ptr<Statement> &innerStatement : statement->getStatements()) {
buildCodeForStatement(innerStatement);
}
}
void ModuleBuilder::buildReturn(shared_ptr<Statement> statement) {
if (statement->getExpression() != nullptr) {
llvm::Value *value = valueForExpression(statement->getExpression());
builder->CreateRet(value);
} else {
builder->CreateRetVoid();
}
}
void ModuleBuilder::buildExpression(shared_ptr<Statement> statement) {
}
llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression) {
switch (expression->getKind()) {
case Expression::Kind::LITERAL:
@@ -38,18 +84,8 @@ llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression
}
shared_ptr<llvm::Module> ModuleBuilder::getModule() {
//llvm::Value *value = valueForExpression(expression);
llvm::FunctionType *fType = llvm::FunctionType::get(int32Type, false);
llvm::Function *f = llvm::Function::Create(fType, llvm::GlobalValue::InternalLinkage, "dummyFunc", module.get());
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, "entry", f);
builder->SetInsertPoint(block);
llvm::Value *value = valueForExpression(expression);
//builder->CreateRetVoid();
builder->CreateRet(value);
//value->print(llvm::outs(), false);
//cout << endl;
for (shared_ptr<Statement> &statement : statements) {
buildCodeForStatement(statement);
}
return module;
}

View File

@@ -3,7 +3,9 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "Expression.h"
#include "Statement.h"
using namespace std;
@@ -16,11 +18,17 @@ private:
llvm::Type *voidType;
llvm::IntegerType *int32Type;
shared_ptr<Expression> expression;
vector<shared_ptr<Statement>> statements;
void buildCodeForStatement(shared_ptr<Statement> statement);
void buildFunction(shared_ptr<Statement> statement);
void buildBlock(shared_ptr<Statement> statement);
void buildReturn(shared_ptr<Statement> statement);
void buildExpression(shared_ptr<Statement> statement);
llvm::Value *valueForExpression(shared_ptr<Expression> expression);
public:
ModuleBuilder(shared_ptr<Expression> expression);
ModuleBuilder(vector<shared_ptr<Statement>> statements);
shared_ptr<llvm::Module> getModule();
};

View File

@@ -2,6 +2,7 @@
#define PARSER_H
#include <vector>
#include "Token.h"
#include "Expression.h"
#include "Statement.h"

View File

@@ -4,14 +4,30 @@ Statement::Statement(Kind kind, shared_ptr<Token> token, shared_ptr<Expression>
kind(kind), token(token), expression(expression), blockStatement(blockStatement), statements(statements), name(name) {
}
shared_ptr<Expression> Statement::getExpression() {
return expression;
Statement::Kind Statement::getKind() {
return kind;
}
shared_ptr<Token> Statement::getToken() {
return token;
}
shared_ptr<Expression> Statement::getExpression() {
return expression;
}
shared_ptr<Statement> Statement::getBlockStatement() {
return blockStatement;
}
vector<shared_ptr<Statement>> Statement::getStatements() {
return statements;
}
string Statement::getName() {
return name;
}
bool Statement::isValid() {
return kind != Statement::Kind::INVALID;
}

View File

@@ -11,10 +11,10 @@ using namespace std;
class Statement {
public:
enum Kind {
EXPRESSION,
BLOCK,
FUNCTION_DECLARATION,
BLOCK,
RETURN,
EXPRESSION,
INVALID
};
@@ -28,8 +28,12 @@ private:
public:
Statement(Kind kind, shared_ptr<Token> token, shared_ptr<Expression> expression, shared_ptr<Statement> blockStatement, vector<shared_ptr<Statement>> statements, string name);
Kind getKind();
shared_ptr<Token> getToken();
shared_ptr<Expression> getExpression();
shared_ptr<Statement> getBlockStatement();
vector<shared_ptr<Statement>> getStatements();
string getName();
bool isValid();
string toString();
};

View File

@@ -52,18 +52,14 @@ int main(int argc, char **argv) {
cout << statement->toString();
cout << endl;
}
//shared_ptr<Expression> expression = parser.getExpression();
//if (!expression) {
// exit(1);
//}
//cout << expression->toString() << endl;
//ModuleBuilder moduleBuilder(expression);
//shared_ptr<llvm::Module> module = moduleBuilder.getModule();
//module->print(llvm::outs(), nullptr);
ModuleBuilder moduleBuilder(statements);
shared_ptr<llvm::Module> module = moduleBuilder.getModule();
module->print(llvm::outs(), nullptr);
//CodeGenerator codeGenerator(module);;
//codeGenerator.generateObjectFile("dummy.s");
CodeGenerator codeGenerator(module);
codeGenerator.generateObjectFile("dummy.s");
return 0;
}