Parse function arguments and return type

This commit is contained in:
Rafał Grodziński
2025-06-16 11:29:30 +09:00
parent 8579de4fba
commit 20a3adcca2
10 changed files with 96 additions and 63 deletions

View File

@@ -23,7 +23,7 @@ string Expression::toString(int indent) {
// //
// ExpressionBinary // ExpressionBinary
ExpressionBinary::ExpressionBinary(shared_ptr<Token> token, shared_ptr<Expression> left, shared_ptr<Expression> right): ExpressionBinary::ExpressionBinary(shared_ptr<Token> token, shared_ptr<Expression> left, shared_ptr<Expression> right):
Expression(ExpressionKind::BINARY, ValueType::VOID), left(left), right(right) { Expression(ExpressionKind::BINARY, ValueType::NONE), left(left), right(right) {
// Types must match // Types must match
if (left->getValueType() != right->getValueType()) if (left->getValueType() != right->getValueType())
exit(1); exit(1);
@@ -124,7 +124,7 @@ string ExpressionBinary::toString(int indent) {
// //
// ExpressionLiteral // ExpressionLiteral
ExpressionLiteral::ExpressionLiteral(shared_ptr<Token> token): ExpressionLiteral::ExpressionLiteral(shared_ptr<Token> token):
Expression(ExpressionKind::LITERAL, ValueType::VOID) { Expression(ExpressionKind::LITERAL, ValueType::NONE) {
switch (token->getKind()) { switch (token->getKind()) {
case TokenKind::BOOL: case TokenKind::BOOL:
boolValue = token->getLexme().compare("true") == 0; boolValue = token->getLexme().compare("true") == 0;
@@ -157,8 +157,8 @@ float ExpressionLiteral::getReal32Value() {
string ExpressionLiteral::toString(int indent) { string ExpressionLiteral::toString(int indent) {
switch (valueType) { switch (valueType) {
case ValueType::VOID: case ValueType::NONE:
return "VOID"; return "NONE";
case ValueType::BOOL: case ValueType::BOOL:
return boolValue ? "true" : "false"; return boolValue ? "true" : "false";
case ValueType::SINT32: case ValueType::SINT32:
@@ -185,7 +185,7 @@ string ExpressionGrouping::toString(int indent) {
// //
// ExpressionIfElse // ExpressionIfElse
ExpressionIfElse::ExpressionIfElse(shared_ptr<Expression> condition, shared_ptr<StatementBlock> thenBlock, shared_ptr<StatementBlock> elseBlock): ExpressionIfElse::ExpressionIfElse(shared_ptr<Expression> condition, shared_ptr<StatementBlock> thenBlock, shared_ptr<StatementBlock> elseBlock):
Expression(ExpressionKind::IF_ELSE, ValueType::VOID), condition(condition), thenBlock(thenBlock), elseBlock(elseBlock) { Expression(ExpressionKind::IF_ELSE, ValueType::NONE), condition(condition), thenBlock(thenBlock), elseBlock(elseBlock) {
// Condition must evaluate to bool // Condition must evaluate to bool
if (condition->getValueType() != ValueType::BOOL) if (condition->getValueType() != ValueType::BOOL)
exit(1); exit(1);
@@ -199,7 +199,7 @@ Expression(ExpressionKind::IF_ELSE, ValueType::VOID), condition(condition), then
exit(1); exit(1);
// get type or default to void // get type or default to void
valueType = thenExpression ? thenExpression->getValueType() : ValueType::VOID; valueType = thenExpression ? thenExpression->getValueType() : ValueType::NONE;
} }
shared_ptr<Expression> ExpressionIfElse::getCondition() { shared_ptr<Expression> ExpressionIfElse::getCondition() {
@@ -235,7 +235,7 @@ string ExpressionIfElse::toString(int indent) {
// //
// ExpressionVar // ExpressionVar
ExpressionVar::ExpressionVar(string name): ExpressionVar::ExpressionVar(string name):
Expression(ExpressionKind::VAR, ValueType::VOID), name(name) { Expression(ExpressionKind::VAR, ValueType::NONE), name(name) {
} }
string ExpressionVar::getName() { string ExpressionVar::getName() {
@@ -249,7 +249,7 @@ string ExpressionVar::toString(int indent) {
// //
// ExpressionInvalid // ExpressionInvalid
ExpressionInvalid::ExpressionInvalid(shared_ptr<Token> token): ExpressionInvalid::ExpressionInvalid(shared_ptr<Token> token):
Expression(ExpressionKind::INVALID, ValueType::VOID), token(token) { Expression(ExpressionKind::INVALID, ValueType::NONE), token(token) {
} }
shared_ptr<Token> ExpressionInvalid::getToken() { shared_ptr<Token> ExpressionInvalid::getToken() {

View File

@@ -112,6 +112,10 @@ shared_ptr<Token> Lexer::nextToken() {
if (token != nullptr) if (token != nullptr)
return token; return token;
token = match(TokenKind::COMMA, ",", false);
if (token != nullptr)
return token;
token = match(TokenKind::COLON, ":", false); token = match(TokenKind::COLON, ":", false);
if (token != nullptr) if (token != nullptr)
return token; return token;
@@ -191,6 +195,7 @@ shared_ptr<Token> Lexer::nextToken() {
if (token != nullptr) if (token != nullptr)
return token; return token;
// literal
token = match(TokenKind::BOOL, "true", true); token = match(TokenKind::BOOL, "true", true);
if (token != nullptr) if (token != nullptr)
return token; return token;
@@ -199,7 +204,6 @@ shared_ptr<Token> Lexer::nextToken() {
if (token != nullptr) if (token != nullptr)
return token; return token;
// literal
token = matchReal(); token = matchReal();
if (token != nullptr) if (token != nullptr)
return token; return token;
@@ -208,11 +212,20 @@ shared_ptr<Token> Lexer::nextToken() {
if (token != nullptr) if (token != nullptr)
return token; return token;
// identifier // type
token = matchType(); token = match(TokenKind::TYPE, "bool", true);
if (token != nullptr) if (token != nullptr)
return token; return token;
token = match(TokenKind::TYPE, "sint32", true);
if (token != nullptr)
return token;
token = match(TokenKind::TYPE, "real32", true);
if (token != nullptr)
return token;
// identifier
token = matchIdentifier(); token = matchIdentifier();
if (token != nullptr) if (token != nullptr)
return token; return token;
@@ -280,30 +293,6 @@ shared_ptr<Token> Lexer::matchReal() {
return token; return token;
} }
shared_ptr<Token> Lexer::matchType() {
bool isVarDec = tokens.size() >= 2 &&
tokens.at(tokens.size() - 1)->getKind() == TokenKind::COLON &&
tokens.at(tokens.size() - 2)->getKind() == TokenKind::IDENTIFIER;
bool isFunDec = tokens.size() >= 1 &&
tokens.at(tokens.size() - 1)->getKind() == TokenKind::RIGHT_ARROW;
if (!isVarDec && !isFunDec)
return nullptr;
int nextIndex = currentIndex;
while (nextIndex < source.length() && isIdentifier(nextIndex))
nextIndex++;
if (nextIndex == currentIndex || !isSeparator(nextIndex))
return nullptr;
string lexme = source.substr(currentIndex, nextIndex - currentIndex);
shared_ptr<Token> token = make_shared<Token>(TokenKind::TYPE, lexme, currentLine, currentColumn);
advanceWithToken(token);
return token;
}
shared_ptr<Token> Lexer::matchIdentifier() { shared_ptr<Token> Lexer::matchIdentifier() {
int nextIndex = currentIndex; int nextIndex = currentIndex;
@@ -365,6 +354,7 @@ bool Lexer::isSeparator(int index) {
case '>': case '>':
case '(': case '(':
case ')': case ')':
case ',':
case ':': case ':':
case ';': case ';':
case ' ': case ' ':

View File

@@ -20,7 +20,7 @@ private:
shared_ptr<Token> match(TokenKind kind, string lexme, bool needsSeparator); shared_ptr<Token> match(TokenKind kind, string lexme, bool needsSeparator);
shared_ptr<Token> matchInteger(); shared_ptr<Token> matchInteger();
shared_ptr<Token> matchReal(); shared_ptr<Token> matchReal();
shared_ptr<Token> matchType(); //shared_ptr<Token> matchType();
shared_ptr<Token> matchIdentifier(); shared_ptr<Token> matchIdentifier();
shared_ptr<Token> matchEnd(); shared_ptr<Token> matchEnd();
shared_ptr<Token> matchInvalid(); shared_ptr<Token> matchInvalid();

View File

@@ -97,7 +97,7 @@ llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression
llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr<ExpressionLiteral> expression) { llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr<ExpressionLiteral> expression) {
switch (expression->getValueType()) { switch (expression->getValueType()) {
case ValueType::VOID: case ValueType::NONE:
return llvm::UndefValue::get(typeVoid); return llvm::UndefValue::get(typeVoid);
case ValueType::BOOL: case ValueType::BOOL:
return llvm::ConstantInt::get(typeBool, expression->getBoolValue(), true); return llvm::ConstantInt::get(typeBool, expression->getBoolValue(), true);
@@ -257,7 +257,7 @@ llvm::Value *ModuleBuilder::valueForVar(shared_ptr<ExpressionVar> expression) {
llvm::Type *ModuleBuilder::typeForValueType(ValueType valueType) { llvm::Type *ModuleBuilder::typeForValueType(ValueType valueType) {
switch (valueType) { switch (valueType) {
case ValueType::VOID: case ValueType::NONE:
return typeVoid; return typeVoid;
case ValueType::BOOL: case ValueType::BOOL:
return typeBool; return typeBool;

View File

@@ -45,32 +45,51 @@ shared_ptr<Statement> Parser::nextStatement() {
} }
shared_ptr<Statement> Parser::matchStatementFunctionDeclaration() { shared_ptr<Statement> Parser::matchStatementFunctionDeclaration() {
if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::COLON, TokenKind::FUNCTION}, true, false)) if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::FUNCTION}, true, false))
return nullptr; return nullptr;
shared_ptr<Token> identifierToken = tokens.at(currentIndex); shared_ptr<Token> identifierToken = tokens.at(currentIndex);
currentIndex++; currentIndex++;
currentIndex++; // skip colon
currentIndex++; // skip fun currentIndex++; // skip fun
// Return type // Get arguments
ValueType returnType = ValueType::VOID; vector<pair<string, ValueType>> arguments;
if (tryMatchingTokenKinds({TokenKind::RIGHT_ARROW}, true, true)) { if (tryMatchingTokenKinds({TokenKind::COLON}, true, true)) {
shared_ptr<Token> valueTypeToken = tokens.at(currentIndex); do {
tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true); // skip new line
if (valueTypeToken->getLexme().compare("bool") == 0) if (!tryMatchingTokenKinds({TokenKind::IDENTIFIER, TokenKind::TYPE}, true, false))
returnType = ValueType::BOOL; return matchStatementInvalid("Expected function argument");
else if (valueTypeToken->getLexme().compare("sint32") == 0) shared_ptr<Token> identifierToken = tokens.at(currentIndex);
returnType = ValueType::SINT32; currentIndex++; // identifier
else if (valueTypeToken->getLexme().compare("real32") == 0) shared_ptr<Token> typeToken = tokens.at(currentIndex);
returnType = ValueType::REAL32; currentIndex++; // type
else optional<ValueType> argumentType = valueTypeForToken(typeToken);
return matchStatementInvalid("Expected return type"); if (!argumentType)
return matchStatementInvalid("Invalid argument type");
currentIndex++; // type
arguments.push_back(pair<string, ValueType>(identifierToken->getLexme(), *argumentType));
} while (tryMatchingTokenKinds({TokenKind::COMMA}, true, true));
}
// consume optional new line
tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true);
// Return type
ValueType returnType = ValueType::NONE;
if (tryMatchingTokenKinds({TokenKind::RIGHT_ARROW}, true, true)) {
shared_ptr<Token> typeToken = tokens.at(currentIndex);
optional<ValueType> type = valueTypeForToken(typeToken);
if (!type)
return matchStatementInvalid("Expected return type");
returnType = *type;
currentIndex++; // type
// consume new line
if (!tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true))
return matchStatementInvalid("Expected new line after function declaration");
} }
currentIndex++; // new line
shared_ptr<Statement> statementBlock = matchStatementBlock({TokenKind::SEMICOLON}, true); shared_ptr<Statement> statementBlock = matchStatementBlock({TokenKind::SEMICOLON}, true);
if (statementBlock == nullptr) if (statementBlock == nullptr)
return matchStatementInvalid(); return matchStatementInvalid();
@@ -80,7 +99,7 @@ shared_ptr<Statement> Parser::matchStatementFunctionDeclaration() {
if(!tryMatchingTokenKinds({TokenKind::NEW_LINE}, false, true)) if(!tryMatchingTokenKinds({TokenKind::NEW_LINE}, false, true))
return matchStatementInvalid("Expected a new line after a function declaration"); return matchStatementInvalid("Expected a new line after a function declaration");
return make_shared<StatementFunctionDeclaration>(identifierToken->getLexme(), returnType, dynamic_pointer_cast<StatementBlock>(statementBlock)); return make_shared<StatementFunctionDeclaration>(identifierToken->getLexme(), arguments, returnType, dynamic_pointer_cast<StatementBlock>(statementBlock));
} }
shared_ptr<Statement> Parser::matchStatementVarDeclaration() { shared_ptr<Statement> Parser::matchStatementVarDeclaration() {
@@ -94,7 +113,7 @@ shared_ptr<Statement> Parser::matchStatementVarDeclaration() {
ValueType valueType; ValueType valueType;
if (valueTypeToken->getLexme().compare("bool") == 0) if (valueTypeToken->getLexme().compare("bool") == 0)
valueType = ValueType::BOOL; valueType = ValueType::BOOL;
else if (valueTypeToken->getLexme().compare("sint32") == 0) else if (valueTypeToken->getLexme().compare("sint32") == 0)
valueType = ValueType::SINT32; valueType = ValueType::SINT32;
else if (valueTypeToken->getLexme().compare("real32") == 0) else if (valueTypeToken->getLexme().compare("real32") == 0)
@@ -395,3 +414,17 @@ bool Parser::tryMatchingTokenKinds(vector<TokenKind> kinds, bool shouldMatchAll,
return false; return false;
} }
} }
optional<ValueType> Parser::valueTypeForToken(shared_ptr<Token> token) {
if (token->getKind() != TokenKind::TYPE)
return {};
if (token->getLexme().compare("bool") == 0)
return ValueType::BOOL;
else if (token->getLexme().compare("sint32") == 0)
return ValueType::SINT32;
else if (token->getLexme().compare("real32") == 0)
return ValueType::REAL32;
return {};
}

View File

@@ -37,6 +37,7 @@ private:
shared_ptr<ExpressionInvalid> matchExpressionInvalid(); shared_ptr<ExpressionInvalid> matchExpressionInvalid();
bool tryMatchingTokenKinds(vector<TokenKind> kinds, bool shouldMatchAll, bool shouldAdvance); bool tryMatchingTokenKinds(vector<TokenKind> kinds, bool shouldMatchAll, bool shouldAdvance);
optional<ValueType> valueTypeForToken(shared_ptr<Token> token);
public: public:
Parser(vector<shared_ptr<Token>> tokens); Parser(vector<shared_ptr<Token>> tokens);

View File

@@ -2,7 +2,7 @@
string valueTypeToString(ValueType valueType) { string valueTypeToString(ValueType valueType) {
switch (valueType) { switch (valueType) {
case ValueType::VOID: case ValueType::NONE:
return "NONE"; return "NONE";
case ValueType::BOOL: case ValueType::BOOL:
return "BOOL"; return "BOOL";
@@ -32,14 +32,18 @@ string Statement::toString(int indent) {
// //
// StatementFunctionDeclaration // StatementFunctionDeclaration
StatementFunctionDeclaration::StatementFunctionDeclaration(string name, ValueType returnValueType, shared_ptr<StatementBlock> statementBlock): StatementFunctionDeclaration::StatementFunctionDeclaration(string name, vector<pair<string, ValueType>> arguments, ValueType returnValueType, shared_ptr<StatementBlock> statementBlock):
Statement(StatementKind::FUNCTION_DECLARATION), name(name), returnValueType(returnValueType), statementBlock(statementBlock) { Statement(StatementKind::FUNCTION_DECLARATION), name(name), arguments(arguments), returnValueType(returnValueType), statementBlock(statementBlock) {
} }
string StatementFunctionDeclaration::getName() { string StatementFunctionDeclaration::getName() {
return name; return name;
} }
vector<pair<string, ValueType>> StatementFunctionDeclaration::getArguments() {
return arguments;
}
ValueType StatementFunctionDeclaration::getReturnValueType() { ValueType StatementFunctionDeclaration::getReturnValueType() {
return returnValueType; return returnValueType;
} }

View File

@@ -34,12 +34,14 @@ public:
class StatementFunctionDeclaration: public Statement { class StatementFunctionDeclaration: public Statement {
private: private:
string name; string name;
vector<pair<string, ValueType>> arguments;
ValueType returnValueType; ValueType returnValueType;
shared_ptr<StatementBlock> statementBlock; shared_ptr<StatementBlock> statementBlock;
public: public:
StatementFunctionDeclaration(string name, ValueType returnValueType, shared_ptr<StatementBlock> statementBlock); StatementFunctionDeclaration(string name, vector<pair<string, ValueType>> arguments, ValueType returnValueType, shared_ptr<StatementBlock> statementBlock);
string getName(); string getName();
vector<pair<string, ValueType>> getArguments();
ValueType getReturnValueType(); ValueType getReturnValueType();
shared_ptr<StatementBlock> getStatementBlock(); shared_ptr<StatementBlock> getStatementBlock();
string toString(int indent) override; string toString(int indent) override;

View File

@@ -104,6 +104,8 @@ string Token::toString() {
return "("; return "(";
case TokenKind::RIGHT_PAREN: case TokenKind::RIGHT_PAREN:
return ")"; return ")";
case TokenKind::COMMA:
return ",";
case TokenKind::COLON: case TokenKind::COLON:
return ":"; return ":";
case TokenKind::SEMICOLON: case TokenKind::SEMICOLON:

View File

@@ -17,6 +17,7 @@ enum class TokenKind {
LEFT_PAREN, LEFT_PAREN,
RIGHT_PAREN, RIGHT_PAREN,
COMMA,
COLON, COLON,
SEMICOLON, SEMICOLON,
QUESTION, QUESTION,
@@ -58,7 +59,7 @@ enum class StatementKind {
}; };
enum class ValueType { enum class ValueType {
VOID, NONE,
BOOL, BOOL,
SINT32, SINT32,
REAL32 REAL32