Generate code for custom type

This commit is contained in:
Rafał Grodziński
2025-08-05 11:16:50 +09:00
parent dc8d10c81c
commit b6c2ff3983
8 changed files with 113 additions and 22 deletions

View File

@@ -16,6 +16,7 @@
#include "Parser/Statement/StatementFunction.h"
#include "Parser/Statement/StatementRawFunction.h"
#include "Parser/Statement/StatementType.h"
#include "Parser/Statement/StatementVariable.h"
#include "Parser/Statement/StatementAssignment.h"
#include "Parser/Statement/StatementReturn.h"
@@ -68,6 +69,9 @@ void ModuleBuilder::buildStatement(shared_ptr<Statement> statement) {
case StatementKind::RAW_FUNCTION:
buildRawFunction(dynamic_pointer_cast<StatementRawFunction>(statement));
break;
case StatementKind::TYPE:
buildType(dynamic_pointer_cast<StatementType>(statement));
break;
case StatementKind::VARIABLE:
buildVarDeclaration(dynamic_pointer_cast<StatementVariable>(statement));
break;
@@ -160,6 +164,14 @@ void ModuleBuilder::buildRawFunction(shared_ptr<StatementRawFunction> statement)
return;
}
void ModuleBuilder::buildType(shared_ptr<StatementType> statement) {
llvm::StructType *structType = llvm::StructType::create(*context, statement->getIdentifier());
vector<llvm::Type *> elements;
structType->setBody(elements, false);
if (!setStruct(statement->getIdentifier(), structType))
return;
}
void ModuleBuilder::buildVarDeclaration(shared_ptr<StatementVariable> statement) {
if (statement->getValueType()->getKind() == ValueTypeKind::DATA) {
vector<llvm::Value *> values = valuesForExpression(statement->getExpression());
@@ -177,6 +189,11 @@ void ModuleBuilder::buildVarDeclaration(shared_ptr<StatementVariable> statement)
builder->CreateStore(values[i], elementPtr);
}
} else if (statement->getValueType()->getKind() == ValueTypeKind::TYPE) {
llvm::StructType *type = (llvm::StructType *)typeForValueType(statement->getValueType(), 0);
llvm::AllocaInst *alloca = builder->CreateAlloca(type, nullptr, statement->getName());
if (!setAlloca(statement->getName(), alloca))
return;
} else {
llvm::Value *value = valueForExpression(statement->getExpression());
if (value == nullptr)
@@ -671,6 +688,29 @@ llvm::InlineAsm *ModuleBuilder::getRawFun(string name) {
return nullptr;
}
bool ModuleBuilder::setStruct(string name, llvm::StructType *structType) {
if (scopes.top().structTypeMap[name] != nullptr) {
markError(0, 0, format("Type \"{}\" already defined in scope", name));
return false;
}
scopes.top().structTypeMap[name] = structType;
return true;
}
llvm::StructType *ModuleBuilder::getStructType(string name) {
stack<Scope> scopes = this->scopes;
while (!scopes.empty()) {
llvm::StructType *structType = scopes.top().structTypeMap[name];
if (structType != nullptr)
return structType;
scopes.pop();
}
return nullptr;
}
llvm::Type *ModuleBuilder::typeForValueType(shared_ptr<ValueType> valueType, int count) {
if (valueType == nullptr) {
markError(0, 0, "Missing type");
@@ -699,6 +739,8 @@ llvm::Type *ModuleBuilder::typeForValueType(shared_ptr<ValueType> valueType, int
count = valueType->getValueArg();
return llvm::ArrayType::get(typeForValueType(valueType->getSubType(), count), count);
}
case ValueTypeKind::TYPE:
return getStructType(valueType->getTypeName());
}
}

View File

@@ -29,6 +29,7 @@ enum class ExpressionBinaryOperation;
class Statement;
class StatementFunction;
class StatementRawFunction;
class StatementType;
class StatementVariable;
class StatementAssignment;
class StatementReturn;
@@ -43,6 +44,7 @@ typedef struct {
map<string, llvm::AllocaInst*> allocaMap;
map<string, llvm::Function*> funMap;
map<string, llvm::InlineAsm*> rawFunMap;
map<string, llvm::StructType*> structTypeMap;
} Scope;
class ModuleBuilder {
@@ -69,6 +71,7 @@ private:
void buildStatement(shared_ptr<Statement> statement);
void buildFunction(shared_ptr<StatementFunction> statement);
void buildRawFunction(shared_ptr<StatementRawFunction> statement);
void buildType(shared_ptr<StatementType> statement);
void buildVarDeclaration(shared_ptr<StatementVariable> statement);
void buildAssignment(shared_ptr<StatementAssignment> statement);
void buildBlock(shared_ptr<StatementBlock> statement);
@@ -101,6 +104,9 @@ private:
bool setRawFun(string name, llvm::InlineAsm *rawFun);
llvm::InlineAsm *getRawFun(string name);
bool setStruct(string name, llvm::StructType *structType);
llvm::StructType *getStructType(string name);
llvm::Type *typeForValueType(shared_ptr<ValueType> valueType, int count = 0);
void markError(int line, int column, string message);