From 5da89c2e23867597435867db834f475f3fbb51ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Grodzi=C5=84ski?= Date: Mon, 16 Jun 2025 18:18:45 +0900 Subject: [PATCH] Read return types --- src/Expression.cpp | 2 +- src/Expression.h | 2 +- src/ModuleBuilder.cpp | 81 ++++++++++++++++++++++++++----------------- src/ModuleBuilder.h | 8 ++--- src/Parser.cpp | 5 ++- 5 files changed, 58 insertions(+), 40 deletions(-) diff --git a/src/Expression.cpp b/src/Expression.cpp index b85d8df..d5fd09d 100644 --- a/src/Expression.cpp +++ b/src/Expression.cpp @@ -147,7 +147,7 @@ bool ExpressionLiteral::getBoolValue() { return boolValue; } -int32_t ExpressionLiteral::getSInt32Value() { +int32_t ExpressionLiteral::getSint32Value() { return sint32Value; } diff --git a/src/Expression.h b/src/Expression.h index 6f3ec75..4fa3afe 100644 --- a/src/Expression.h +++ b/src/Expression.h @@ -38,7 +38,7 @@ private: public: ExpressionLiteral(shared_ptr token); bool getBoolValue(); - int32_t getSInt32Value(); + int32_t getSint32Value(); float getReal32Value(); string toString(int indent) override; }; diff --git a/src/ModuleBuilder.cpp b/src/ModuleBuilder.cpp index 9f26312..1cf2b26 100644 --- a/src/ModuleBuilder.cpp +++ b/src/ModuleBuilder.cpp @@ -9,7 +9,7 @@ moduleName(moduleName), sourceFileName(sourceFileName), statements(statements) { typeVoid = llvm::Type::getVoidTy(*context); typeBool = llvm::Type::getInt1Ty(*context); - typeSInt32 = llvm::Type::getInt32Ty(*context); + typeSint32 = llvm::Type::getInt32Ty(*context); typeReal32 = llvm::Type::getFloatTy(*context); } @@ -43,11 +43,36 @@ void ModuleBuilder::buildStatement(shared_ptr statement) { } void ModuleBuilder::buildFunctionDeclaration(shared_ptr statement) { - llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), false); + // get argument types + vector types; + for (pair &arg : statement->getArguments()) { + types.push_back(typeForValueType(arg.second)); + } + + // build function declaration + llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false); llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get()); funMap[statement->getName()] = fun; + + // define function body llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun); builder->SetInsertPoint(block); + + // build arguments + int i=0; + for (auto &arg : fun->args()) { + string name = statement->getArguments()[i].first; + llvm::Type *type = types[i]; + arg.setName(name); + + llvm::AllocaInst *alloca = builder->CreateAlloca(type, nullptr, name); + allocaMap[name] = alloca; + builder->CreateStore(&arg, alloca); + + i++; + } + + // build function body buildStatement(statement->getStatementBlock()); } @@ -105,7 +130,7 @@ llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr expres case ValueType::BOOL: return llvm::ConstantInt::get(typeBool, expression->getBoolValue(), true); case ValueType::SINT32: - return llvm::ConstantInt::get(typeSInt32, expression->getSInt32Value(), true); + return llvm::ConstantInt::get(typeSint32, expression->getSint32Value(), true); case ValueType::REAL32: return llvm::ConstantInt::get(typeReal32, expression->getReal32Value(), true); } @@ -116,25 +141,24 @@ llvm::Value *ModuleBuilder::valueForGrouping(shared_ptr expr } llvm::Value *ModuleBuilder::valueForBinary(shared_ptr expression) { - switch (expression->getLeft()->getValueType()) { - case ValueType::BOOL: - return valueForBinaryBool(expression); - case ValueType::SINT32: - return valueForBinaryInteger(expression); - case ValueType::REAL32: - return valueForBinaryReal(expression); - case ValueType::NONE: - return valueForBinaryInteger(expression); - default: - failed("Unexpected operation"); - } -} - -llvm::Value *ModuleBuilder::valueForBinaryBool(shared_ptr expression) { llvm::Value *leftValue = valueForExpression(expression->getLeft()); llvm::Value *rightValue = valueForExpression(expression->getRight()); - switch (expression->getOperation()) { + llvm::Type *type = leftValue->getType(); + + if (type == typeBool) { + return valueForBinaryBool(expression->getOperation(), leftValue, rightValue); + } else if (type == typeSint32 || type == typeVoid) { + return valueForBinaryInteger(expression->getOperation(), leftValue, rightValue); + } else if (type == typeReal32) { + return valueForBinaryReal(expression->getOperation(), leftValue, rightValue); + } + + failed("Unexpected operation"); +} + +llvm::Value *ModuleBuilder::valueForBinaryBool(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) { + switch (operation) { case ExpressionBinary::Operation::EQUAL: return builder->CreateICmpEQ(leftValue, rightValue); case ExpressionBinary::Operation::NOT_EQUAL: @@ -144,11 +168,8 @@ llvm::Value *ModuleBuilder::valueForBinaryBool(shared_ptr expr } } -llvm::Value *ModuleBuilder::valueForBinaryInteger(shared_ptr expression) { - llvm::Value *leftValue = valueForExpression(expression->getLeft()); - llvm::Value *rightValue = valueForExpression(expression->getRight()); - - switch (expression->getOperation()) { +llvm::Value *ModuleBuilder::valueForBinaryInteger(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) { + switch (operation) { case ExpressionBinary::Operation::EQUAL: return builder->CreateICmpEQ(leftValue, rightValue); case ExpressionBinary::Operation::NOT_EQUAL: @@ -174,11 +195,8 @@ llvm::Value *ModuleBuilder::valueForBinaryInteger(shared_ptr e } } -llvm::Value *ModuleBuilder::valueForBinaryReal(shared_ptr expression) { - llvm::Value *leftValue = valueForExpression(expression->getLeft()); - llvm::Value *rightValue = valueForExpression(expression->getRight()); - - switch (expression->getOperation()) { +llvm::Value *ModuleBuilder::valueForBinaryReal(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue) { + switch (operation) { case ExpressionBinary::Operation::EQUAL: return builder->CreateFCmpOEQ(leftValue, rightValue); case ExpressionBinary::Operation::NOT_EQUAL: @@ -257,7 +275,8 @@ llvm::Value *ModuleBuilder::valueForVar(shared_ptr expression) { llvm::Value *ModuleBuilder::valueForCall(shared_ptr expression) { llvm::Function *fun = funMap[expression->getName()]; - failed("Function " + expression->getName() + " not defined"); + if (fun == nullptr) + failed("Function " + expression->getName() + " not defined"); llvm::FunctionType *funType = fun->getFunctionType(); vector argValues; for (shared_ptr &argumentExpression : expression->getArgumentExpressions()) { @@ -274,7 +293,7 @@ llvm::Type *ModuleBuilder::typeForValueType(ValueType valueType) { case ValueType::BOOL: return typeBool; case ValueType::SINT32: - return typeSInt32; + return typeSint32; case ValueType::REAL32: return typeReal32; } diff --git a/src/ModuleBuilder.h b/src/ModuleBuilder.h index d711088..ef993be 100644 --- a/src/ModuleBuilder.h +++ b/src/ModuleBuilder.h @@ -25,7 +25,7 @@ private: llvm::Type *typeVoid; llvm::Type *typeBool; - llvm::IntegerType *typeSInt32; + llvm::IntegerType *typeSint32; llvm::Type *typeReal32; vector> statements; @@ -43,9 +43,9 @@ private: llvm::Value *valueForLiteral(shared_ptr expression); llvm::Value *valueForGrouping(shared_ptr expression); llvm::Value *valueForBinary(shared_ptr expression); - llvm::Value *valueForBinaryBool(shared_ptr expression); - llvm::Value *valueForBinaryInteger(shared_ptr expression); - llvm::Value *valueForBinaryReal(shared_ptr expression); + llvm::Value *valueForBinaryBool(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue); + llvm::Value *valueForBinaryInteger(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue); + llvm::Value *valueForBinaryReal(ExpressionBinary::Operation operation, llvm::Value *leftValue, llvm::Value *rightValue); llvm::Value *valueForIfElse(shared_ptr expression); llvm::Value *valueForVar(shared_ptr expression); llvm::Value *valueForCall(shared_ptr expression); diff --git a/src/Parser.cpp b/src/Parser.cpp index e6f8492..30fc57b 100644 --- a/src/Parser.cpp +++ b/src/Parser.cpp @@ -103,12 +103,11 @@ shared_ptr Parser::matchStatementFunctionDeclaration() { } shared_ptr Parser::matchStatementVarDeclaration() { - if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::COLON, TokenKind::TYPE}, true, false)) + if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::TYPE}, true, false)) return nullptr; shared_ptr identifierToken = tokens.at(currentIndex); - currentIndex++; - currentIndex++; // skip colon + currentIndex++; // identifier shared_ptr valueTypeToken = tokens.at(currentIndex); ValueType valueType;