Pass return value

This commit is contained in:
Rafał Grodziński
2025-06-08 12:13:23 +09:00
parent 53c5e2c22e
commit 88eccac667
4 changed files with 26 additions and 19 deletions

View File

@@ -38,7 +38,7 @@ void ModuleBuilder::buildStatement(shared_ptr<Statement> statement) {
}
void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunctionDeclaration> statement) {
llvm::FunctionType *funType = llvm::FunctionType::get(typeSInt32, false);
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::InternalLinkage, statement->getName(), module.get());
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block);
@@ -224,7 +224,7 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expressi
// Merge
fun->insert(fun->end(), mergeBlock);
builder->SetInsertPoint(mergeBlock);
llvm::PHINode *phi = builder->CreatePHI(typeForExpression(expression), valuesCount, "phii");
llvm::PHINode *phi = builder->CreatePHI(typeForValueType(expression->getValueType()), valuesCount, "phii");
phi->addIncoming(thenValue, thenBlock);
if (elseValue != nullptr)
phi->addIncoming(elseValue, elseBlock);
@@ -233,8 +233,8 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expressi
return phi;
}
llvm::Type *ModuleBuilder::typeForExpression(shared_ptr<Expression> expression) {
switch (expression->getValueType()) {
llvm::Type *ModuleBuilder::typeForValueType(ValueType valueType) {
switch (valueType) {
case ValueType::VOID:
return typeVoid;
case ValueType::BOOL:

View File

@@ -40,7 +40,7 @@ private:
llvm::Value *valueForBinaryReal(shared_ptr<ExpressionBinary> expression);
llvm::Value *valueForIfElse(shared_ptr<ExpressionIfElse> expression);
llvm::Type *typeForExpression(shared_ptr<Expression> expression);
llvm::Type *typeForValueType(ValueType valueType);
public:
ModuleBuilder(vector<shared_ptr<Statement>> statements);

View File

@@ -74,14 +74,17 @@ shared_ptr<Statement> Parser::matchStatementBlock() {
else
statements.push_back(statement);
}
currentIndex++; // consune ';' and ':'
if (!tokens.at(currentIndex)->isOfKind({TokenKind::NEW_LINE, TokenKind::END}))
return matchStatementInvalid();
if (tokens.at(currentIndex)->getKind() == TokenKind::NEW_LINE)
// consune ';' only
if (tokens.at(currentIndex)->getKind() == TokenKind::SEMICOLON) {
currentIndex++;
if (!tokens.at(currentIndex)->isOfKind({TokenKind::NEW_LINE, TokenKind::END}))
return matchStatementInvalid();
if (tokens.at(currentIndex)->getKind() == TokenKind::NEW_LINE)
currentIndex++;
}
return make_shared<StatementBlock>(statements);
}
@@ -95,10 +98,11 @@ shared_ptr<Statement> Parser::matchStatementReturn() {
if (expression != nullptr && !expression->isValid())
return matchStatementInvalid();
if (tokens.at(currentIndex)->getKind() != TokenKind::NEW_LINE)
if (!tokens.at(currentIndex)->isOfKind({TokenKind::NEW_LINE, TokenKind::SEMICOLON}))
return matchStatementInvalid();
currentIndex++; // new line
if (tokens.at(currentIndex)->getKind() == TokenKind::NEW_LINE)
currentIndex++; // new line
return make_shared<StatementReturn>(expression);
}
@@ -289,9 +293,11 @@ shared_ptr<Expression> Parser::matchExpressionIfElse() {
// Match else blcok
shared_ptr<Statement> elseBlock;
shared_ptr<Token> lastToken = tokens.at(currentIndex-2);
// ':' marks else block
if (lastToken->getKind() == TokenKind::COLON) {
if (tokens.at(currentIndex)->getKind() == TokenKind::COLON) {
currentIndex++;
if (tokens.at(currentIndex)->getKind() == TokenKind::NEW_LINE)
currentIndex++;
elseBlock = matchStatementBlock();
if (elseBlock == nullptr)
return matchExpressionInvalid();

View File

@@ -98,9 +98,10 @@ string StatementReturn::toString(int indent) {
string value;
for (int ind=0; ind<indent; ind++)
value += " ";
value += "RETURN";
if (expression != nullptr)
value += "(" + expression->toString(0) + ")";
value += "RETURN:\n";
for (int ind=0; ind<indent+1; ind++)
value += " ";
value += expression->toString(indent+1);
value += "\n";
return value;
}