diff --git a/samples/test.brc b/samples/test.brc index b2f978b..02f3f71 100644 --- a/samples/test.brc +++ b/samples/test.brc @@ -49,10 +49,10 @@ i u32 <- 0, rep text[i] != 0: ;*/ main fun -> u32 - num1 u8 <- 42 - num2 s8 <- -15 + /*num1 u8 <- 42 + num2 s8 <- 3 - +15 num3 u32 <- 1234123 - num4 s32 <- -345345 + num4 s32 <- -345345*/ num5 r32 <- -42.58 ret 0 diff --git a/src/Compiler/ModuleBuilder.cpp b/src/Compiler/ModuleBuilder.cpp index ca59586..5a8c3c7 100644 --- a/src/Compiler/ModuleBuilder.cpp +++ b/src/Compiler/ModuleBuilder.cpp @@ -11,6 +11,7 @@ #include "Parser/Expression/ExpressionCall.h" #include "Parser/Expression/ExpressionIfElse.h" #include "Parser/Expression/ExpressionBinary.h" +#include "Parser/Expression/ExpressionUnary.h" #include "Parser/Expression/ExpressionBlock.h" #include "Parser/Statement/StatementFunction.h" @@ -172,6 +173,8 @@ void ModuleBuilder::buildVarDeclaration(shared_ptr statement) } } else { llvm::Value *value = valueForExpression(statement->getExpression()); + if (value == nullptr) + return; llvm::AllocaInst *alloca = builder->CreateAlloca(typeForValueType(statement->getValueType(), 0), nullptr, statement->getName()); if (!setAlloca(statement->getName(), alloca)) @@ -297,6 +300,8 @@ llvm::Value *ModuleBuilder::valueForExpression(shared_ptr expression return valueForExpression(dynamic_pointer_cast(expression)->getExpression()); case ExpressionKind::BINARY: return valueForBinary(dynamic_pointer_cast(expression)); + case ExpressionKind::UNARY: + return valueForUnary(dynamic_pointer_cast(expression)); case ExpressionKind::IF_ELSE: return valueForIfElse(dynamic_pointer_cast(expression)); case ExpressionKind::VAR: @@ -337,7 +342,7 @@ llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr expres case ValueTypeKind::S32: return llvm::ConstantInt::get(typeS32, expression->getS32Value(), true); case ValueTypeKind::R32: - return llvm::ConstantInt::get(typeR32, expression->getR32Value(), true); + return llvm::ConstantFP::get(typeR32, expression->getR32Value()); } } @@ -383,7 +388,7 @@ llvm::Value *ModuleBuilder::valueForBinaryBool(ExpressionBinaryOperation operati case ExpressionBinaryOperation::NOT_EQUAL: return builder->CreateICmpNE(leftValue, rightValue); default: - markError(0, 0, "Unexpecgted operation for boolean operands"); + markError(0, 0, "Unexpected operation for boolean operands"); return nullptr; } } @@ -469,6 +474,26 @@ llvm::Value *ModuleBuilder::valueForBinaryReal(ExpressionBinaryOperation operati } } +llvm::Value *ModuleBuilder::valueForUnary(shared_ptr expression) { + llvm::Value *value = valueForExpression(expression->getExpression()); + llvm::Type *type = value->getType(); + + // do nothing for plus + if (expression->getOperation() == ExpressionUnaryOperation::PLUS) + return value; + + if (type == typeU8 || type == typeU32) { + return builder->CreateNeg(value); + } else if (type == typeS8 || type == typeS32) { + return builder->CreateNSWNeg(value); + } else if (type == typeR32) { + return builder->CreateFNeg(value); + } + + markError(0, 0, "Unexpected operation"); + return nullptr; +} + llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr expression) { shared_ptr conditionExpression = expression->getCondition(); diff --git a/src/Compiler/ModuleBuilder.h b/src/Compiler/ModuleBuilder.h index b9f0100..872b4ad 100644 --- a/src/Compiler/ModuleBuilder.h +++ b/src/Compiler/ModuleBuilder.h @@ -23,6 +23,7 @@ class ExpressionVariable; class ExpressionCall; class ExpressionIfElse; class ExpressionBinary; +class ExpressionUnary; enum class ExpressionBinaryOperation; class Statement; @@ -86,6 +87,7 @@ private: llvm::Value *valueForBinaryUnsignedInteger(ExpressionBinaryOperation operation, llvm::Value *leftValue, llvm::Value *rightValue); llvm::Value *valueForBinarySignedInteger(ExpressionBinaryOperation operation, llvm::Value *leftValue, llvm::Value *rightValue); llvm::Value *valueForBinaryReal(ExpressionBinaryOperation operation, llvm::Value *leftValue, llvm::Value *rightValue); + llvm::Value *valueForUnary(shared_ptr expression); llvm::Value *valueForIfElse(shared_ptr expression); llvm::Value *valueForVar(shared_ptr expression); llvm::Value *valueForCall(shared_ptr expression); diff --git a/src/Lexer/Token.cpp b/src/Lexer/Token.cpp index 5fcc1df..2adddfe 100644 --- a/src/Lexer/Token.cpp +++ b/src/Lexer/Token.cpp @@ -4,21 +4,30 @@ vector Token::tokensEquality = { TokenKind::EQUAL, TokenKind::NOT_EQUAL }; + vector Token::tokensComparison = { TokenKind::LESS, TokenKind::LESS_EQUAL, TokenKind::GREATER, TokenKind::GREATER_EQUAL }; + vector Token::tokensTerm = { TokenKind::PLUS, TokenKind::MINUS }; + vector Token::tokensFactor = { TokenKind::STAR, TokenKind::SLASH, TokenKind::PERCENT }; + +vector Token::tokensUnary = { + TokenKind::PLUS, + TokenKind::MINUS +}; + vector Token::tokensBinary = { TokenKind::EQUAL, TokenKind::NOT_EQUAL, @@ -35,6 +44,7 @@ vector Token::tokensBinary = { TokenKind::SLASH, TokenKind::PERCENT }; + vector Token::tokensLiteral = { TokenKind::BOOL, TokenKind::INTEGER_DEC, diff --git a/src/Lexer/Token.h b/src/Lexer/Token.h index c6ee86a..4b1ab57 100644 --- a/src/Lexer/Token.h +++ b/src/Lexer/Token.h @@ -65,6 +65,7 @@ public: static vector tokensComparison; static vector tokensTerm; static vector tokensFactor; + static vector tokensUnary; static vector tokensBinary; static vector tokensLiteral; diff --git a/src/Logger.cpp b/src/Logger.cpp index a2c9543..6cfbd4f 100644 --- a/src/Logger.cpp +++ b/src/Logger.cpp @@ -20,6 +20,7 @@ #include "Parser/Expression/Expression.h" #include "Parser/Expression/ExpressionBinary.h" +#include "Parser/Expression/ExpressionUnary.h" #include "Parser/Expression/ExpressionIfElse.h" #include "Parser/Expression/ExpressionVariable.h" #include "Parser/Expression/ExpressionGrouping.h" @@ -344,7 +345,9 @@ string Logger::toString(shared_ptr statement) { string Logger::toString(shared_ptr expression) { switch (expression->getKind()) { case ExpressionKind::BINARY: - return toString(dynamic_pointer_cast(expression)); + return toString(dynamic_pointer_cast(expression)); + case ExpressionKind::UNARY: + return toString(dynamic_pointer_cast(expression)); case ExpressionKind::IF_ELSE: return toString(dynamic_pointer_cast(expression)); case ExpressionKind::VAR: @@ -391,6 +394,17 @@ string Logger::toString(shared_ptr expression) { } } +string Logger::toString(shared_ptr expression) { + switch (expression->getOperation()) { + case ExpressionUnaryOperation::PLUS: + return "+" + toString(expression->getExpression()); + case ExpressionUnaryOperation::MINUS: + return "-" + toString(expression->getExpression()); + case ExpressionUnaryOperation::INVALID: + return "{INVALID}"; + } +} + string Logger::toString(shared_ptr expression) { string text; diff --git a/src/Logger.h b/src/Logger.h index 3101911..72e6746 100644 --- a/src/Logger.h +++ b/src/Logger.h @@ -20,6 +20,7 @@ class StatementExpression; class Expression; class ExpressionBinary; +class ExpressionUnary; class ExpressionIfElse; class ExpressionVariable; class ExpressionGrouping; @@ -51,6 +52,7 @@ private: static string toString(shared_ptr expression); static string toString(shared_ptr expression); + static string toString(shared_ptr expression); static string toString(shared_ptr expression); static string toString(shared_ptr expression); static string toString(shared_ptr expression); diff --git a/src/Parser/Expression/ExpressionUnary.cpp b/src/Parser/Expression/ExpressionUnary.cpp index 420cdd9..ddc570d 100644 --- a/src/Parser/Expression/ExpressionUnary.cpp +++ b/src/Parser/Expression/ExpressionUnary.cpp @@ -7,12 +7,15 @@ Expression(ExpressionKind::UNARY, nullptr), expression(expression) { switch (token->getKind()) { case TokenKind::PLUS: operation = ExpressionUnaryOperation::PLUS; + valueType = expression->getValueType(); break; case TokenKind::MINUS: operation = ExpressionUnaryOperation::MINUS; + valueType = expression->getValueType(); break; default: operation = ExpressionUnaryOperation::INVALID; + valueType = nullptr; break; } } diff --git a/src/Parser/Parser.cpp b/src/Parser/Parser.cpp index 9906eeb..cd9b957 100644 --- a/src/Parser/Parser.cpp +++ b/src/Parser/Parser.cpp @@ -12,6 +12,7 @@ #include "Parser/Expression/ExpressionVariable.h" #include "Parser/Expression/ExpressionCall.h" #include "Parser/Expression/ExpressionIfElse.h" +#include "Parser/Expression/ExpressionUnary.h" #include "Parser/Expression/ExpressionBinary.h" #include "Parser/Expression/ExpressionBlock.h" @@ -723,7 +724,7 @@ shared_ptr Parser::matchTerm() { } shared_ptr Parser::matchFactor() { - shared_ptr expression = matchPrimary(); + shared_ptr expression = matchUnary(); if (expression == nullptr) return nullptr; @@ -733,6 +734,19 @@ shared_ptr Parser::matchFactor() { return expression; } +shared_ptr Parser::matchUnary() { + shared_ptr token = tokens.at(currentIndex); + + if (tryMatchingTokenKinds(Token::tokensUnary, false, true)) { + shared_ptr expression = matchPrimary(); + if (expression == nullptr) + return nullptr; + return make_shared(token, expression); + } + + return matchPrimary(); +} + shared_ptr Parser::matchPrimary() { shared_ptr expression; int errorsCount = errors.size(); diff --git a/src/Parser/Parser.h b/src/Parser/Parser.h index 8e5a8d5..0e27960 100644 --- a/src/Parser/Parser.h +++ b/src/Parser/Parser.h @@ -43,7 +43,8 @@ private: shared_ptr matchComparison(); // <, <=, >, >= shared_ptr matchTerm(); // +, - shared_ptr matchFactor(); // *, /, % - shared_ptr matchPrimary(); // integer, () + shared_ptr matchUnary(); // +, - + shared_ptr matchPrimary(); // literal, () shared_ptr matchExpressionGrouping(); shared_ptr matchExpressionLiteral();