further progress

This commit is contained in:
Ludwig Lehnert 2024-12-24 17:46:47 +01:00
parent c23fc2ee86
commit ceab70ed84
23 changed files with 1019 additions and 41 deletions

View File

@ -2,6 +2,7 @@
"C_Cpp.default.includePath": [ "C_Cpp.default.includePath": [
"${workspaceFolder}/include/", "${workspaceFolder}/include/",
"${workspaceFolder}/build/antlr4_runtime/src/antlr4_runtime/runtime/Cpp/runtime/src", "${workspaceFolder}/build/antlr4_runtime/src/antlr4_runtime/runtime/Cpp/runtime/src",
"/usr/lib64/llvm18/include",
], ],
"files.associations": { "files.associations": {
"*.embeddedhtml": "html", "*.embeddedhtml": "html",

View File

@ -44,7 +44,9 @@ add_dependencies(plsm antlr4_static)
target_include_directories(plsm PRIVATE ${CMAKE_SOURCE_DIR}/include) target_include_directories(plsm PRIVATE ${CMAKE_SOURCE_DIR}/include)
target_include_directories(plsm PRIVATE ${ANTLR4_INCLUDE_DIRS}) target_include_directories(plsm PRIVATE ${ANTLR4_INCLUDE_DIRS})
target_link_directories(plsm PRIVATE ${ANTLR4_OUTPUT_DIR}) target_link_directories(plsm PRIVATE ${ANTLR4_OUTPUT_DIR})
target_link_libraries(plsm PRIVATE antlr4-runtime Boost::json) target_link_directories(plsm PRIVATE /usr/lib64/llvm18/lib)
target_include_directories(plsm PRIVATE /usr/lib64/llvm18/include)
target_link_libraries(plsm PRIVATE antlr4-runtime Boost::json LLVM-18)
add_custom_target(clean-all add_custom_target(clean-all
COMMAND ${CMAKE_COMMAND} -E rm -rf COMMAND ${CMAKE_COMMAND} -E rm -rf

View File

@ -18,6 +18,7 @@
#include "Stmt/ExprStmt.h" #include "Stmt/ExprStmt.h"
#include "Stmt/FnDecl.h" #include "Stmt/FnDecl.h"
#include "Stmt/IfStmt.h" #include "Stmt/IfStmt.h"
#include "Stmt/InlineAsm.h"
#include "Stmt/RetStmt.h" #include "Stmt/RetStmt.h"
#include "Stmt/VarDecl.h" #include "Stmt/VarDecl.h"
#include "Stmt/WhileStmt.h" #include "Stmt/WhileStmt.h"

View File

@ -65,6 +65,9 @@ public:
virtual std::any visit(FnParam &fnParam, std::any param) = 0; virtual std::any visit(FnParam &fnParam, std::any param) = 0;
virtual std::any visit(FnDecl &fnDecl, std::any param) = 0; virtual std::any visit(FnDecl &fnDecl, std::any param) = 0;
virtual std::any visit(IfStmt &ifStmt, std::any param) = 0; virtual std::any visit(IfStmt &ifStmt, std::any param) = 0;
virtual std::any visit(InlineAsmConstraint &inlineAsmConstraint,
std::any param) = 0;
virtual std::any visit(InlineAsm &inlineAsm, std::any param) = 0;
virtual std::any visit(RetStmt &retStmt, std::any param) = 0; virtual std::any visit(RetStmt &retStmt, std::any param) = 0;
virtual std::any visit(VarDecl &varDecl, std::any param) = 0; virtual std::any visit(VarDecl &varDecl, std::any param) = 0;
virtual std::any visit(WhileStmt &whileStmt, std::any param) = 0; virtual std::any visit(WhileStmt &whileStmt, std::any param) = 0;
@ -227,6 +230,8 @@ public:
static Stmt *fromJson(boost::json::value json); static Stmt *fromJson(boost::json::value json);
virtual bool alywasReturns() const = 0;
virtual bool isStmt() const override { return true; } virtual bool isStmt() const override { return true; }
}; };

View File

@ -142,6 +142,25 @@ public:
return std::any(); return std::any();
} }
virtual std::any visit(InlineAsmConstraint &inlineAsmConstraint,
std::any param) override {
return std::any();
}
virtual std::any visit(InlineAsm &inlineAsm, std::any param) override {
for (auto &c : inlineAsm.outputs) {
if (c.get())
c->accept(this, param);
}
for (auto &c : inlineAsm.inputs) {
if (c.get())
c->accept(this, param);
}
return std::any();
}
virtual std::any visit(RetStmt &retStmt, std::any param) override { virtual std::any visit(RetStmt &retStmt, std::any param) override {
if (retStmt.value.get()) if (retStmt.value.get())
retStmt.value->accept(this, param); retStmt.value->accept(this, param);

View File

@ -17,6 +17,8 @@ public:
virtual boost::json::value toJson() const override; virtual boost::json::value toJson() const override;
static AssignStmt *fromJson(boost::json::value json); static AssignStmt *fromJson(boost::json::value json);
virtual bool alywasReturns() const override { return false; }
virtual std::any accept(ASTVisitor *visitor, std::any param) override { virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param); return visitor->visit(*this, param);
} }

View File

@ -15,6 +15,13 @@ public:
virtual boost::json::value toJson() const override; virtual boost::json::value toJson() const override;
static Block *fromJson(boost::json::value json); static Block *fromJson(boost::json::value json);
bool alywasReturns() const {
for (auto &stmt : stmts)
if (stmt.get() && stmt->alywasReturns())
return true;
return false;
}
virtual std::any accept(ASTVisitor *visitor, std::any param) override { virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param); return visitor->visit(*this, param);
} }

View File

@ -15,6 +15,8 @@ public:
virtual boost::json::value toJson() const override; virtual boost::json::value toJson() const override;
static ExprStmt *fromJson(boost::json::value json); static ExprStmt *fromJson(boost::json::value json);
virtual bool alywasReturns() const override { return false; }
virtual std::any accept(ASTVisitor *visitor, std::any param) override { virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param); return visitor->visit(*this, param);
} }

View File

@ -44,6 +44,10 @@ public:
virtual boost::json::value toJson() const override; virtual boost::json::value toJson() const override;
static FnDecl *fromJson(boost::json::value json); static FnDecl *fromJson(boost::json::value json);
virtual bool alywasReturns() const override {
throw std::runtime_error("should not call FnDecl::alywasReturns");
}
virtual std::any accept(ASTVisitor *visitor, std::any param) override { virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param); return visitor->visit(*this, param);
} }

View File

@ -19,6 +19,11 @@ public:
virtual boost::json::value toJson() const override; virtual boost::json::value toJson() const override;
static IfStmt *fromJson(boost::json::value json); static IfStmt *fromJson(boost::json::value json);
virtual bool alywasReturns() const override {
return (ifBody.get() && ifBody->alywasReturns()) ||
(elseBody.get() && elseBody->alywasReturns());
}
virtual std::any accept(ASTVisitor *visitor, std::any param) override { virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param); return visitor->visit(*this, param);
} }

View File

@ -0,0 +1,46 @@
#include "AST/Base.h"
namespace plsm {
namespace ast {
class InlineAsmConstraint : public ASTNode {
public:
std::string constraint, variable;
InlineAsmConstraint(LOC_ARG, std::string constraint, std::string variable)
: ASTNode(sourceRange), constraint(constraint), variable(variable) {}
virtual boost::json::value toJson() const override;
static InlineAsmConstraint *fromJson(boost::json::value json);
virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param);
}
};
class InlineAsm : public Stmt {
public:
std::string code;
std::vector<std::unique_ptr<InlineAsmConstraint>> outputs;
std::vector<std::unique_ptr<InlineAsmConstraint>> inputs;
std::vector<std::string> clobbers;
InlineAsm(LOC_ARG, std::string code,
std::vector<std::unique_ptr<InlineAsmConstraint>> outputs,
std::vector<std::unique_ptr<InlineAsmConstraint>> inputs,
std::vector<std::string> clobbers)
: Stmt(sourceRange), code(code), outputs(std::move(outputs)),
inputs(std::move(inputs)), clobbers(clobbers) {}
virtual boost::json::value toJson() const override;
static InlineAsm *fromJson(boost::json::value json);
virtual bool alywasReturns() const override { return false; }
virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param);
}
};
} // namespace ast
} // namespace plsm

View File

@ -15,6 +15,8 @@ public:
virtual boost::json::value toJson() const override; virtual boost::json::value toJson() const override;
static RetStmt *fromJson(boost::json::value json); static RetStmt *fromJson(boost::json::value json);
virtual bool alywasReturns() const override { return true; }
virtual std::any accept(ASTVisitor *visitor, std::any param) override { virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param); return visitor->visit(*this, param);
} }

View File

@ -20,6 +20,8 @@ public:
virtual boost::json::value toJson() const override; virtual boost::json::value toJson() const override;
static VarDecl *fromJson(boost::json::value json); static VarDecl *fromJson(boost::json::value json);
virtual bool alywasReturns() const override { return false; }
virtual std::any accept(ASTVisitor *visitor, std::any param) override { virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param); return visitor->visit(*this, param);
} }

View File

@ -18,6 +18,10 @@ public:
virtual boost::json::value toJson() const override; virtual boost::json::value toJson() const override;
static WhileStmt *fromJson(boost::json::value json); static WhileStmt *fromJson(boost::json::value json);
virtual bool alywasReturns() const override {
return (body.get() && body->alywasReturns());
}
virtual std::any accept(ASTVisitor *visitor, std::any param) override { virtual std::any accept(ASTVisitor *visitor, std::any param) override {
return visitor->visit(*this, param); return visitor->visit(*this, param);
} }

View File

@ -0,0 +1,10 @@
#pragma once
#include "AST/AST.h"
namespace plsm {
void compileModule(std::unique_ptr<ast::Module> &module,
const std::string &filename);
} // namespace plsm

View File

@ -4,6 +4,7 @@ grammar plsm;
#include "AST/AST.h" #include "AST/AST.h"
#include "Utils.h" #include "Utils.h"
#include <memory> #include <memory>
#include <boost/json.hpp>
#include <boost/algorithm/string.hpp> #include <boost/algorithm/string.hpp>
using namespace plsm::utils; using namespace plsm::utils;
@ -95,6 +96,49 @@ stmt
} }
| whileStmt { | whileStmt {
$ast = ptrcast<Stmt>($ctx->whileStmt()->ast); $ast = ptrcast<Stmt>($ctx->whileStmt()->ast);
}
| inlineAsm {
$ast = ptrcast<Stmt>($ctx->inlineAsm()->ast);
};
inlineAsm
returns[std::unique_ptr<InlineAsm> ast]:
'inline' 'asm' '(' inlineAsmCode = string (
':' outputs += inlineAsmConstraint (
',' outputs += inlineAsmConstraint
)*
)? (
':' inputs += inlineAsmConstraint (
',' inputs += inlineAsmConstraint
)*
)? (':' clobbers += string ( ',' clobbers += string)*)? ')' ';' {
auto code = $ctx->inlineAsmCode->value;
std::vector<std::unique_ptr<InlineAsmConstraint>> outputs;
for (auto &output : $ctx->outputs) {
outputs.push_back(std::move(output->ast));
}
std::vector<std::unique_ptr<InlineAsmConstraint>> inputs;
for (auto &input : $ctx->inputs) {
inputs.push_back(std::move(input->ast));
}
std::vector<std::string> clobbers;
for (auto &clobber : $ctx->clobbers) {
clobbers.push_back(clobber->value);
}
$ast = std::make_unique<InlineAsm>(
getSourceRange($ctx), code, std::move(outputs), std::move(inputs), clobbers);
};
inlineAsmConstraint
returns[std::unique_ptr<InlineAsmConstraint> ast]:
string '(' IDENTIFIER ')' {
auto constraint = $ctx->string()->value;
auto variable = $ctx->IDENTIFIER()->getText();
$ast = std::make_unique<InlineAsmConstraint>(getSourceRange($ctx), constraint, variable);
}; };
whileStmt whileStmt
@ -168,7 +212,7 @@ fnDecl
fnParam fnParam
returns[std::unique_ptr<FnParam> ast]: returns[std::unique_ptr<FnParam> ast]:
IDENTIFIER typeName { IDENTIFIER ':' typeName {
$ast = std::make_unique<FnParam>(getSourceRange($ctx), $ctx->IDENTIFIER()->getText(), std::move($ctx->typeName()->ast)); $ast = std::make_unique<FnParam>(getSourceRange($ctx), $ctx->IDENTIFIER()->getText(), std::move($ctx->typeName()->ast));
}; };
@ -216,10 +260,6 @@ binaryExpr
auto binExpr = std::make_unique<BinExpr>(getSourceRange($ctx), op, std::move($ctx->lhs->ast), std::move($ctx->rhs->ast)); auto binExpr = std::make_unique<BinExpr>(getSourceRange($ctx), op, std::move($ctx->lhs->ast), std::move($ctx->rhs->ast));
$ast = ptrcast<Expr>(binExpr); $ast = ptrcast<Expr>(binExpr);
}
| operand = binaryExpr 'as' typeName {
auto castExpr = std::make_unique<CastExpr>(getSourceRange($ctx), std::move($ctx->operand->ast), std::move($ctx->typeName()->ast));
$ast = ptrcast<Expr>(castExpr);
} }
| lhs = binaryExpr op = ( | lhs = binaryExpr op = (
'==' '=='
@ -258,6 +298,10 @@ unaryExpr
} }
| functionCall { | functionCall {
$ast = ptrcast<Expr>($ctx->functionCall()->ast); $ast = ptrcast<Expr>($ctx->functionCall()->ast);
}
| unaryExpr 'as' typeName {
auto castExpr = std::make_unique<CastExpr>(getSourceRange($ctx), std::move($ctx->unaryExpr()->ast), std::move($ctx->typeName()->ast));
$ast = ptrcast<Expr>(castExpr);
} }
| '!' unaryExpr { | '!' unaryExpr {
auto unExpr = std::make_unique<UnExpr>(getSourceRange($ctx), UnOp::NOT, std::move($ctx->unaryExpr()->ast)); auto unExpr = std::make_unique<UnExpr>(getSourceRange($ctx), UnOp::NOT, std::move($ctx->unaryExpr()->ast));
@ -314,6 +358,9 @@ factorExpr
auto text = $ctx->BOOL()->getText(); auto text = $ctx->BOOL()->getText();
auto val = std::make_unique<IntValue>(getSourceRange($ctx), text == "true" ? 1 : 0); auto val = std::make_unique<IntValue>(getSourceRange($ctx), text == "true" ? 1 : 0);
$ast = ptrcast<Expr>(val); $ast = ptrcast<Expr>(val);
}
| lambdaExpr {
$ast = ptrcast<Expr>($ctx->lambdaExpr()->ast);
} }
| '(' expr ')' { | '(' expr ')' {
$ast = std::move($ctx->expr()->ast); $ast = std::move($ctx->expr()->ast);
@ -330,6 +377,17 @@ functionCall
$ast = std::make_unique<CallExpr>(getSourceRange($ctx), std::move($ctx->callee->ast), std::move(args)); $ast = std::make_unique<CallExpr>(getSourceRange($ctx), std::move($ctx->callee->ast), std::move(args));
}; };
lambdaExpr
returns[std::unique_ptr<LambdaExpr> ast]:
'@' '(' (params += fnParam (',' params += fnParam)*)? ')' typeName '{' block '}' {
std::vector<std::unique_ptr<FnParam>> params;
for (auto &param : $ctx->params) {
params.push_back(std::move(param->ast));
}
$ast = std::make_unique<LambdaExpr>(getSourceRange($ctx), std::move(params), std::move($ctx->typeName()->ast), std::move($ctx->block()->ast));
};
typeName typeName
returns[std::unique_ptr<TypeName> ast]: returns[std::unique_ptr<TypeName> ast]:
IDENTIFIER { IDENTIFIER {
@ -338,6 +396,18 @@ typeName
$ast = ptrcast<TypeName>(named); $ast = ptrcast<TypeName>(named);
}; };
string
returns[std::string value]:
STRING {
auto encoded = $ctx->STRING()->getText();
auto decoded = boost::json::parse(encoded);
$value = decoded.as_string();
};
STRING: '"' (ESC | ~["\\\r\n])* '"';
fragment ESC: '\\' (["\\/bfnrt] | 'u' HEX HEX HEX HEX);
fragment HEX: [0-9a-fA-F];
INT: [0-9]+ | '0x' [0-9a-fA-F]+ | '0o' [0-7]+ | '0b' [01]+; INT: [0-9]+ | '0x' [0-9a-fA-F]+ | '0o' [0-7]+ | '0b' [01]+;
FLOAT: [0-9]+ '.' | [0-9]* '.' [0-9]+; FLOAT: [0-9]+ '.' | [0-9]* '.' [0-9]+;
BOOL: 'true' | 'false'; BOOL: 'true' | 'false';
@ -345,3 +415,4 @@ BOOL: 'true' | 'false';
IDENTIFIER: [a-zA-Z_] [a-zA-Z0-9_]*; IDENTIFIER: [a-zA-Z_] [a-zA-Z0-9_]*;
WHITESPACE: [ \r\n\t]+ -> skip; WHITESPACE: [ \r\n\t]+ -> skip;
COMMENT: (('//' ~( '\r' | '\n')*) | ('/*' .*? '*/')) -> skip;

View File

@ -0,0 +1,45 @@
#include "AST/AST.h"
#include "Utils.h"
namespace plsm {
namespace ast {
boost::json::value InlineAsmConstraint::toJson() const {
return {
{"@type", "InlineAsmConstraint"},
{"constraint", constraint},
{"variable", variable},
};
}
InlineAsmConstraint *InlineAsmConstraint::fromJson(boost::json::value json) {
auto constraint =
getJsonValue<InlineAsmConstraint, std::string>(json, "constraint");
auto variable =
getJsonValue<InlineAsmConstraint, std::string>(json, "variable");
return new InlineAsmConstraint(SourceRange::json(), constraint, variable);
}
boost::json::value InlineAsm::toJson() const {
return {
{"@type", "InlineAsm"},
{"code", code},
{"outputs", utils::mapToJson(outputs)},
{"inputs", utils::mapToJson(inputs)},
{"clobbers", clobbers},
};
}
InlineAsm *InlineAsm::fromJson(boost::json::value json) {
auto name = getJsonValue<InlineAsm, std::string>(json, "name");
auto outputs =
fromJsonVector<InlineAsm, InlineAsmConstraint>(json, "outputs");
auto inputs = fromJsonVector<InlineAsm, InlineAsmConstraint>(json, "inputs");
auto clobbers =
getJsonValue<InlineAsm, std::vector<std::string>>(json, "clobbers");
return new InlineAsm(SourceRange::json(), name, std::move(outputs),
std::move(inputs), clobbers);
}
} // namespace ast
} // namespace plsm

View File

@ -0,0 +1,567 @@
#include "AST/BaseASTVisitor.h"
#include "Compile.h"
#include "Utils.h"
#include <map>
#include <vector>
#include <llvm/ADT/Hashing.h>
#include <llvm/CodeGen/MachineModuleInfo.h>
#include <llvm/CodeGen/Passes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Verifier.h>
#include <llvm/MC/TargetRegistry.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
namespace plsm {
using namespace ast;
namespace {
static llvm::Type *getLLVMType(llvm::LLVMContext &ctx,
const std::shared_ptr<Type> &type) {
if (utils::is<PrimitiveType>(type.get())) {
auto primitiveType = (PrimitiveType *)type.get();
if (primitiveType->name == "i8" || primitiveType->name == "u8")
return llvm::Type::getInt8Ty(ctx);
if (primitiveType->name == "i16" || primitiveType->name == "u16")
return llvm::Type::getInt16Ty(ctx);
if (primitiveType->name == "i32" || primitiveType->name == "u32")
return llvm::Type::getInt32Ty(ctx);
if (primitiveType->name == "i64" || primitiveType->name == "u64")
return llvm::Type::getInt64Ty(ctx);
if (primitiveType->name == "i128" || primitiveType->name == "u128")
return llvm::Type::getInt128Ty(ctx);
if (primitiveType->name == "float")
return llvm::Type::getFloatTy(ctx);
if (primitiveType->name == "double")
return llvm::Type::getDoubleTy(ctx);
}
else if (utils::is<FunctionType>(type.get())) {
auto functionType = (FunctionType *)type.get();
auto returnType = getLLVMType(ctx, functionType->returnType);
std::vector<llvm::Type *> llvmParams;
llvmParams.push_back(llvm::IntegerType::get(ctx, 8)->getPointerTo());
for (auto &paramType : functionType->paramTypes)
llvmParams.push_back(getLLVMType(ctx, paramType));
return llvm::FunctionType::get(returnType, llvmParams, false);
}
throw std::runtime_error("cannot determine llvm type");
return nullptr;
}
class IRGenerator1 : public BaseASTVisitor {
llvm::LLVMContext &ctx;
llvm::Module &mod;
llvm::IRBuilder<> &builder;
std::map<Symbol *, llvm::Value *> &symbolMap;
std::set<llvm::Value *> &functions;
public:
IRGenerator1(llvm::LLVMContext &ctx, llvm::Module &mod,
llvm::IRBuilder<> &builder,
std::map<Symbol *, llvm::Value *> &symbolMap,
std::set<llvm::Value *> &functions)
: ctx(ctx), mod(mod), builder(builder), symbolMap(symbolMap),
functions(functions) {}
virtual std::any visit(VarDecl &varDecl, std::any param) override {
auto llvmType = getLLVMType(ctx, varDecl.symbol->type);
// auto global = new llvm::GlobalVariable(mod, llvmType, false,
// llvm::GlobalValue::ExternalLinkage,
// nullptr, varDecl.name);
auto global = mod.getOrInsertGlobal(varDecl.name, llvmType);
symbolMap[varDecl.symbol.get()] = global;
return std::any();
}
virtual std::any visit(FnDecl &fnDecl, std::any param) override {
auto functionType =
(llvm::FunctionType *)getLLVMType(ctx, fnDecl.symbol->type);
mod.getOrInsertFunction(fnDecl.name, functionType);
auto function = mod.getFunction(fnDecl.name);
symbolMap[fnDecl.symbol.get()] = function;
functions.insert(function);
return std::any();
}
};
class IRGenerator2 : public BaseASTVisitor {
llvm::LLVMContext &ctx;
llvm::Module &mod;
llvm::IRBuilder<> &builder;
std::map<Symbol *, llvm::Value *> &symbolMap;
std::set<llvm::Value *> &functions;
std::set<llvm::Value *> lambdas;
llvm::Value *retStore = nullptr;
llvm::BasicBlock *retBlock = nullptr;
llvm::Value *rvalueForLValue = nullptr;
size_t labelCounter = 0;
std::string createLabel() { return "L" + std::to_string(labelCounter++); }
llvm::Value *wrapCallee(llvm::Value *callee) {
auto ptr = llvm::IntegerType::get(ctx, 8)->getPointerTo();
auto structType = llvm::StructType::get(ptr, callee->getType());
auto store = builder.CreateAlloca(structType);
auto ep = builder.CreateStructGEP(structType, store, 0);
builder.CreateStore(llvm::ConstantPointerNull::get(ptr), ep);
ep = builder.CreateStructGEP(structType, store, 1);
builder.CreateStore(callee, ep);
return store;
// (f() -> asdf) -> { context, f() -> asdf }
}
public:
IRGenerator2(llvm::LLVMContext &ctx, llvm::Module &mod,
llvm::IRBuilder<> &builder,
std::map<Symbol *, llvm::Value *> &symbolMap,
std::set<llvm::Value *> &functions)
: ctx(ctx), mod(mod), builder(builder), symbolMap(symbolMap),
functions(functions) {}
virtual std::any visit(FnDecl &fnDecl, std::any param) override {
auto function = mod.getFunction(fnDecl.name);
auto block = llvm::BasicBlock::Create(ctx, createLabel(), function);
builder.SetInsertPoint(block);
size_t i = 0;
for (auto &arg : function->args()) {
if (i > 0) {
auto store = builder.CreateAlloca(arg.getType(), nullptr);
builder.CreateStore(&arg, store);
symbolMap[fnDecl.params[i - 1]->symbol.get()] = store;
}
i += 1;
}
auto fnType = (llvm::FunctionType *)getLLVMType(ctx, fnDecl.symbol->type);
retStore = builder.CreateAlloca(fnType->getReturnType());
retBlock = llvm::BasicBlock::Create(ctx, createLabel(), function);
BaseASTVisitor::visit(fnDecl, param);
if (!fnDecl.body->alywasReturns())
builder.CreateBr(retBlock);
builder.SetInsertPoint(retBlock);
auto retVal = builder.CreateLoad(fnType->getReturnType(), retStore);
builder.CreateRet(retVal);
return std::any();
}
virtual std::any visit(VarDecl &varDecl, std::any param) override {
auto store = builder.CreateAlloca(getLLVMType(ctx, varDecl.symbol->type));
symbolMap[varDecl.symbol.get()] = store;
return std::any();
}
virtual std::any visit(BinExpr &binExpr, std::any param) override {
auto lhs = std::any_cast<llvm::Value *>(binExpr.lhs->accept(this, param));
auto rhs = std::any_cast<llvm::Value *>(binExpr.rhs->accept(this, param));
auto primitiveType = (PrimitiveType *)binExpr.type.get();
auto name = primitiveType->name;
auto isFloat = name == "float" || name == "double";
auto isUnsigned = name[0] == 'u';
switch (binExpr.op) {
case BinOp::ADD:
if (isFloat)
return builder.CreateFAdd(lhs, rhs);
return builder.CreateAdd(lhs, rhs);
case BinOp::SUB:
if (isFloat)
return builder.CreateFSub(lhs, rhs);
return builder.CreateSub(lhs, rhs);
case BinOp::MUL:
if (isFloat)
return builder.CreateFMul(lhs, rhs);
return builder.CreateMul(lhs, rhs);
case BinOp::DIV:
if (isFloat)
return builder.CreateFDiv(lhs, rhs);
if (isUnsigned)
return builder.CreateUDiv(lhs, rhs);
return builder.CreateSDiv(lhs, rhs);
case BinOp::MOD:
if (isUnsigned)
return builder.CreateURem(lhs, rhs);
return builder.CreateSRem(lhs, rhs);
case BinOp::EQ:
return builder.CreateICmpEQ(lhs, rhs);
case BinOp::NE:
return builder.CreateICmpNE(lhs, rhs);
case BinOp::LT:
if (isFloat)
return builder.CreateFCmpOGT(lhs, rhs);
if (isUnsigned)
return builder.CreateICmpULT(lhs, rhs);
return builder.CreateICmpSLT(lhs, rhs);
case BinOp::GT:
if (isFloat)
return builder.CreateFCmpOGT(lhs, rhs);
if (isUnsigned)
return builder.CreateICmpUGT(lhs, rhs);
return builder.CreateICmpSGT(lhs, rhs);
case BinOp::LE:
if (isFloat)
return builder.CreateFCmpOLE(lhs, rhs);
if (isUnsigned)
return builder.CreateICmpULE(lhs, rhs);
return builder.CreateICmpSLE(lhs, rhs);
case BinOp::GE:
if (isFloat)
return builder.CreateFCmpOGE(lhs, rhs);
if (isUnsigned)
return builder.CreateICmpUGE(lhs, rhs);
return builder.CreateICmpSGE(lhs, rhs);
case BinOp::AND:
return builder.CreateAnd(lhs, rhs);
case BinOp::OR:
return builder.CreateOr(lhs, rhs);
}
throw std::runtime_error("binop not implemented");
}
virtual std::any visit(UnExpr &unExpr, std::any param) override {
auto expr = std::any_cast<llvm::Value *>(unExpr.expr->accept(this, param));
switch (unExpr.op) {
case UnOp::NEG:
return builder.CreateNeg(expr);
case UnOp::NOT:
return builder.CreateNot(expr);
case UnOp::POS:
return expr;
}
throw std::runtime_error("unop not implemented");
}
virtual std::any visit(CastExpr &castExpr, std::any param) override {
auto value =
std::any_cast<llvm::Value *>(castExpr.value->accept(this, param));
if (utils::is<PrimitiveType>(castExpr.value->type.get()) &&
utils::is<PrimitiveType>(castExpr.type.get())) {
auto primitiveType = (PrimitiveType *)castExpr.value->type.get();
auto newPrimitiveType = (PrimitiveType *)castExpr.type.get();
auto newType = getLLVMType(ctx, castExpr.type);
auto wasFloat =
primitiveType->name == "float" || primitiveType->name == "double";
auto wasUnsigned = primitiveType->name[0] == 'u';
auto willFloat = newPrimitiveType->name == "float" ||
newPrimitiveType->name == "double";
auto willUnsigned = newPrimitiveType->name[0] == 'u';
if (wasFloat) {
if (willFloat) {
if (primitiveType->name == "double")
return builder.CreateFPExt(value, newType);
else
return builder.CreateFPTrunc(value, newType);
}
else {
if (willUnsigned)
return builder.CreateFPToUI(value, newType);
else
return builder.CreateFPToSI(value, newType);
}
}
else {
if (willFloat) {
if (wasUnsigned)
return builder.CreateUIToFP(value, newType);
else
return builder.CreateSIToFP(value, newType);
}
if (willUnsigned)
return builder.CreateZExtOrTrunc(value, newType);
else
return builder.CreateSExtOrTrunc(value, newType);
}
}
throw std::runtime_error("cast not implemented");
}
virtual std::any visit(CallExpr &callExpr, std::any param) override {
auto callee =
std::any_cast<llvm::Value *>(callExpr.callee->accept(this, param));
auto ptrType = llvm::IntegerType::get(ctx, 8)->getPointerTo();
auto structType = llvm::StructType::get(ptrType, ptrType);
auto ep = builder.CreateStructGEP(structType, callee, 0);
auto callCtx = builder.CreateLoad(ptrType, ep);
ep = builder.CreateStructGEP(structType, callee, 1);
auto realCallee = (llvm::Value *)builder.CreateLoad(ptrType, ep);
// realCallee = builder.CreatePointerCast(realCallee, calleeType);
std::vector<llvm::Value *> llvmArgs;
llvmArgs.push_back(callCtx);
for (auto &arg : callExpr.args) {
llvmArgs.push_back(
std::any_cast<llvm::Value *>(arg->accept(this, param)));
}
auto calleeType =
(llvm::FunctionType *)getLLVMType(ctx, callExpr.callee->type);
return (llvm::Value *)builder.CreateCall(calleeType, realCallee, llvmArgs);
}
virtual std::any visit(Identifier &identifier, std::any param) override {
auto value = symbolMap[identifier.symbol.get()];
if (rvalueForLValue) {
builder.CreateStore(rvalueForLValue, value);
return nullptr;
}
else {
if (functions.count(value))
return wrapCallee(value);
if (utils::is<FunctionType>(identifier.type.get()))
return value;
auto loadType = getLLVMType(ctx, identifier.type);
return (llvm::Value *)builder.CreateLoad(loadType, value);
}
}
virtual std::any visit(IntValue &intValue, std::any param) override {
return (llvm::Value *)llvm::ConstantInt::get(
getLLVMType(ctx, intValue.type), intValue.value);
}
virtual std::any visit(FloatValue &floatValue, std::any param) override {
return (llvm::Value *)llvm::ConstantFP::get(
getLLVMType(ctx, floatValue.type), floatValue.value);
}
virtual std::any visit(RetStmt &retStmt, std::any param) override {
auto value =
std::any_cast<llvm::Value *>(retStmt.value->accept(this, param));
builder.CreateStore(value, retStore);
builder.CreateBr(retBlock);
return std::any();
}
virtual std::any visit(IfStmt &ifStmt, std::any param) override {
auto cond =
std::any_cast<llvm::Value *>(ifStmt.condition->accept(this, param));
auto fn = builder.GetInsertBlock()->getParent();
auto ifBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn);
auto elseBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn);
llvm::BasicBlock *endBlock = nullptr;
if (!ifStmt.alywasReturns())
endBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn);
builder.CreateCondBr(cond, ifBlock, elseBlock);
builder.SetInsertPoint(ifBlock);
ifStmt.ifBody->accept(this, param);
if (endBlock && !ifStmt.ifBody->alywasReturns())
builder.CreateBr(endBlock);
builder.SetInsertPoint(elseBlock);
ifStmt.elseBody->accept(this, param);
if (endBlock && !ifStmt.elseBody->alywasReturns())
builder.CreateBr(endBlock);
if (endBlock && !ifStmt.alywasReturns())
builder.SetInsertPoint(endBlock);
return std::any();
}
virtual std::any visit(WhileStmt &whileStmt, std::any param) override {
auto fn = builder.GetInsertBlock()->getParent();
auto condBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn);
auto whileBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn);
auto endBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn);
builder.CreateBr(condBlock);
builder.SetInsertPoint(condBlock);
auto cond =
std::any_cast<llvm::Value *>(whileStmt.condition->accept(this, param));
builder.CreateCondBr(cond, whileBlock, endBlock);
builder.SetInsertPoint(whileBlock);
whileStmt.body->accept(this, param);
if (!whileStmt.body->alywasReturns())
builder.CreateBr(condBlock);
builder.SetInsertPoint(endBlock);
return std::any();
}
virtual std::any visit(AssignStmt &assignStmt, std::any param) override {
auto rvalue =
std::any_cast<llvm::Value *>(assignStmt.rval->accept(this, param));
rvalueForLValue = rvalue;
auto lvalue = assignStmt.lval->accept(this, param);
rvalueForLValue = nullptr;
return std::any();
}
};
static void runMPM(llvm::Module &mod) {
llvm::PassBuilder passBuilder;
llvm::ModuleAnalysisManager mam;
llvm::CGSCCAnalysisManager gam;
llvm::FunctionAnalysisManager fam;
llvm::LoopAnalysisManager lam;
llvm::ModulePassManager mpm;
passBuilder.registerModuleAnalyses(mam);
passBuilder.registerCGSCCAnalyses(gam);
passBuilder.registerFunctionAnalyses(fam);
passBuilder.registerLoopAnalyses(lam);
passBuilder.crossRegisterProxies(lam, fam, gam, mam);
mpm = passBuilder.buildModuleOptimizationPipeline(
llvm::OptimizationLevel::O3, llvm::ThinOrFullLTOPhase::None);
mpm.run(mod, mam);
mam.clear();
gam.clear();
fam.clear();
lam.clear();
}
static void writeToFile(llvm::LLVMContext &ctx, llvm::Module &mod,
llvm::IRBuilder<> &builder,
const std::string &outfile) {
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
auto target = LLVM_DEFAULT_TARGET_TRIPLE;
mod.setTargetTriple(target);
std::string err;
const llvm::Target *t = llvm::TargetRegistry::lookupTarget(target, err);
if (!t)
throw std::runtime_error(err);
llvm::TargetMachine *targetMachine =
t->createTargetMachine(target, "", "", llvm::TargetOptions(),
// TODO: make configurable
llvm::Reloc::PIC_);
mod.setDataLayout(targetMachine->createDataLayout());
std::error_code ec;
llvm::raw_fd_ostream dest(outfile, ec, llvm::sys::fs::OF_None);
if (ec)
throw std::runtime_error(ec.message());
llvm::legacy::PassManager pm;
auto &tm = (llvm::LLVMTargetMachine &)*targetMachine;
pm.add(new llvm::TargetLibraryInfoWrapperPass());
pm.add(new llvm::MachineModuleInfoWrapperPass(&tm));
bool objResult = targetMachine->addPassesToEmitFile(
pm, dest, nullptr, llvm::CodeGenFileType::ObjectFile);
if (objResult)
throw std::runtime_error("failed to produce " + outfile);
pm.run(mod);
dest.flush();
}
} // namespace
void compileModule(std::unique_ptr<ast::Module> &module,
const std::string &filename) {
auto moduleId = filename;
llvm::LLVMContext ctx;
llvm::Module mod(moduleId, ctx);
llvm::IRBuilder<> builder(ctx);
std::map<Symbol *, llvm::Value *> symbolMap;
std::set<llvm::Value *> functions;
IRGenerator1 generator1(ctx, mod, builder, symbolMap, functions);
module->accept(&generator1, nullptr);
IRGenerator2 generator2(ctx, mod, builder, symbolMap, functions);
for (auto &stmt : module->stmts) {
if (utils::is<FnDecl>(stmt.get()))
stmt->accept(&generator2, nullptr);
}
if (llvm::verifyModule(mod, &llvm::errs())) {
mod.print(llvm::outs(), nullptr);
llvm::outs().flush();
throw std::runtime_error("Module verification failed");
}
runMPM(mod); // info: does not work, programs will malfunction
mod.print(llvm::outs(), nullptr);
llvm::outs().flush();
// std::cout << "----------------------------------------------" << std::endl;
// mod.print(llvm::outs(), nullptr);
// llvm::outs().flush();
writeToFile(ctx, mod, builder, filename + ".o");
}
} // namespace plsm

View File

@ -21,6 +21,12 @@ public:
if (!fnDecl.name.size()) if (!fnDecl.name.size())
return std::any(); return std::any();
if (scopes->back().count(fnDecl.name)) {
errors::put(
fnDecl.error("redeclaration of global symbol '" + fnDecl.name + "'"));
return std::any();
}
auto symbol = std::make_shared<ast::Symbol>(fnDecl.name); auto symbol = std::make_shared<ast::Symbol>(fnDecl.name);
fnDecl.symbol = symbol; fnDecl.symbol = symbol;
scopes->back()[fnDecl.name] = symbol; scopes->back()[fnDecl.name] = symbol;
@ -32,6 +38,12 @@ public:
if (!varDecl.name.size()) if (!varDecl.name.size())
return std::any(); return std::any();
if (scopes->back().count(varDecl.name)) {
errors::put(varDecl.error("redeclaration of global symbol '" +
varDecl.name + "'"));
return std::any();
}
auto symbol = std::make_shared<ast::Symbol>(varDecl.name); auto symbol = std::make_shared<ast::Symbol>(varDecl.name);
varDecl.symbol = symbol; varDecl.symbol = symbol;
scopes->back()[varDecl.name] = symbol; scopes->back()[varDecl.name] = symbol;
@ -137,9 +149,7 @@ void performNameAnalysis(std::unique_ptr<ast::Module> &module) {
NameAnalysisVisitor1 visitor1(&scopes); NameAnalysisVisitor1 visitor1(&scopes);
NameAnalysisVisitor2 visitor2(&scopes); NameAnalysisVisitor2 visitor2(&scopes);
for (auto &stmt : module->stmts) { module->accept(&visitor1, nullptr);
stmt->accept(&visitor1, nullptr);
}
for (auto &stmt : module->stmts) { for (auto &stmt : module->stmts) {
if (utils::is<ast::FnDecl>(stmt.get())) if (utils::is<ast::FnDecl>(stmt.get()))

View File

@ -21,10 +21,12 @@ static std::map<std::string, std::shared_ptr<PrimitiveType>> primitiveTypes = {
{"i16", std::make_shared<PrimitiveType>("i16")}, {"i16", std::make_shared<PrimitiveType>("i16")},
{"i32", std::make_shared<PrimitiveType>("i32")}, {"i32", std::make_shared<PrimitiveType>("i32")},
{"i64", std::make_shared<PrimitiveType>("i64")}, {"i64", std::make_shared<PrimitiveType>("i64")},
{"i128", std::make_shared<PrimitiveType>("i128")},
{"u8", std::make_shared<PrimitiveType>("u8")}, {"u8", std::make_shared<PrimitiveType>("u8")},
{"u16", std::make_shared<PrimitiveType>("u16")}, {"u16", std::make_shared<PrimitiveType>("u16")},
{"u32", std::make_shared<PrimitiveType>("u32")}, {"u32", std::make_shared<PrimitiveType>("u32")},
{"u64", std::make_shared<PrimitiveType>("u64")}, {"u64", std::make_shared<PrimitiveType>("u64")},
{"u128", std::make_shared<PrimitiveType>("u128")},
{"float", std::make_shared<PrimitiveType>("float")}, {"float", std::make_shared<PrimitiveType>("float")},
{"double", std::make_shared<PrimitiveType>("double")}, {"double", std::make_shared<PrimitiveType>("double")},
}; };
@ -48,6 +50,9 @@ static std::shared_ptr<Type> resolveTypeName(const TypeName *typeName) {
static void castTo(std::unique_ptr<Expr> &expr, static void castTo(std::unique_ptr<Expr> &expr,
const std::shared_ptr<Type> &type) { const std::shared_ptr<Type> &type) {
if (*expr->type == *type)
return;
auto cast = new CastExpr(expr->sourceRange, std::move(expr), auto cast = new CastExpr(expr->sourceRange, std::move(expr),
std::unique_ptr<TypeName>(type->toTypeName())); std::unique_ptr<TypeName>(type->toTypeName()));
cast->type = type; cast->type = type;
@ -68,14 +73,17 @@ static bool tryAssignTo(std::unique_ptr<Expr> &from,
const PrimitiveType *toT = (PrimitiveType *)toType.get(); const PrimitiveType *toT = (PrimitiveType *)toType.get();
std::map<std::string, std::vector<std::string>> castMatrix = { std::map<std::string, std::vector<std::string>> castMatrix = {
{"i8", {"i16", "i32", "i64", "u8", "u16", "u32", "u64"}}, {"i8",
{"i16", {"i32", "i64", "u16", "u32", "u64"}}, {"i16", "i32", "i64", "u8", "u16", "u32", "u64", "i128", "u128"}},
{"i32", {"i64", "u32", "u64"}}, {"i16", {"i32", "i64", "u16", "u32", "u64", "i128", "u128"}},
{"i64", {"u64"}}, {"i32", {"i64", "u32", "u64", "i128", "u128"}},
{"u8", {"i16", "i32", "i64", "u16", "u32", "u64"}}, {"i64", {"u64", "i128", "u128"}},
{"u16", {"i32", "i64", "u32", "u64"}}, {"i128", {"u128"}},
{"u32", {"i64", "u64"}}, {"u8", {"i16", "i32", "i64", "u16", "u32", "u64", "i128", "u128"}},
{"u64", {}}, {"u16", {"i32", "i64", "u32", "u64", "i128", "u128"}},
{"u32", {"i64", "u64", "i128", "u128"}},
{"u64", {"i128", "u128"}},
{"u128", {}},
{"float", {"double"}}, {"float", {"double"}},
{"double", {}}, {"double", {}},
}; };
@ -108,14 +116,15 @@ static bool canBeCastedTo(std::unique_ptr<Expr> &from,
PrimitiveType *fromT = (PrimitiveType *)from->type.get(); PrimitiveType *fromT = (PrimitiveType *)from->type.get();
const PrimitiveType *toT = (PrimitiveType *)toType.get(); const PrimitiveType *toT = (PrimitiveType *)toType.get();
std::vector<std::string> allNumberTypes = {"i8", "i16", "i32", "i64", std::vector<std::string> allNumberTypes = {
"u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "i128", "u8",
"float", "double"}; "u16", "u32", "u64", "u128", "float", "double"};
std::map<std::string, std::vector<std::string>> castMatrix = { std::map<std::string, std::vector<std::string>> castMatrix = {
{"i8", allNumberTypes}, {"i16", allNumberTypes}, {"i8", allNumberTypes}, {"i16", allNumberTypes},
{"i32", allNumberTypes}, {"i64", allNumberTypes}, {"i32", allNumberTypes}, {"i64", allNumberTypes},
{"u8", allNumberTypes}, {"u16", allNumberTypes}, {"i128", allNumberTypes}, {"u8", allNumberTypes},
{"u32", allNumberTypes}, {"u64", allNumberTypes}, {"u16", allNumberTypes}, {"u32", allNumberTypes},
{"u64", allNumberTypes}, {"u128", allNumberTypes},
{"float", allNumberTypes}, {"double", allNumberTypes}, {"float", allNumberTypes}, {"double", allNumberTypes},
}; };
@ -168,6 +177,11 @@ public:
} }
virtual std::any visit(FnDecl &fnDecl, std::any param) override { virtual std::any visit(FnDecl &fnDecl, std::any param) override {
if (!(fnDecl.body.get() && fnDecl.body->alywasReturns())) {
errors::put(fnDecl.error("function '" + fnDecl.name +
"' does not always return a value"));
}
if (!fnDecl.symbol.get()) if (!fnDecl.symbol.get())
return std::any(); return std::any();
@ -240,6 +254,39 @@ public:
return std::any(); return std::any();
} }
virtual std::any visit(LambdaExpr &lambdaExpr, std::any param) override {
std::shared_ptr<Type> returnType(nullptr);
if (lambdaExpr.returnTypeName.get()) {
lambdaExpr.returnTypeName->accept(this, param);
returnType = lambdaExpr.returnTypeName->type;
}
std::vector<std::shared_ptr<Type>> paramTypes;
for (auto &p : lambdaExpr.params) {
if (p.get())
p->accept(this, param);
if (p->symbol.get()) {
paramTypes.push_back(p->symbol->type);
} else {
paramTypes.push_back(std::shared_ptr<Type>(nullptr));
}
}
lambdaExpr.type = std::make_shared<FunctionType>(paramTypes, returnType);
auto prevReturnType = currentReturnType;
currentReturnType = returnType;
if (lambdaExpr.body.get()) {
lambdaExpr.body->accept(this, param);
}
currentReturnType = prevReturnType;
return std::any();
}
virtual std::any visit(Identifier &identifier, std::any param) override { virtual std::any visit(Identifier &identifier, std::any param) override {
if (!identifier.symbol.get()) if (!identifier.symbol.get())
return std::any(); return std::any();
@ -306,6 +353,11 @@ public:
if (!assignStmt.lval.get() || !assignStmt.rval->type.get()) if (!assignStmt.lval.get() || !assignStmt.rval->type.get())
return std::any(); return std::any();
if (!utils::is<Identifier>(assignStmt.lval.get())) {
errors::put(assignStmt.error("invalid lvalue"));
return std::any();
}
if (!tryAssignTo(assignStmt.rval, assignStmt.lval->type)) { if (!tryAssignTo(assignStmt.rval, assignStmt.lval->type)) {
errors::put(assignStmt.error("assignment type mismatch")); errors::put(assignStmt.error("assignment type mismatch"));
return std::any(); return std::any();
@ -376,14 +428,125 @@ public:
return std::any(); return std::any();
} }
callExpr.type = callExpr.callee->type;
auto functionType = (FunctionType *)(callExpr.callee->type.get()); auto functionType = (FunctionType *)(callExpr.callee->type.get());
callExpr.type = functionType->returnType;
if (functionType->paramTypes.size() != callExpr.args.size()) { if (functionType->paramTypes.size() != callExpr.args.size()) {
errors::put(callExpr.error("wrong number of arguments")); errors::put(callExpr.error("wrong number of arguments"));
return std::any(); return std::any();
} }
size_t smallerArgCount =
std::min(functionType->paramTypes.size(), callExpr.args.size());
for (size_t i = 0; i < smallerArgCount; i++) {
if (!callExpr.args[i].get() || !functionType->paramTypes[i].get())
continue;
if (!callExpr.args[i]->type.get())
continue;
if (!tryAssignTo(callExpr.args[i], functionType->paramTypes[i])) {
errors::put(callExpr.args[i]->error("argument type mismatch"));
}
}
return std::any();
}
virtual std::any visit(UnExpr &unExpr, std::any param) override {
BaseASTVisitor::visit(unExpr, param);
if (!unExpr.expr.get() || !unExpr.expr->type.get())
return std::any();
if (!utils::is<PrimitiveType>(unExpr.expr->type.get())) {
errors::put(unExpr.error("operand must be of primitive type"));
return std::any();
}
if (unExpr.op == UnOp::NEG) {
castTo(unExpr.expr, std::make_shared<PrimitiveType>("i32"));
}
unExpr.type = unExpr.expr->type;
return std::any();
}
virtual std::any visit(BinExpr &binExpr, std::any param) override {
BaseASTVisitor::visit(binExpr, param);
if (!binExpr.lhs.get() || !binExpr.rhs.get())
return std::any();
if (!binExpr.lhs->type.get() || !binExpr.rhs->type.get())
return std::any();
if (!utils::is<PrimitiveType>(binExpr.lhs->type.get()) ||
!utils::is<PrimitiveType>(binExpr.rhs->type.get())) {
errors::put(binExpr.error("operands must be of primitive type"));
return std::any();
}
switch (binExpr.op) {
case BinOp::ADD:
case BinOp::SUB:
case BinOp::MUL:
case BinOp::DIV: {
if (!tryAssignTo(binExpr.rhs, binExpr.lhs->type)) {
if (!tryAssignTo(binExpr.lhs, binExpr.rhs->type)) {
errors::put(
binExpr.error("operands incompatible, explicit cast required"));
return std::any();
}
}
binExpr.type = binExpr.lhs->type;
break;
}
case BinOp::MOD: {
auto lhsSuccess =
tryAssignTo(binExpr.lhs, std::make_shared<PrimitiveType>("i64"));
auto rhsSuccess =
tryAssignTo(binExpr.rhs, std::make_shared<PrimitiveType>("i64"));
if (!lhsSuccess || !rhsSuccess) {
errors::put(binExpr.error("operands must be of integer type"));
return std::any();
}
binExpr.type = std::make_shared<PrimitiveType>("i64");
}
case BinOp::EQ:
case BinOp::NE:
case BinOp::LT:
case BinOp::GT:
case BinOp::LE:
case BinOp::GE: {
if (!tryAssignTo(binExpr.rhs, binExpr.lhs->type)) {
if (!tryAssignTo(binExpr.lhs, binExpr.rhs->type)) {
errors::put(
binExpr.error("operands incompatible, explicit cast required"));
return std::any();
}
}
binExpr.type = std::make_shared<PrimitiveType>("i32");
break;
}
case BinOp::AND:
case BinOp::OR: {
castTo(binExpr.lhs, std::make_shared<PrimitiveType>("i32"));
castTo(binExpr.rhs, std::make_shared<PrimitiveType>("i32"));
binExpr.type = std::make_shared<PrimitiveType>("i32");
break;
}
}
return std::any(); return std::any();
} }
}; };
@ -394,9 +557,7 @@ void performTypeAnalysis(std::unique_ptr<Module> &module) {
TypeAnalysisVisitor1 visitor1; TypeAnalysisVisitor1 visitor1;
TypeAnalysisVisitor2 visitor2; TypeAnalysisVisitor2 visitor2;
for (auto &stmt : module->stmts) { module->accept(&visitor1, nullptr);
stmt->accept(&visitor1, std::any());
}
for (auto &stmt : module->stmts) { for (auto &stmt : module->stmts) {
if (utils::is<FnDecl>(stmt.get())) if (utils::is<FnDecl>(stmt.get()))

View File

@ -4,6 +4,7 @@
#include <sstream> #include <sstream>
#include "Analysis.h" #include "Analysis.h"
#include "Compile.h"
#include "Errors.h" #include "Errors.h"
#include "Parser.h" #include "Parser.h"
@ -30,11 +31,16 @@ int main(int argc, char *argv[]) {
try { try {
auto module = plsm::parse(argv[1], input); auto module = plsm::parse(argv[1], input);
// std::cout << module->toJsonString() << std::endl;
plsm::performNameAnalysis(module); plsm::performNameAnalysis(module);
plsm::performTypeAnalysis(module); plsm::performTypeAnalysis(module);
// std::cout << module->toJsonString() << std::endl;
if (!plsm::errors::get().size()) {
plsm::compileModule(module, std::string(argv[1]));
}
// std::cout << "\n\n"; // std::cout << "\n\n";
// std::cout << plsm::ast::Module::fromJson(module->toJson())->toJson() << // std::cout << plsm::ast::Module::fromJson(module->toJson())->toJson() <<

BIN
examples/new Executable file

Binary file not shown.

View File

@ -1,16 +1,22 @@
var asdf : i64; fun addFirst(n : i32) i32 {
var result : i32;
result = 0;
fun main(arg0 i64) i8 { var i : i32;
var a : i64; i = 0;
a = 10 + 2; while (i < n) {
result = result + i;
if (asdf > 1000) { i = i + 1;
a = 1;
} else if (asdf > 500) {
a = 2;
} else {
a = 3;
} }
ret 100; ret result;
}
// fun f(n : i32) i32 {
// if (n < 2) ret n;
// ret f(n - 1);
// }
fun main(argc : i32) u8 {
ret addFirst(argc) as u8;
} }