Read return types

This commit is contained in:
Rafał Grodziński
2025-06-16 18:18:45 +09:00
parent 7397183c34
commit 5da89c2e23
5 changed files with 58 additions and 40 deletions

View File

@@ -147,7 +147,7 @@ bool ExpressionLiteral::getBoolValue() {
return boolValue;
}
int32_t ExpressionLiteral::getSInt32Value() {
int32_t ExpressionLiteral::getSint32Value() {
return sint32Value;
}

View File

@@ -38,7 +38,7 @@ private:
public:
ExpressionLiteral(shared_ptr<Token> token);
bool getBoolValue();
int32_t getSInt32Value();
int32_t getSint32Value();
float getReal32Value();
string toString(int indent) override;
};

View File

@@ -9,7 +9,7 @@ moduleName(moduleName), sourceFileName(sourceFileName), statements(statements) {
typeVoid = llvm::Type::getVoidTy(*context);
typeBool = llvm::Type::getInt1Ty(*context);
typeSInt32 = llvm::Type::getInt32Ty(*context);
typeSint32 = llvm::Type::getInt32Ty(*context);
typeReal32 = llvm::Type::getFloatTy(*context);
}
@@ -43,11 +43,36 @@ void ModuleBuilder::buildStatement(shared_ptr<Statement> statement) {
}
void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunctionDeclaration> statement) {
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), false);
// get argument types
vector<llvm::Type *> types;
for (pair<string, ValueType> &arg : statement->getArguments()) {
types.push_back(typeForValueType(arg.second));
}
// build function declaration
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get());
funMap[statement->getName()] = fun;
// define function body
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block);
// build arguments
int i=0;
for (auto &arg : fun->args()) {
string name = statement->getArguments()[i].first;
llvm::Type *type = types[i];
arg.setName(name);
llvm::AllocaInst *alloca = builder->CreateAlloca(type, nullptr, name);
allocaMap[name] = alloca;
builder->CreateStore(&arg, alloca);
i++;
}
// build function body
buildStatement(statement->getStatementBlock());
}
@@ -105,7 +130,7 @@ llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr<ExpressionLiteral> expres
case ValueType::BOOL:
return llvm::ConstantInt::get(typeBool, expression->getBoolValue(), true);
case ValueType::SINT32:
return llvm::ConstantInt::get(typeSInt32, expression->getSInt32Value(), true);
return llvm::ConstantInt::get(typeSint32, expression->getSint32Value(), true);
case ValueType::REAL32:
return llvm::ConstantInt::get(typeReal32, expression->getReal32Value(), true);
}
@@ -116,25 +141,24 @@ llvm::Value *ModuleBuilder::valueForGrouping(shared_ptr<ExpressionGrouping> expr
}
llvm::Value *ModuleBuilder::valueForBinary(shared_ptr<ExpressionBinary> expression) {
switch (expression->getLeft()->getValueType()) {
case ValueType::BOOL:
return valueForBinaryBool(expression);
case ValueType::SINT32:
return valueForBinaryInteger(expression);
case ValueType::REAL32:
return valueForBinaryReal(expression);
case ValueType::NONE:
return valueForBinaryInteger(expression);
default:
failed("Unexpected operation");
}
}
llvm::Value *ModuleBuilder::valueForBinaryBool(shared_ptr<ExpressionBinary> expression) {
llvm::Value *leftValue = valueForExpression(expression->getLeft());
llvm::Value *rightValue = valueForExpression(expression->getRight());
switch (expression->getOperation()) {
llvm::Type *type = leftValue->getType();
if (type == typeBool) {
return valueForBinaryBool(expression->getOperation(), leftValue, rightValue);
} else if (type == typeSint32 || type == typeVoid) {
return valueForBinaryInteger(expression->getOperation(), leftValue, rightValue);
} else if (type == typeReal32) {
return valueForBinaryReal(expression->getOperation(), leftValue, rightValue);
}
failed("Unexpected operation");
}
llvm::Value *ModuleBuilder::valueForBinaryBool(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) {
switch (operation) {
case ExpressionBinary::Operation::EQUAL:
return builder->CreateICmpEQ(leftValue, rightValue);
case ExpressionBinary::Operation::NOT_EQUAL:
@@ -144,11 +168,8 @@ llvm::Value *ModuleBuilder::valueForBinaryBool(shared_ptr<ExpressionBinary> expr
}
}
llvm::Value *ModuleBuilder::valueForBinaryInteger(shared_ptr<ExpressionBinary> expression) {
llvm::Value *leftValue = valueForExpression(expression->getLeft());
llvm::Value *rightValue = valueForExpression(expression->getRight());
switch (expression->getOperation()) {
llvm::Value *ModuleBuilder::valueForBinaryInteger(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) {
switch (operation) {
case ExpressionBinary::Operation::EQUAL:
return builder->CreateICmpEQ(leftValue, rightValue);
case ExpressionBinary::Operation::NOT_EQUAL:
@@ -174,11 +195,8 @@ llvm::Value *ModuleBuilder::valueForBinaryInteger(shared_ptr<ExpressionBinary> e
}
}
llvm::Value *ModuleBuilder::valueForBinaryReal(shared_ptr<ExpressionBinary> expression) {
llvm::Value *leftValue = valueForExpression(expression->getLeft());
llvm::Value *rightValue = valueForExpression(expression->getRight());
switch (expression->getOperation()) {
llvm::Value *ModuleBuilder::valueForBinaryReal(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) {
switch (operation) {
case ExpressionBinary::Operation::EQUAL:
return builder->CreateFCmpOEQ(leftValue, rightValue);
case ExpressionBinary::Operation::NOT_EQUAL:
@@ -257,7 +275,8 @@ llvm::Value *ModuleBuilder::valueForVar(shared_ptr<ExpressionVar> expression) {
llvm::Value *ModuleBuilder::valueForCall(shared_ptr<ExpressionCall> expression) {
llvm::Function *fun = funMap[expression->getName()];
failed("Function " + expression->getName() + " not defined");
if (fun == nullptr)
failed("Function " + expression->getName() + " not defined");
llvm::FunctionType *funType = fun->getFunctionType();
vector<llvm::Value*> argValues;
for (shared_ptr<Expression> &argumentExpression : expression->getArgumentExpressions()) {
@@ -274,7 +293,7 @@ llvm::Type *ModuleBuilder::typeForValueType(ValueType valueType) {
case ValueType::BOOL:
return typeBool;
case ValueType::SINT32:
return typeSInt32;
return typeSint32;
case ValueType::REAL32:
return typeReal32;
}

View File

@@ -25,7 +25,7 @@ private:
llvm::Type *typeVoid;
llvm::Type *typeBool;
llvm::IntegerType *typeSInt32;
llvm::IntegerType *typeSint32;
llvm::Type *typeReal32;
vector<shared_ptr<Statement>> statements;
@@ -43,9 +43,9 @@ private:
llvm::Value *valueForLiteral(shared_ptr<ExpressionLiteral> expression);
llvm::Value *valueForGrouping(shared_ptr<ExpressionGrouping> expression);
llvm::Value *valueForBinary(shared_ptr<ExpressionBinary> expression);
llvm::Value *valueForBinaryBool(shared_ptr<ExpressionBinary> expression);
llvm::Value *valueForBinaryInteger(shared_ptr<ExpressionBinary> expression);
llvm::Value *valueForBinaryReal(shared_ptr<ExpressionBinary> expression);
llvm::Value *valueForBinaryBool(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue);
llvm::Value *valueForBinaryInteger(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue);
llvm::Value *valueForBinaryReal(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue);
llvm::Value *valueForIfElse(shared_ptr<ExpressionIfElse> expression);
llvm::Value *valueForVar(shared_ptr<ExpressionVar> expression);
llvm::Value *valueForCall(shared_ptr<ExpressionCall> expression);

View File

@@ -103,12 +103,11 @@ shared_ptr<Statement> Parser::matchStatementFunctionDeclaration() {
}
shared_ptr<Statement> Parser::matchStatementVarDeclaration() {
if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::COLON, TokenKind::TYPE}, true, false))
if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::TYPE}, true, false))
return nullptr;
shared_ptr<Token> identifierToken = tokens.at(currentIndex);
currentIndex++;
currentIndex++; // skip colon
currentIndex++; // identifier
shared_ptr<Token> valueTypeToken = tokens.at(currentIndex);
ValueType valueType;