Merge pull request #21 from rafalgrodzinski/1-add-block-scoping-for-variables

Add block scoping for variables
This commit is contained in:
Rafał
2025-07-05 23:13:07 +09:00
committed by GitHub
11 changed files with 241 additions and 60 deletions

1
.gitignore vendored
View File

@@ -3,5 +3,4 @@
brb
.vscode/settings.json
*.dSYM
*.brc
build/

34
samples/fib.brc Normal file
View File

@@ -0,0 +1,34 @@
@extern putchar fun: character sint32 -> sint32
fib fun: number sint32 -> sint32
ret if number < 2:
number
else
fib(number - 1) + fib(number - 2)
;
;
printNum fun: number sint32
biggest sint32 <- 10
rep biggest <= number: biggest <- biggest * 10
biggest <- biggest / 10
rep biggest > 0:
digit sint32 <- number / biggest
putchar(digit + '0')
number <- number % biggest
biggest <- biggest / 10
;
;
// Print first 20 fibonaci numbers
main fun -> sint32
rep i sint32 <- 0, i < 20:
res sint32 <- fib(i)
printNum(res)
putchar('\n')
i <- i + 1
;
ret 0
;

View File

@@ -1,5 +1,7 @@
#include "ModuleBuilder.h"
#include "Error.h"
#include "Logger.h"
#include "Parser/ValueType.h"
#include "Parser/Expression/ExpressionGrouping.h"
@@ -33,9 +35,16 @@ moduleName(moduleName), sourceFileName(sourceFileName), statements(statements) {
}
shared_ptr<llvm::Module> ModuleBuilder::getModule() {
for (shared_ptr<Statement> &statement : statements) {
scopes.push(Scope());
for (shared_ptr<Statement> &statement : statements)
buildStatement(statement);
if (!errors.empty()) {
for (shared_ptr<Error> &error : errors)
Logger::print(error);
exit(1);
}
return module;
}
@@ -66,7 +75,7 @@ void ModuleBuilder::buildStatement(shared_ptr<Statement> statement) {
buildExpression(dynamic_pointer_cast<StatementExpression>(statement));
return;
default:
failWithMessage("Unexpected statement");
markError(0, 0, "Unexpected statement");
}
}
@@ -80,12 +89,15 @@ void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunction> state
// build function declaration
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::ExternalLinkage, statement->getName(), module.get());
funMap[statement->getName()] = fun;
if (!setFun(statement->getName(), fun))
return;
// define function body
llvm::BasicBlock *block = llvm::BasicBlock::Create(*context, statement->getName(), fun);
builder->SetInsertPoint(block);
scopes.push(Scope());
// build arguments
int i=0;
for (auto &arg : fun->args()) {
@@ -94,7 +106,8 @@ void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunction> state
arg.setName(name);
llvm::AllocaInst *alloca = builder->CreateAlloca(type, nullptr, name);
allocaMap[name] = alloca;
if (!setAlloca(name, alloca))
return;
builder->CreateStore(&arg, alloca);
i++;
@@ -103,24 +116,28 @@ void ModuleBuilder::buildFunctionDeclaration(shared_ptr<StatementFunction> state
// build function body
buildStatement(statement->getStatementBlock());
scopes.pop();
// verify
string errorMessage;
llvm::raw_string_ostream llvmErrorMessage(errorMessage);
if (llvm::verifyFunction(*fun, &llvmErrorMessage))
failWithMessage(errorMessage);
markError(0, 0, errorMessage);
}
void ModuleBuilder::buildVarDeclaration(shared_ptr<StatementVariable> statement) {
llvm::Value *value = valueForExpression(statement->getExpression());
llvm::AllocaInst *alloca = builder->CreateAlloca(typeForValueType(statement->getValueType()), nullptr, statement->getName());
allocaMap[statement->getName()] = alloca;
if (!setAlloca(statement->getName(), alloca))
return;
builder->CreateStore(value, alloca);
}
void ModuleBuilder::buildAssignment(shared_ptr<StatementAssignment> statement) {
llvm::AllocaInst *alloca = allocaMap[statement->getName()];
llvm::AllocaInst *alloca = getAlloca(statement->getName());
if (alloca == nullptr)
failWithMessage("Variable " + statement->getName() + " not defined");
return;
llvm::Value *value = valueForExpression(statement->getExpression());
builder->CreateStore(value, alloca);
@@ -151,6 +168,8 @@ void ModuleBuilder::buildLoop(shared_ptr<StatementRepeat> statement) {
llvm::BasicBlock *bodyBlock = llvm::BasicBlock::Create(*context, "loopBody");
llvm::BasicBlock *afterBlock = llvm::BasicBlock::Create(*context, "loopPost");
scopes.push(Scope());
// loop init
if (initStatement != nullptr)
buildStatement(statement->getInitStatement());
@@ -181,6 +200,8 @@ void ModuleBuilder::buildLoop(shared_ptr<StatementRepeat> statement) {
// loop post
fun->insert(fun->end(), afterBlock);
builder->SetInsertPoint(afterBlock);
scopes.pop();
}
void ModuleBuilder::buildMetaExternFunction(shared_ptr<StatementMetaExternFunction> statement) {
@@ -193,7 +214,8 @@ void ModuleBuilder::buildMetaExternFunction(shared_ptr<StatementMetaExternFuncti
// build function declaration
llvm::FunctionType *funType = llvm::FunctionType::get(typeForValueType(statement->getReturnValueType()), types, false);
llvm::Function *fun = llvm::Function::Create(funType, llvm::GlobalValue::ExternalLinkage, statement->getName(), module.get());
funMap[statement->getName()] = fun;
if (!setFun(statement->getName(), fun))
return;
// build arguments
int i=0;
@@ -223,11 +245,15 @@ llvm::Value *ModuleBuilder::valueForExpression(shared_ptr<Expression> expression
case ExpressionKind::CALL:
return valueForCall(dynamic_pointer_cast<ExpressionCall>(expression));
default:
failWithMessage("Unexpected expression");
markError(0, 0, "Unexpected expression");
return nullptr;
}
}
llvm::Value *ModuleBuilder::valueForLiteral(shared_ptr<ExpressionLiteral> expression) {
if (expression->getValueType() == nullptr)
return llvm::UndefValue::get(typeVoid);
switch (expression->getValueType()->getKind()) {
case ValueTypeKind::NONE:
return llvm::UndefValue::get(typeVoid);
@@ -258,7 +284,8 @@ llvm::Value *ModuleBuilder::valueForBinary(shared_ptr<ExpressionBinary> expressi
return valueForBinaryReal(expression->getOperation(), leftValue, rightValue);
}
failWithMessage("Unexpected operation");
markError(0, 0, "Unexpected operation");
return nullptr;
}
llvm::Value *ModuleBuilder::valueForBinaryBool(ExpressionBinaryOperation operation, llvm::Value *leftValue, llvm::Value *rightValue) {
@@ -268,7 +295,8 @@ llvm::Value *ModuleBuilder::valueForBinaryBool(ExpressionBinaryOperation operati
case ExpressionBinaryOperation::NOT_EQUAL:
return builder->CreateICmpNE(leftValue, rightValue);
default:
failWithMessage("Undefined operation for boolean operands");
markError(0, 0, "Unexpecgted operation for boolean operands");
return nullptr;
}
}
@@ -340,13 +368,16 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expressi
builder->CreateCondBr(conditionValue, thenBlock, elseBlock);
// Then
scopes.push(Scope());
builder->SetInsertPoint(thenBlock);
buildStatement(expression->getThenBlock()->getStatementBlock());
llvm::Value *thenValue = valueForExpression(expression->getThenBlock()->getResultStatementExpression()->getExpression());
builder->CreateBr(mergeBlock);
thenBlock = builder->GetInsertBlock();
scopes.pop();
// Else
scopes.push(Scope());
fun->insert(fun->end(), elseBlock);
builder->SetInsertPoint(elseBlock);
llvm::Value *elseValue = nullptr;
@@ -357,6 +388,7 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expressi
}
builder->CreateBr(mergeBlock);
elseBlock = builder->GetInsertBlock();
scopes.pop();
// Merge
fun->insert(fun->end(), mergeBlock);
@@ -375,17 +407,17 @@ llvm::Value *ModuleBuilder::valueForIfElse(shared_ptr<ExpressionIfElse> expressi
}
llvm::Value *ModuleBuilder::valueForVar(shared_ptr<ExpressionVariable> expression) {
llvm::AllocaInst *alloca = allocaMap[expression->getName()];
llvm::AllocaInst *alloca = getAlloca(expression->getName());
if (alloca == nullptr)
failWithMessage("Variable " + expression->getName() + " not defined");
return nullptr;
return builder->CreateLoad(alloca->getAllocatedType(), alloca, expression->getName());
}
llvm::Value *ModuleBuilder::valueForCall(shared_ptr<ExpressionCall> expression) {
llvm::Function *fun = funMap[expression->getName()];
llvm::Function *fun = getFun(expression->getName());
if (fun == nullptr)
failWithMessage("Function " + expression->getName() + " not defined");
return nullptr;
llvm::FunctionType *funType = fun->getFunctionType();
vector<llvm::Value*> argValues;
for (shared_ptr<Expression> &argumentExpression : expression->getArgumentExpressions()) {
@@ -395,6 +427,54 @@ llvm::Value *ModuleBuilder::valueForCall(shared_ptr<ExpressionCall> expression)
return builder->CreateCall(funType, fun, llvm::ArrayRef(argValues));
}
bool ModuleBuilder::setAlloca(string name, llvm::AllocaInst *alloca) {
if (scopes.top().allocaMap[name] != nullptr) {
markError(0, 0, format("Variable \"{}\" already defined", name));
return false;
}
scopes.top().allocaMap[name] = alloca;
return true;
}
llvm::AllocaInst* ModuleBuilder::getAlloca(string name) {
stack<Scope> scopes = this->scopes;
while (!scopes.empty()) {
llvm::AllocaInst *alloca = scopes.top().allocaMap[name];
if (alloca != nullptr)
return alloca;
scopes.pop();
}
markError(0, 0, format("Variable \"{}\" not defined in scope", name));
return nullptr;
}
bool ModuleBuilder::setFun(string name, llvm::Function *fun) {
if (scopes.top().funMap[name] != nullptr) {
markError(0, 0, format("Function \"{}\" already defined", name));
return false;
}
scopes.top().funMap[name] = fun;
return true;
}
llvm::Function* ModuleBuilder::getFun(string name) {
stack<Scope> scopes = this->scopes;
while (!scopes.empty()) {
llvm::Function *fun = scopes.top().funMap[name];
if (fun != nullptr)
return fun;
scopes.pop();
}
markError(0, 0, format("Function \"{}\" not defined in scope", name));
return nullptr;
}
llvm::Type *ModuleBuilder::typeForValueType(shared_ptr<ValueType> valueType) {
switch (valueType->getKind()) {
case ValueTypeKind::NONE:
@@ -408,7 +488,6 @@ llvm::Type *ModuleBuilder::typeForValueType(shared_ptr<ValueType> valueType) {
}
}
void ModuleBuilder::failWithMessage(string message) {
cerr << "Error! Building module \"" << moduleName << "\" from \"" + sourceFileName + "\" failed:" << endl << message << endl;
exit(1);
void ModuleBuilder::markError(int line, int column, string message) {
errors.push_back(Error::builderError(line, column, message));
}

View File

@@ -2,6 +2,7 @@
#define MODULE_BUILDER_H
#include <map>
#include <stack>
#include <llvm/IR/Module.h>
#include <llvm/IR/IRBuilder.h>
@@ -10,6 +11,7 @@
#include <llvm/Support/raw_ostream.h>
#include <llvm/IR/Verifier.h>
class Error;
class ValueType;
class Expression;
@@ -33,8 +35,14 @@ class StatementBlock;
using namespace std;
typedef struct {
map<string, llvm::AllocaInst*> allocaMap;
map<string, llvm::Function*> funMap;
} Scope;
class ModuleBuilder {
private:
vector<shared_ptr<Error>> errors;
string moduleName;
string sourceFileName;
@@ -48,8 +56,7 @@ private:
llvm::Type *typeReal32;
vector<shared_ptr<Statement>> statements;
map<string, llvm::AllocaInst*> allocaMap;
map<string, llvm::Function*> funMap;
stack<Scope> scopes;
void buildStatement(shared_ptr<Statement> statement);
void buildFunctionDeclaration(shared_ptr<StatementFunction> statement);
@@ -72,8 +79,15 @@ private:
llvm::Value *valueForVar(shared_ptr<ExpressionVariable> expression);
llvm::Value *valueForCall(shared_ptr<ExpressionCall> expression);
bool setAlloca(string name, llvm::AllocaInst *alloca);
llvm::AllocaInst *getAlloca(string name);
bool setFun(string name, llvm::Function *fun);
llvm::Function *getFun(string name);
llvm::Type *typeForValueType(shared_ptr<ValueType> valueType);
void failWithMessage(string message);
void markError(int line, int column, string message);
public:
ModuleBuilder(string moduleName, string sourceFileName, vector<shared_ptr<Statement>> statements);

View File

@@ -1,10 +1,43 @@
#include "Error.h"
Error::Error(int line, int column, string lexme) :
kind(ErrorKind::LEXER_ERROR), line(line), column(column), lexme(lexme) { }
shared_ptr<Error> Error::lexerError(int line, int column, string lexme) {
return make_shared<Error>(
ErrorKind::LEXER_ERROR,
line,
column,
lexme,
nullptr,
optional<TokenKind>(),
optional<string>()
);
}
Error::Error(shared_ptr<Token> actualToken, optional<TokenKind> expectedTokenKind, optional<string> message) :
kind(ErrorKind::PARSER_ERROR), actualToken(actualToken), expectedTokenKind(expectedTokenKind), message(message) { }
shared_ptr<Error> Error::parserError(shared_ptr<Token> actualToken, optional<TokenKind> expectedTokenKind, optional<string> message) {
return make_shared<Error>(
ErrorKind::PARSER_ERROR,
0,
0,
optional<string>(),
actualToken,
expectedTokenKind,
message
);
}
shared_ptr<Error> Error::builderError(int line, int column, string message) {
return make_shared<Error>(
ErrorKind::BUILDER_ERROR,
line,
column,
optional<string>(),
nullptr,
optional<TokenKind>(),
message
);
}
Error::Error(ErrorKind kind, int line, int column, optional<string> lexme, shared_ptr<Token> actualToken, optional<TokenKind> expectedTokenKind, optional<string> message):
kind(kind), line(line), column(column), lexme(lexme), actualToken(actualToken), expectedTokenKind(expectedTokenKind), message(message) { }
ErrorKind Error::getKind() {
return kind;
@@ -18,7 +51,7 @@ int Error::getColumn() {
return column;
}
string Error::getLexme() {
optional<string> Error::getLexme() {
return lexme;
}

View File

@@ -10,7 +10,8 @@ using namespace std;
enum class ErrorKind {
LEXER_ERROR,
PARSER_ERROR
PARSER_ERROR,
BUILDER_ERROR
};
class Error {
@@ -19,21 +20,24 @@ private:
int line;
int column;
string lexme;
optional<string> lexme;
shared_ptr<Token> actualToken;
optional<TokenKind> expectedTokenKind;
optional<string> message;
public:
Error(int line, int column, string lexme);
Error(shared_ptr<Token> actualToken, optional<TokenKind> expectedTokenKind, optional<string> message);
static shared_ptr<Error> lexerError(int line, int column, string lexme);
static shared_ptr<Error> parserError(shared_ptr<Token> actualToken, optional<TokenKind> expectedTokenKind, optional<string> message);
static shared_ptr<Error> builderError(int line, int column, string message);
Error(ErrorKind kind, int line, int column, optional<string> lexme, shared_ptr<Token> actualToken, optional<TokenKind> expectedTokenKind, optional<string> message);
ErrorKind getKind();
int getLine();
int getColumn();
string getLexme();
optional<string> getLexme();
shared_ptr<Token> getActualToken();
optional<TokenKind> getExpectedTokenKind();

View File

@@ -1,9 +1,10 @@
#include "Lexer.h"
#include "Token.h"
#include "Error.h"
#include "Logger.h"
#include "Token.h"
Lexer::Lexer(string source):
source(source) { }
@@ -506,5 +507,5 @@ void Lexer::markError() {
} else {
lexme = "EOF";
}
errors.push_back(make_shared<Error>(currentLine, startColumn, lexme));
errors.push_back(Error::lexerError(currentLine, startColumn, lexme));
}

View File

@@ -2,6 +2,8 @@
#include <iostream>
#include "Error.h"
#include "Lexer/Token.h"
#include "Parser/ValueType.h"
@@ -24,8 +26,6 @@
#include "Parser/Expression/ExpressionCall.h"
#include "Parser/Expression/ExpressionBlock.h"
#include "Error.h"
string Logger::toString(shared_ptr<Token> token) {
switch (token->getKind()) {
case TokenKind::PLUS:
@@ -367,6 +367,9 @@ string Logger::toString(shared_ptr<ExpressionGrouping> expression) {
}
string Logger::toString(shared_ptr<ExpressionLiteral> expression) {
if (expression->getValueType() == nullptr)
return "?";
switch (expression->getValueType()->getKind()) {
case ValueTypeKind::NONE:
return "NONE";
@@ -392,6 +395,8 @@ string Logger::toString(shared_ptr<ExpressionCall> expression) {
string Logger::toString(shared_ptr<ExpressionBlock> expression) {
string text;
text += toString(expression->getStatementBlock());
if (!text.empty())
text += '\n';
if (expression->getResultStatementExpression() != nullptr)
text += toString(expression->getResultStatementExpression());
return text;
@@ -416,10 +421,12 @@ void Logger::print(vector<shared_ptr<Statement>> statements) {
void Logger::print(shared_ptr<Error> error) {
string message;
switch (error->getKind()) {
case ErrorKind::LEXER_ERROR:
message = format("Unexpected token \"{}\" at line: {}, column: {}", error->getLexme(), error->getLine() + 1, error->getColumn() + 1);
case ErrorKind::LEXER_ERROR: {
string lexme = error->getLexme() ? *(error->getLexme()) : "";
message = format("Unexpected token \"{}\" at line: {}, column: {}", lexme, error->getLine() + 1, error->getColumn() + 1);
break;
case ErrorKind::PARSER_ERROR:
}
case ErrorKind::PARSER_ERROR: {
shared_ptr<Token> token = error->getActualToken();
optional<TokenKind> expectedTokenKind = error->getExpectedTokenKind();
optional<string> errorMessage = error->getMessage();
@@ -438,6 +445,12 @@ void Logger::print(shared_ptr<Error> error) {
if (errorMessage)
message += format(". {}", *errorMessage);
break;
}
case ErrorKind::BUILDER_ERROR: {
string errorMessage = error->getMessage() ? *(error->getMessage()) : "";
message = format("Error at line {}, column {}: {}", error->getLine(), error->getColumn(), errorMessage);
break;
}
}
cout << message << endl;
}

View File

@@ -54,13 +54,15 @@ Expression(ExpressionKind::BINARY, nullptr), operation(ExpressionBinaryOperation
break;
}
// Types must match
if (left->getValueType() != right->getValueType())
valueType = ValueType::NONE;
if (left->getValueType() != nullptr && right->getValueType() != nullptr) {
// Types must match
if (left->getValueType() != right->getValueType())
valueType = ValueType::NONE;
// Booleans can only do = or !=
if (valueType->getKind() == ValueTypeKind::BOOL && (token->getKind() != TokenKind::EQUAL || token->getKind() != TokenKind::NOT_EQUAL))
valueType = ValueType::NONE;
// Booleans can only do = or !=
if (valueType->getKind() == ValueTypeKind::BOOL && (token->getKind() != TokenKind::EQUAL || token->getKind() != TokenKind::NOT_EQUAL))
valueType = ValueType::NONE;
}
}
ExpressionBinaryOperation ExpressionBinary::getOperation() {

View File

@@ -501,19 +501,21 @@ shared_ptr<Expression> Parser::matchExpressionCall() {
currentIndex++; // left parenthesis
vector<shared_ptr<Expression>> argumentExpressions;
do {
tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true); // optional new line
shared_ptr<Expression> argumentExpression = nextExpression();
if (argumentExpression == nullptr)
return nullptr;
argumentExpressions.push_back(argumentExpression);
} while (tryMatchingTokenKinds({TokenKind::COMMA}, true, true));
tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true); // optional new line
if (!tryMatchingTokenKinds({TokenKind::RIGHT_PAREN}, true, true)) {
markError(TokenKind::RIGHT_PAREN, {});
return nullptr;
do {
tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true); // optional new line
shared_ptr<Expression> argumentExpression = nextExpression();
if (argumentExpression == nullptr)
return nullptr;
argumentExpressions.push_back(argumentExpression);
} while (tryMatchingTokenKinds({TokenKind::COMMA}, true, true));
tryMatchingTokenKinds({TokenKind::NEW_LINE}, true, true); // optional new line
if (!tryMatchingTokenKinds({TokenKind::RIGHT_PAREN}, true, true)) {
markError(TokenKind::RIGHT_PAREN, {});
return nullptr;
}
}
return make_shared<ExpressionCall>(identifierToken->getLexme(), argumentExpressions);
@@ -653,5 +655,5 @@ void Parser::markError(optional<TokenKind> expectedTokenKind, optional<string> m
while (!tryMatchingTokenKinds(safeKinds, false, true))
currentIndex++;
errors.push_back(make_shared<Error>(actualToken, expectedTokenKind, message));
errors.push_back(Error::parserError(actualToken, expectedTokenKind, message));
}

View File

@@ -14,9 +14,9 @@ using namespace std;
class Parser {
private:
vector<shared_ptr<Error>> errors;
vector<shared_ptr<Token>> tokens;
int currentIndex = 0;
vector<shared_ptr<Error>> errors;
shared_ptr<Statement> nextStatement();
shared_ptr<Statement> nextInBlockStatement();