From 88eccac66772f463bd2396335806935071b5d54d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Grodzi=C5=84ski?= Date: Sun, 8 Jun 2025 12:13:23 +0900 Subject: [PATCH] Pass return value --- src/ModuleBuilder.cpp | 8 ++++---- src/ModuleBuilder.h | 2 +- src/Parser.cpp | 28 +++++++++++++++++----------- src/Statement.cpp | 7 ++++--- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/ModuleBuilder.cpp b/src/ModuleBuilder.cpp index c7f5079..7e399b8 100644 --- a/src/ModuleBuilder.cpp +++ b/src/ModuleBuilder.cpp @@ -38,7 +38,7 @@ void ModuleBuilder::buildStatement(shared_ptr statement) { } void ModuleBuilder::buildFunctionDeclaration(shared_ptr statement) { - llvm::FunctionType *funType = llvm::FunctionType::get(typeSInt32, false); + llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), false); llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get()); llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun); builder->SetInsertPoint(block); @@ -224,7 +224,7 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr expressi // Merge fun->insert(fun->end(), mergeBlock); builder->SetInsertPoint(mergeBlock); - llvm::PHINode *phi = builder->CreatePHI(typeForExpression(expression), valuesCount, "phii"); + llvm::PHINode *phi = builder->CreatePHI(typeForValueType(expression->getValueType()), valuesCount, "phii"); phi->addIncoming(thenValue, thenBlock); if (elseValue != nullptr) phi->addIncoming(elseValue, elseBlock); @@ -233,8 +233,8 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr expressi return phi; } -llvm::Type *ModuleBuilder::typeForExpression(shared_ptr expression) { - switch (expression->getValueType()) { +llvm::Type *ModuleBuilder::typeForValueType(ValueType valueType) { + switch (valueType) { case ValueType::VOID: return typeVoid; case ValueType::BOOL: diff --git a/src/ModuleBuilder.h b/src/ModuleBuilder.h index 2feb4f0..bc16b83 100644 --- a/src/ModuleBuilder.h +++ b/src/ModuleBuilder.h @@ -40,7 +40,7 @@ private: llvm::Value *valueForBinaryReal(shared_ptr expression); llvm::Value *valueForIfElse(shared_ptr expression); - llvm::Type *typeForExpression(shared_ptr expression); + llvm::Type *typeForValueType(ValueType valueType); public: ModuleBuilder(vector> statements); diff --git a/src/Parser.cpp b/src/Parser.cpp index a7ad080..fed88a0 100644 --- a/src/Parser.cpp +++ b/src/Parser.cpp @@ -74,14 +74,17 @@ shared_ptr Parser::matchStatementBlock() { else statements.push_back(statement); } - currentIndex++; // consune ';' and ':' - - if (!tokens.at(currentIndex)->isOfKind({TokenKind::NEW_LINE, TokenKind::END})) - return matchStatementInvalid(); - - if (tokens.at(currentIndex)->getKind() == TokenKind::NEW_LINE) + // consune ';' only + if (tokens.at(currentIndex)->getKind() == TokenKind::SEMICOLON) { currentIndex++; + if (!tokens.at(currentIndex)->isOfKind({TokenKind::NEW_LINE, TokenKind::END})) + return matchStatementInvalid(); + + if (tokens.at(currentIndex)->getKind() == TokenKind::NEW_LINE) + currentIndex++; + } + return make_shared(statements); } @@ -95,10 +98,11 @@ shared_ptr Parser::matchStatementReturn() { if (expression != nullptr && !expression->isValid()) return matchStatementInvalid(); - if (tokens.at(currentIndex)->getKind() != TokenKind::NEW_LINE) + if (!tokens.at(currentIndex)->isOfKind({TokenKind::NEW_LINE, TokenKind::SEMICOLON})) return matchStatementInvalid(); - currentIndex++; // new line + if (tokens.at(currentIndex)->getKind() == TokenKind::NEW_LINE) + currentIndex++; // new line return make_shared(expression); } @@ -289,9 +293,11 @@ shared_ptr Parser::matchExpressionIfElse() { // Match else blcok shared_ptr elseBlock; - shared_ptr lastToken = tokens.at(currentIndex-2); - // ':' marks else block - if (lastToken->getKind() == TokenKind::COLON) { + if (tokens.at(currentIndex)->getKind() == TokenKind::COLON) { + currentIndex++; + if (tokens.at(currentIndex)->getKind() == TokenKind::NEW_LINE) + currentIndex++; + elseBlock = matchStatementBlock(); if (elseBlock == nullptr) return matchExpressionInvalid(); diff --git a/src/Statement.cpp b/src/Statement.cpp index 94bc2bc..d4a2cab 100644 --- a/src/Statement.cpp +++ b/src/Statement.cpp @@ -98,9 +98,10 @@ string StatementReturn::toString(int indent) { string value; for (int ind=0; indtoString(0) + ")"; + value += "RETURN:\n"; + for (int ind=0; indtoString(indent+1); value += "\n"; return value; }