From 20a3adcca2af562a5bc9653c3f233fbd4f31f90a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Grodzi=C5=84ski?= Date: Mon, 16 Jun 2025 11:29:30 +0900 Subject: [PATCH] Parse function arguments and return type --- src/Expression.cpp | 16 +++++----- src/Lexer.cpp | 44 ++++++++++---------------- src/Lexer.h | 2 +- src/ModuleBuilder.cpp | 4 +-- src/Parser.cpp | 73 +++++++++++++++++++++++++++++++------------ src/Parser.h | 1 + src/Statement.cpp | 10 ++++-- src/Statement.h | 4 ++- src/Token.cpp | 2 ++ src/Types.h | 3 +- 10 files changed, 96 insertions(+), 63 deletions(-) diff --git a/src/Expression.cpp b/src/Expression.cpp index d22840f..9b0ea31 100644 --- a/src/Expression.cpp +++ b/src/Expression.cpp @@ -23,7 +23,7 @@ string Expression::toString(int indent) { // // ExpressionBinary ExpressionBinary::ExpressionBinary(shared_ptr token, shared_ptr left, shared_ptr right): -Expression(ExpressionKind::BINARY, ValueType::VOID), left(left), right(right) { +Expression(ExpressionKind::BINARY, ValueType::NONE), left(left), right(right) { // Types must match if (left->getValueType() != right->getValueType()) exit(1); @@ -124,7 +124,7 @@ string ExpressionBinary::toString(int indent) { // // ExpressionLiteral ExpressionLiteral::ExpressionLiteral(shared_ptr token): -Expression(ExpressionKind::LITERAL, ValueType::VOID) { +Expression(ExpressionKind::LITERAL, ValueType::NONE) { switch (token->getKind()) { case TokenKind::BOOL: boolValue = token->getLexme().compare("true") == 0; @@ -157,8 +157,8 @@ float ExpressionLiteral::getReal32Value() { string ExpressionLiteral::toString(int indent) { switch (valueType) { - case ValueType::VOID: - return "VOID"; + case ValueType::NONE: + return "NONE"; case ValueType::BOOL: return boolValue ? "true" : "false"; case ValueType::SINT32: @@ -185,7 +185,7 @@ string ExpressionGrouping::toString(int indent) { // // ExpressionIfElse ExpressionIfElse::ExpressionIfElse(shared_ptr condition, shared_ptr thenBlock, shared_ptr elseBlock): -Expression(ExpressionKind::IF_ELSE, ValueType::VOID), condition(condition), thenBlock(thenBlock), elseBlock(elseBlock) { +Expression(ExpressionKind::IF_ELSE, ValueType::NONE), condition(condition), thenBlock(thenBlock), elseBlock(elseBlock) { // Condition must evaluate to bool if (condition->getValueType() != ValueType::BOOL) exit(1); @@ -199,7 +199,7 @@ Expression(ExpressionKind::IF_ELSE, ValueType::VOID), condition(condition), then exit(1); // get type or default to void - valueType = thenExpression ? thenExpression->getValueType() : ValueType::VOID; + valueType = thenExpression ? thenExpression->getValueType() : ValueType::NONE; } shared_ptr ExpressionIfElse::getCondition() { @@ -235,7 +235,7 @@ string ExpressionIfElse::toString(int indent) { // // ExpressionVar ExpressionVar::ExpressionVar(string name): -Expression(ExpressionKind::VAR, ValueType::VOID), name(name) { +Expression(ExpressionKind::VAR, ValueType::NONE), name(name) { } string ExpressionVar::getName() { @@ -249,7 +249,7 @@ string ExpressionVar::toString(int indent) { // // ExpressionInvalid ExpressionInvalid::ExpressionInvalid(shared_ptr token): -Expression(ExpressionKind::INVALID, ValueType::VOID), token(token) { +Expression(ExpressionKind::INVALID, ValueType::NONE), token(token) { } shared_ptr ExpressionInvalid::getToken() { diff --git a/src/Lexer.cpp b/src/Lexer.cpp index 08ca571..787f5d3 100644 --- a/src/Lexer.cpp +++ b/src/Lexer.cpp @@ -112,6 +112,10 @@ shared_ptr Lexer::nextToken() { if (token != nullptr) return token; + token = match(TokenKind::COMMA, ",", false); + if (token != nullptr) + return token; + token = match(TokenKind::COLON, ":", false); if (token != nullptr) return token; @@ -191,6 +195,7 @@ shared_ptr Lexer::nextToken() { if (token != nullptr) return token; + // literal token = match(TokenKind::BOOL, "true", true); if (token != nullptr) return token; @@ -199,7 +204,6 @@ shared_ptr Lexer::nextToken() { if (token != nullptr) return token; - // literal token = matchReal(); if (token != nullptr) return token; @@ -208,11 +212,20 @@ shared_ptr Lexer::nextToken() { if (token != nullptr) return token; - // identifier - token = matchType(); + // type + token = match(TokenKind::TYPE, "bool", true); if (token != nullptr) return token; + token = match(TokenKind::TYPE, "sint32", true); + if (token != nullptr) + return token; + + token = match(TokenKind::TYPE, "real32", true); + if (token != nullptr) + return token; + + // identifier token = matchIdentifier(); if (token != nullptr) return token; @@ -280,30 +293,6 @@ shared_ptr Lexer::matchReal() { return token; } -shared_ptr Lexer::matchType() { - bool isVarDec = tokens.size() >= 2 && - tokens.at(tokens.size() - 1)->getKind() == TokenKind::COLON && - tokens.at(tokens.size() - 2)->getKind() == TokenKind::IDENTIFIER; - - bool isFunDec = tokens.size() >= 1 && - tokens.at(tokens.size() - 1)->getKind() == TokenKind::RIGHT_ARROW; - - if (!isVarDec && !isFunDec) - return nullptr; - - int nextIndex = currentIndex; - while (nextIndex < source.length() && isIdentifier(nextIndex)) - nextIndex++; - - if (nextIndex == currentIndex || !isSeparator(nextIndex)) - return nullptr; - - string lexme = source.substr(currentIndex, nextIndex - currentIndex); - shared_ptr token = make_shared(TokenKind::TYPE, lexme, currentLine, currentColumn); - advanceWithToken(token); - return token; -} - shared_ptr Lexer::matchIdentifier() { int nextIndex = currentIndex; @@ -365,6 +354,7 @@ bool Lexer::isSeparator(int index) { case '>': case '(': case ')': + case ',': case ':': case ';': case ' ': diff --git a/src/Lexer.h b/src/Lexer.h index 6dfe653..08bfa2d 100644 --- a/src/Lexer.h +++ b/src/Lexer.h @@ -20,7 +20,7 @@ private: shared_ptr match(TokenKind kind, string lexme, bool needsSeparator); shared_ptr matchInteger(); shared_ptr matchReal(); - shared_ptr matchType(); + //shared_ptr matchType(); shared_ptr matchIdentifier(); shared_ptr matchEnd(); shared_ptr matchInvalid(); diff --git a/src/ModuleBuilder.cpp b/src/ModuleBuilder.cpp index 2381de1..3a80e16 100644 --- a/src/ModuleBuilder.cpp +++ b/src/ModuleBuilder.cpp @@ -97,7 +97,7 @@ llvm::Value *ModuleBuilder::valueForExpression(shared_ptr expression llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr expression) { switch (expression->getValueType()) { - case ValueType::VOID: + case ValueType::NONE: return llvm::UndefValue::get(typeVoid); case ValueType::BOOL: return llvm::ConstantInt::get(typeBool, expression->getBoolValue(), true); @@ -257,7 +257,7 @@ llvm::Value *ModuleBuilder::valueForVar(shared_ptr expression) { llvm::Type *ModuleBuilder::typeForValueType(ValueType valueType) { switch (valueType) { - case ValueType::VOID: + case ValueType::NONE: return typeVoid; case ValueType::BOOL: return typeBool; diff --git a/src/Parser.cpp b/src/Parser.cpp index 119b067..7b8f513 100644 --- a/src/Parser.cpp +++ b/src/Parser.cpp @@ -45,32 +45,51 @@ shared_ptr Parser::nextStatement() { } shared_ptr Parser::matchStatementFunctionDeclaration() { - if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::COLON, TokenKind::FUNCTION}, true, false)) + if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::FUNCTION}, true, false)) return nullptr; shared_ptr identifierToken = tokens.at(currentIndex); currentIndex++; - currentIndex++; // skip colon currentIndex++; // skip fun - // Return type - ValueType returnType = ValueType::VOID; - if (tryMatchingTokenKinds({TokenKind::RIGHT_ARROW}, true, true)) { - shared_ptr valueTypeToken = tokens.at(currentIndex); - - if (valueTypeToken->getLexme().compare("bool") == 0) - returnType = ValueType::BOOL; - else if (valueTypeToken->getLexme().compare("sint32") == 0) - returnType = ValueType::SINT32; - else if (valueTypeToken->getLexme().compare("real32") == 0) - returnType = ValueType::REAL32; - else - return matchStatementInvalid("Expected return type"); - - currentIndex++; // type + // Get arguments + vector> arguments; + if (tryMatchingTokenKinds({TokenKind::COLON}, true, true)) { + do { + tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true); // skip new line + if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::TYPE}, true, false)) + return matchStatementInvalid("Expected function argument"); + shared_ptr identifierToken = tokens.at(currentIndex); + currentIndex++; // identifier + shared_ptr typeToken = tokens.at(currentIndex); + currentIndex++; // type + optional argumentType = valueTypeForToken(typeToken); + if (!argumentType) + return matchStatementInvalid("Invalid argument type"); + + arguments.push_back(pair(identifierToken->getLexme(), *argumentType)); + } while (tryMatchingTokenKinds({TokenKind::COMMA}, true, true)); + } + + // consume optional new line + tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true); + + // Return type + ValueType returnType = ValueType::NONE; + if (tryMatchingTokenKinds({TokenKind::RIGHT_ARROW}, true, true)) { + shared_ptr typeToken = tokens.at(currentIndex); + optional type = valueTypeForToken(typeToken); + if (!type) + return matchStatementInvalid("Expected return type"); + returnType = *type; + + currentIndex++; // type + + // consume new line + if (!tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true)) + return matchStatementInvalid("Expected new line after function declaration"); } - currentIndex++; // new line shared_ptr statementBlock = matchStatementBlock({TokenKind::SEMICOLON}, true); if (statementBlock == nullptr) return matchStatementInvalid(); @@ -80,7 +99,7 @@ shared_ptr Parser::matchStatementFunctionDeclaration() { if(!tryMatchingTokenKinds({TokenKind::NEW_LINE}, false, true)) return matchStatementInvalid("Expected a new line after a function declaration"); - return make_shared(identifierToken->getLexme(), returnType, dynamic_pointer_cast(statementBlock)); + return make_shared(identifierToken->getLexme(), arguments, returnType, dynamic_pointer_cast(statementBlock)); } shared_ptr Parser::matchStatementVarDeclaration() { @@ -94,7 +113,7 @@ shared_ptr Parser::matchStatementVarDeclaration() { ValueType valueType; if (valueTypeToken->getLexme().compare("bool") == 0) - valueType = ValueType::BOOL; + valueType = ValueType::BOOL; else if (valueTypeToken->getLexme().compare("sint32") == 0) valueType = ValueType::SINT32; else if (valueTypeToken->getLexme().compare("real32") == 0) @@ -395,3 +414,17 @@ bool Parser::tryMatchingTokenKinds(vector kinds, bool shouldMatchAll, return false; } } + +optional Parser::valueTypeForToken(shared_ptr token) { + if (token->getKind() != TokenKind::TYPE) + return {}; + + if (token->getLexme().compare("bool") == 0) + return ValueType::BOOL; + else if (token->getLexme().compare("sint32") == 0) + return ValueType::SINT32; + else if (token->getLexme().compare("real32") == 0) + return ValueType::REAL32; + + return {}; +} diff --git a/src/Parser.h b/src/Parser.h index 39d1ce7..f25faad 100644 --- a/src/Parser.h +++ b/src/Parser.h @@ -37,6 +37,7 @@ private: shared_ptr matchExpressionInvalid(); bool tryMatchingTokenKinds(vector kinds, bool shouldMatchAll, bool shouldAdvance); + optional valueTypeForToken(shared_ptr token); public: Parser(vector> tokens); diff --git a/src/Statement.cpp b/src/Statement.cpp index 35185a1..a53f14d 100644 --- a/src/Statement.cpp +++ b/src/Statement.cpp @@ -2,7 +2,7 @@ string valueTypeToString(ValueType valueType) { switch (valueType) { - case ValueType::VOID: + case ValueType::NONE: return "NONE"; case ValueType::BOOL: return "BOOL"; @@ -32,14 +32,18 @@ string Statement::toString(int indent) { // // StatementFunctionDeclaration -StatementFunctionDeclaration::StatementFunctionDeclaration(string name, ValueType returnValueType, shared_ptr statementBlock): -Statement(StatementKind::FUNCTION_DECLARATION), name(name), returnValueType(returnValueType), statementBlock(statementBlock) { +StatementFunctionDeclaration::StatementFunctionDeclaration(string name, vector> arguments, ValueType returnValueType, shared_ptr statementBlock): +Statement(StatementKind::FUNCTION_DECLARATION), name(name), arguments(arguments), returnValueType(returnValueType), statementBlock(statementBlock) { } string StatementFunctionDeclaration::getName() { return name; } +vector> StatementFunctionDeclaration::getArguments() { + return arguments; +} + ValueType StatementFunctionDeclaration::getReturnValueType() { return returnValueType; } diff --git a/src/Statement.h b/src/Statement.h index 0e00d99..2e23bd8 100644 --- a/src/Statement.h +++ b/src/Statement.h @@ -34,12 +34,14 @@ public: class StatementFunctionDeclaration: public Statement { private: string name; + vector> arguments; ValueType returnValueType; shared_ptr statementBlock; public: - StatementFunctionDeclaration(string name, ValueType returnValueType, shared_ptr statementBlock); + StatementFunctionDeclaration(string name, vector> arguments, ValueType returnValueType, shared_ptr statementBlock); string getName(); + vector> getArguments(); ValueType getReturnValueType(); shared_ptr getStatementBlock(); string toString(int indent) override; diff --git a/src/Token.cpp b/src/Token.cpp index aca1ad4..fb1e6a7 100644 --- a/src/Token.cpp +++ b/src/Token.cpp @@ -104,6 +104,8 @@ string Token::toString() { return "("; case TokenKind::RIGHT_PAREN: return ")"; + case TokenKind::COMMA: + return ","; case TokenKind::COLON: return ":"; case TokenKind::SEMICOLON: diff --git a/src/Types.h b/src/Types.h index 9169ef0..9bbe44a 100644 --- a/src/Types.h +++ b/src/Types.h @@ -17,6 +17,7 @@ enum class TokenKind { LEFT_PAREN, RIGHT_PAREN, + COMMA, COLON, SEMICOLON, QUESTION, @@ -58,7 +59,7 @@ enum class StatementKind { }; enum class ValueType { - VOID, + NONE, BOOL, SINT32, REAL32