diff --git a/README.md b/README.md index 6b8de49..d4e41ac 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,45 @@ # plsm -Mainly functional general purpose programming language \ No newline at end of file + +Toy general purpose systems programming lanugage + +## Hello World Example (+ unicode, yeah) + +This example just looks like C without its syntax when in reality, plsm aims to be different. It shall have support for classes, lambdas, interfaces/traits and easy modularization. + +```plsm +fun write(fd : i64, msg : &u8, len : u64) i64 { + inline asm ( + "mov $0, %rax" // syscall: write + "mov $1, %rdi" // file descriptor: stdout + "mov $2, %rsi" // message to write + "mov $3, %rdx" // length of message + "syscall" // make the syscall + : + : "r"(1 as i64), "r"(1 as i64), "r"(msg), "r"(len) // input: different params + : "rax", "rdi", "rsi", "rdx" // clobbered registers + ); + + ret 0; +} + +fun exit(code : u8) i64 { + inline asm ( + "mov $0, %rax" + "mov $1, %rdi" + "syscall" + : + : "r"(60 as i64), "r"(code as i64) + : "rax", "rdi" + ); + + ret 0; +} + +fun main(argc : i32) u8 { + write(1, "Hello World!\n", 14); + write(1, "Ə Ɛ Ƒ ƒ Ɠ Ɣ ƕ Ɩ Ɨ Ƙ ƙ ƚ ƛ Ɯ Ɲ ƞ Ɵ Ơ ơ Ƣ ƣ Ƥ ƥ\n", 69); + exit(10); + + ret 0; +} +``` \ No newline at end of file diff --git a/compiler/.vscode/settings.json b/compiler/.vscode/settings.json index d211f7b..8cdcac6 100644 --- a/compiler/.vscode/settings.json +++ b/compiler/.vscode/settings.json @@ -4,6 +4,7 @@ "${workspaceFolder}/build/antlr4_runtime/src/antlr4_runtime/runtime/Cpp/runtime/src", "/usr/lib64/llvm18/include", ], + "clang-format.style": "{ColumnLimit: 110}", "files.associations": { "*.embeddedhtml": "html", "iosfwd": "cpp", diff --git a/compiler/include/AST/AST.h b/compiler/include/AST/AST.h index 1964983..36710ed 100644 --- a/compiler/include/AST/AST.h +++ b/compiler/include/AST/AST.h @@ -24,6 +24,8 @@ #include "Stmt/WhileStmt.h" #include "Type/FunctionType.h" +#include "Type/PointerType.h" #include "Type/PrimitiveType.h" #include "TypeName/NamedTypeName.h" +#include "TypeName/PointerTypeName.h" diff --git a/compiler/include/AST/Base.h b/compiler/include/AST/Base.h index f199353..ac6258a 100644 --- a/compiler/include/AST/Base.h +++ b/compiler/include/AST/Base.h @@ -27,6 +27,7 @@ class LambdaExpr; class UnExpr; class IntValue; class FloatValue; +class StringValue; class Import; class Module; @@ -37,11 +38,14 @@ class ExprStmt; class FnParam; class FnDecl; class IfStmt; +class InlineAsm; +class InlineAsmConstraint; class RetStmt; class VarDecl; class WhileStmt; class NamedTypeName; +class PointerTypeName; class ASTVisitor { public: @@ -55,6 +59,7 @@ public: virtual std::any visit(UnExpr &unExpr, std::any param) = 0; virtual std::any visit(IntValue &intValue, std::any param) = 0; virtual std::any visit(FloatValue &floatValue, std::any param) = 0; + virtual std::any visit(StringValue &stringValue, std::any param) = 0; virtual std::any visit(Import &import, std::any param) = 0; virtual std::any visit(Module &module, std::any param) = 0; @@ -65,14 +70,14 @@ public: virtual std::any visit(FnParam &fnParam, std::any param) = 0; virtual std::any visit(FnDecl &fnDecl, std::any param) = 0; virtual std::any visit(IfStmt &ifStmt, std::any param) = 0; - virtual std::any visit(InlineAsmConstraint &inlineAsmConstraint, - std::any param) = 0; + virtual std::any visit(InlineAsmConstraint &inlineAsmConstraint, std::any param) = 0; virtual std::any visit(InlineAsm &inlineAsm, std::any param) = 0; virtual std::any visit(RetStmt &retStmt, std::any param) = 0; virtual std::any visit(VarDecl &varDecl, std::any param) = 0; virtual std::any visit(WhileStmt &whileStmt, std::any param) = 0; virtual std::any visit(NamedTypeName &namedTypeName, std::any param) = 0; + virtual std::any visit(PointerTypeName &pointerTypeName, std::any param) = 0; }; } // namespace ast @@ -89,8 +94,7 @@ public: protected: template - static inline boost::json::value getJsonProperty(boost::json::value json, - std::string property) { + static inline boost::json::value getJsonProperty(boost::json::value json, std::string property) { boost::json::value prop; try { if (!json.as_object().contains(property)) @@ -98,8 +102,8 @@ protected: prop = json.as_object().at(property); } catch (...) { std::cout << boost::json::serialize(json) << std::endl; - throw std::runtime_error("missing property '" + property + "' in " + - typeid(CurrNode).name() + "::fromJson"); + throw std::runtime_error("missing property '" + property + "' in " + typeid(CurrNode).name() + + "::fromJson"); } return prop; @@ -111,22 +115,18 @@ protected: try { return boost::json::value_to(prop); } catch (...) { - throw std::runtime_error("invalid value for property '" + property + - "' in " + typeid(CurrNode).name() + + throw std::runtime_error("invalid value for property '" + property + "' in " + typeid(CurrNode).name() + "::fromJson"); } } template - static inline auto fromJsonProperty(boost::json::value json, - std::string property) { - return std::unique_ptr( - SubNode::fromJson(getJsonProperty(json, property))); + static inline auto fromJsonProperty(boost::json::value json, std::string property) { + return std::unique_ptr(SubNode::fromJson(getJsonProperty(json, property))); } template - static inline auto fromJsonVector(boost::json::value json, - std::string property) { + static inline auto fromJsonVector(boost::json::value json, std::string property) { auto arr = getJsonProperty(json, property).as_array(); std::vector> result; @@ -140,20 +140,16 @@ protected: class SourceRange { public: - SourceRange(const std::string &file, std::string text, - std::pair start, std::pair end) + SourceRange(const std::string &file, std::string text, std::pair start, + std::pair end) : file(file), text(text), start(start), end(end) {} const std::string file, text; const std::pair start, end; - static SourceRange unknown() { - return SourceRange("", "", {0, 0}, {0, 0}); - }; + static SourceRange unknown() { return SourceRange("", "", {0, 0}, {0, 0}); }; - static SourceRange json() { - return SourceRange("", "", {0, 0}, {0, 0}); - }; + static SourceRange json() { return SourceRange("", "", {0, 0}, {0, 0}); }; }; class TypeName; @@ -166,7 +162,7 @@ public: virtual TypeName *toTypeName() = 0; virtual bool operator==(const Type &other) = 0; - virtual bool operator!=(const Type &other) { return !(*this == other); } + virtual bool operator!=(const Type &other) { return !((*this) == other); } static Type *fromJson(boost::json::value json); }; @@ -187,9 +183,7 @@ public: const SourceRange sourceRange; - virtual std::string toJsonString() const { - return boost::json::serialize(toJson(), {}); - } + virtual std::string toJsonString() const { return boost::json::serialize(toJson(), {}); } static ASTNode *fromJson(boost::json::value json); @@ -200,8 +194,8 @@ public: virtual std::string error(const std::string &message) const { std::stringstream ss; - ss << "In file " << sourceRange.file << ":" << sourceRange.start.first - << ":" << sourceRange.start.second + 1 << "\n" + ss << "In file " << sourceRange.file << ":" << sourceRange.start.first << ":" + << sourceRange.start.second + 1 << "\n" << terminal::cyan << sourceRange.text << terminal::reset << "\n" << terminal::red << message << terminal::reset; diff --git a/compiler/include/AST/BaseASTVisitor.h b/compiler/include/AST/BaseASTVisitor.h index 7597fc4..d915c09 100644 --- a/compiler/include/AST/BaseASTVisitor.h +++ b/compiler/include/AST/BaseASTVisitor.h @@ -37,9 +37,7 @@ public: return std::any(); } - virtual std::any visit(Identifier &identifier, std::any param) override { - return std::any(); - } + virtual std::any visit(Identifier &identifier, std::any param) override { return std::any(); } virtual std::any visit(LambdaExpr &lambdaExpr, std::any param) override { if (lambdaExpr.returnTypeName.get()) @@ -62,17 +60,13 @@ public: return std::any(); } - virtual std::any visit(IntValue &intValue, std::any param) override { - return std::any(); - } + virtual std::any visit(IntValue &intValue, std::any param) override { return std::any(); } - virtual std::any visit(FloatValue &floatValue, std::any param) override { - return std::any(); - } + virtual std::any visit(FloatValue &floatValue, std::any param) override { return std::any(); } - virtual std::any visit(Import &import, std::any param) override { - return std::any(); - } + virtual std::any visit(StringValue &stringValue, std::any param) override { return std::any(); } + + virtual std::any visit(Import &import, std::any param) override { return std::any(); } virtual std::any visit(Module &module, std::any param) override { for (auto &import : module.imports) { @@ -142,8 +136,9 @@ public: return std::any(); } - virtual std::any visit(InlineAsmConstraint &inlineAsmConstraint, - std::any param) override { + virtual std::any visit(InlineAsmConstraint &inlineAsmConstraint, std::any param) override { + if (inlineAsmConstraint.value.get()) + inlineAsmConstraint.value->accept(this, param); return std::any(); } @@ -182,8 +177,11 @@ public: return std::any(); } - virtual std::any visit(NamedTypeName &namedTypeName, - std::any param) override { + virtual std::any visit(NamedTypeName &namedTypeName, std::any param) override { return std::any(); } + + virtual std::any visit(PointerTypeName &pointerTypeName, std::any param) override { + if (pointerTypeName.baseTypeName.get()) + pointerTypeName.baseTypeName->accept(this, param); return std::any(); } }; diff --git a/compiler/include/AST/Expr/Value.h b/compiler/include/AST/Expr/Value.h index e516baa..b7785f2 100644 --- a/compiler/include/AST/Expr/Value.h +++ b/compiler/include/AST/Expr/Value.h @@ -18,7 +18,7 @@ class IntValue : public Expr { public: const std::int64_t value; - IntValue(LOC_ARG, int64_t value) : Expr(sourceRange), value(value) {} + IntValue(LOC_ARG, std::int64_t value) : Expr(sourceRange), value(value) {} virtual boost::json::value toJson() const override; static IntValue *fromJson(boost::json::value json); @@ -32,7 +32,7 @@ class FloatValue : public Expr { public: const std::double_t value; - FloatValue(LOC_ARG, double value) : Expr(sourceRange), value(value) {} + FloatValue(LOC_ARG, std::double_t value) : Expr(sourceRange), value(value) {} virtual boost::json::value toJson() const override; static FloatValue *fromJson(boost::json::value json); @@ -42,5 +42,19 @@ public: } }; +class StringValue : public Expr { +public: + const std::string value; + + StringValue(LOC_ARG, std::string value) : Expr(sourceRange), value(value) {} + + virtual boost::json::value toJson() const override; + static StringValue *fromJson(boost::json::value json); + + virtual std::any accept(ASTVisitor *visitor, std::any param) override { + return visitor->visit(*this, param); + } +}; + } // namespace ast } // namespace plsm diff --git a/compiler/include/AST/Stmt/InlineAsm.h b/compiler/include/AST/Stmt/InlineAsm.h index ed77f39..6cbf391 100644 --- a/compiler/include/AST/Stmt/InlineAsm.h +++ b/compiler/include/AST/Stmt/InlineAsm.h @@ -5,10 +5,11 @@ namespace ast { class InlineAsmConstraint : public ASTNode { public: - std::string constraint, variable; + std::string constraint; + std::unique_ptr value; - InlineAsmConstraint(LOC_ARG, std::string constraint, std::string variable) - : ASTNode(sourceRange), constraint(constraint), variable(variable) {} + InlineAsmConstraint(LOC_ARG, std::string constraint, std::unique_ptr value) + : ASTNode(sourceRange), constraint(constraint), value(std::move(value)) {} virtual boost::json::value toJson() const override; static InlineAsmConstraint *fromJson(boost::json::value json); @@ -25,12 +26,10 @@ public: std::vector> inputs; std::vector clobbers; - InlineAsm(LOC_ARG, std::string code, - std::vector> outputs, - std::vector> inputs, - std::vector clobbers) - : Stmt(sourceRange), code(code), outputs(std::move(outputs)), - inputs(std::move(inputs)), clobbers(clobbers) {} + InlineAsm(LOC_ARG, std::string code, std::vector> outputs, + std::vector> inputs, std::vector clobbers) + : Stmt(sourceRange), code(code), outputs(std::move(outputs)), inputs(std::move(inputs)), + clobbers(std::move(clobbers)) {} virtual boost::json::value toJson() const override; static InlineAsm *fromJson(boost::json::value json); diff --git a/compiler/include/AST/Type/PointerType.h b/compiler/include/AST/Type/PointerType.h new file mode 100644 index 0000000..1d91cef --- /dev/null +++ b/compiler/include/AST/Type/PointerType.h @@ -0,0 +1,29 @@ +#pragma once + +#include "AST/Base.h" + +namespace plsm { +namespace ast { + +class PointerType : public Type { +public: + const std::shared_ptr baseType; + + PointerType(const std::shared_ptr &baseType) : Type(), baseType(baseType) {} + + virtual TypeName *toTypeName() override; + + virtual bool operator==(const Type &other) override { + if (const PointerType *pt = dynamic_cast(&other)) { + return *baseType == *pt->baseType; + } + + return false; + } + + virtual boost::json::value toJson() const override; + static PointerType *fromJson(boost::json::value json); +}; + +} // namespace ast +} // namespace plsm diff --git a/compiler/include/AST/TypeName/PointerTypeName.h b/compiler/include/AST/TypeName/PointerTypeName.h new file mode 100644 index 0000000..215bc41 --- /dev/null +++ b/compiler/include/AST/TypeName/PointerTypeName.h @@ -0,0 +1,24 @@ +#pragma once + +#include "AST/Base.h" + +namespace plsm { +namespace ast { + +class PointerTypeName : public TypeName { +public: + const std::unique_ptr baseTypeName; + + PointerTypeName(LOC_ARG, std::unique_ptr baseTypeName) + : TypeName(sourceRange), baseTypeName(std::move(baseTypeName)) {} + + virtual boost::json::value toJson() const override; + static PointerTypeName *fromJson(boost::json::value json); + + virtual std::any accept(ASTVisitor *visitor, std::any param) override { + return visitor->visit(*this, param); + } +}; + +} // namespace ast +} // namespace plsm \ No newline at end of file diff --git a/compiler/plsm.g4 b/compiler/plsm.g4 index a00741a..a0b8e47 100644 --- a/compiler/plsm.g4 +++ b/compiler/plsm.g4 @@ -103,42 +103,57 @@ stmt inlineAsm returns[std::unique_ptr ast]: - 'inline' 'asm' '(' inlineAsmCode = string ( - ':' outputs += inlineAsmConstraint ( - ',' outputs += inlineAsmConstraint - )* + 'inline' 'asm' '(' (inlineAsmCode += string)+ ( + ':' ( + outputs += inlineAsmConstraint ( + ',' outputs += inlineAsmConstraint + )* + )? )? ( - ':' inputs += inlineAsmConstraint ( - ',' inputs += inlineAsmConstraint - )* - )? (':' clobbers += string ( ',' clobbers += string)*)? ')' ';' { - auto code = $ctx->inlineAsmCode->value; + ':' ( + inputs += inlineAsmConstraint ( + ',' inputs += inlineAsmConstraint + )* + )? + )? (':' (clobbers += string ( ',' clobbers += string)*)?)? ')' ';' { + std::string code = ""; + for (auto &asmCode : $ctx->inlineAsmCode) { + code += asmCode->value; + code += ";"; + } + code.pop_back(); - std::vector> outputs; - for (auto &output : $ctx->outputs) { - outputs.push_back(std::move(output->ast)); - } + std::vector> outputs; + for (auto &output : $ctx->outputs) { + outputs.push_back(std::move(output->ast)); + } - std::vector> inputs; - for (auto &input : $ctx->inputs) { - inputs.push_back(std::move(input->ast)); - } + std::vector> inputs; + for (auto &input : $ctx->inputs) { + inputs.push_back(std::move(input->ast)); + } - std::vector clobbers; - for (auto &clobber : $ctx->clobbers) { - clobbers.push_back(clobber->value); - } + std::vector clobbers; + for (auto &clobber : $ctx->clobbers) { + clobbers.push_back(clobber->value); + } - $ast = std::make_unique( - getSourceRange($ctx), code, std::move(outputs), std::move(inputs), clobbers); + $ast = std::make_unique( + getSourceRange($ctx), code, std::move(outputs), std::move(inputs), clobbers); }; inlineAsmConstraint returns[std::unique_ptr ast]: - string '(' IDENTIFIER ')' { + string '(' expr ')' { auto constraint = $ctx->string()->value; - auto variable = $ctx->IDENTIFIER()->getText(); - $ast = std::make_unique(getSourceRange($ctx), constraint, variable); + $ast = std::make_unique(getSourceRange($ctx), constraint, std::move($ctx->expr()->ast)); + }; + +inlineAsmLvalueConstraint + returns[std::unique_ptr ast]: + string '(' lvalue ')' { + auto constraint = $ctx->string()->value; + $ast = std::make_unique(getSourceRange($ctx), constraint, std::move($ctx->lvalue()->ast)); }; whileStmt @@ -185,10 +200,16 @@ implDeclAssignStmt assignStmt returns[std::unique_ptr ast]: - lval = expr '=' rval = expr ';' { + lval = lvalue '=' rval = expr ';' { $ast = std::make_unique(getSourceRange($ctx), std::move($ctx->lval->ast), std::move($ctx->rval->ast)); }; +lvalue + returns[std::unique_ptr ast]: + identifier { + $ast = ptrcast($ctx->identifier()->ast); + }; + retStmt returns[std::unique_ptr ast]: 'ret' expr ';' { @@ -361,6 +382,9 @@ factorExpr } | lambdaExpr { $ast = ptrcast($ctx->lambdaExpr()->ast); + } + | string { + $ast = std::unique_ptr((Expr *) new StringValue(getSourceRange($ctx), $ctx->string()->value)); } | '(' expr ')' { $ast = std::move($ctx->expr()->ast); @@ -394,6 +418,11 @@ typeName auto text = $ctx->IDENTIFIER()->getText(); auto named = std::make_unique(getSourceRange($ctx), text); $ast = ptrcast(named); + } + | '&' typeName { + auto typeName = std::move($ctx->typeName()->ast); + auto ptr = std::make_unique(getSourceRange($ctx), std::move(typeName)); + $ast = ptrcast(ptr); }; string @@ -404,6 +433,12 @@ string $value = decoded.as_string(); }; +identifier + returns[std::unique_ptr ast]: + IDENTIFIER { + $ast = std::make_unique(getSourceRange($ctx), $ctx->IDENTIFIER()->getText()); + }; + STRING: '"' (ESC | ~["\\\r\n])* '"'; fragment ESC: '\\' (["\\/bfnrt] | 'u' HEX HEX HEX HEX); fragment HEX: [0-9a-fA-F]; diff --git a/compiler/src/AST/Expr/Value.cpp b/compiler/src/AST/Expr/Value.cpp index a3ce274..b55036f 100644 --- a/compiler/src/AST/Expr/Value.cpp +++ b/compiler/src/AST/Expr/Value.cpp @@ -22,5 +22,11 @@ FloatValue *FloatValue::fromJson(boost::json::value json) { return new FloatValue(SourceRange::json(), json.as_double()); } +boost::json::value StringValue::toJson() const { return boost::json::value_from(value); } + +StringValue *StringValue::fromJson(boost::json::value json) { + return new StringValue(SourceRange::json(), json.as_string().c_str()); +} + } // namespace ast } // namespace plsm diff --git a/compiler/src/AST/Stmt/InlineAsm.cpp b/compiler/src/AST/Stmt/InlineAsm.cpp index 0a430e2..5e10649 100644 --- a/compiler/src/AST/Stmt/InlineAsm.cpp +++ b/compiler/src/AST/Stmt/InlineAsm.cpp @@ -8,37 +8,34 @@ boost::json::value InlineAsmConstraint::toJson() const { return { {"@type", "InlineAsmConstraint"}, {"constraint", constraint}, - {"variable", variable}, + {"value", value->toJson()}, }; } InlineAsmConstraint *InlineAsmConstraint::fromJson(boost::json::value json) { - auto constraint = - getJsonValue(json, "constraint"); - auto variable = - getJsonValue(json, "variable"); - return new InlineAsmConstraint(SourceRange::json(), constraint, variable); + auto constraint = getJsonValue(json, "constraint"); + auto value = fromJsonProperty(json, "value"); + return new InlineAsmConstraint(SourceRange::json(), constraint, std::move(value)); } boost::json::value InlineAsm::toJson() const { + boost::json::array jsonClobbers; + return { {"@type", "InlineAsm"}, {"code", code}, {"outputs", utils::mapToJson(outputs)}, {"inputs", utils::mapToJson(inputs)}, - {"clobbers", clobbers}, + {"clobbers", utils::mapToJson(clobbers, [](const std::string &clobber) { return clobber; })}, }; } InlineAsm *InlineAsm::fromJson(boost::json::value json) { auto name = getJsonValue(json, "name"); - auto outputs = - fromJsonVector(json, "outputs"); + auto outputs = fromJsonVector(json, "outputs"); auto inputs = fromJsonVector(json, "inputs"); - auto clobbers = - getJsonValue>(json, "clobbers"); - return new InlineAsm(SourceRange::json(), name, std::move(outputs), - std::move(inputs), clobbers); + auto clobbers = getJsonValue>(json, "clobbers"); + return new InlineAsm(SourceRange::json(), name, std::move(outputs), std::move(inputs), std::move(clobbers)); } } // namespace ast diff --git a/compiler/src/AST/Type/PointerType.cpp b/compiler/src/AST/Type/PointerType.cpp new file mode 100644 index 0000000..895ea81 --- /dev/null +++ b/compiler/src/AST/Type/PointerType.cpp @@ -0,0 +1,24 @@ +#include "AST/AST.h" +#include "AST/Base.h" +#include "AST/TypeName/NamedTypeName.h" +#include + +namespace plsm { +namespace ast { + +boost::json::value PointerType::toJson() const { + return { + {"@type", "PointerType"}, + {"baseType", baseType->toJson()}, + }; +} + +PointerType *PointerType::fromJson(boost::json::value json) { + auto baseType = fromJsonProperty(json, "baseType"); + return new PointerType(std::shared_ptr(baseType.release())); +} + +TypeName *PointerType::toTypeName() { return nullptr; } + +} // namespace ast +} // namespace plsm \ No newline at end of file diff --git a/compiler/src/AST/TypeName/PointerTypeName.cpp b/compiler/src/AST/TypeName/PointerTypeName.cpp new file mode 100644 index 0000000..e33e742 --- /dev/null +++ b/compiler/src/AST/TypeName/PointerTypeName.cpp @@ -0,0 +1,19 @@ +#include "AST/AST.h" + +namespace plsm { +namespace ast { + +boost::json::value PointerTypeName::toJson() const { + return { + {"@type", "PointerTypeName"}, + {"baseTypeName", baseTypeName->toJson()}, + }; +} + +PointerTypeName *PointerTypeName::fromJson(boost::json::value json) { + auto baseTypeName = fromJsonProperty(json, "baseTypeName"); + return new PointerTypeName(SourceRange::json(), std::move(baseTypeName)); +} + +} // namespace ast +} // namespace plsm \ No newline at end of file diff --git a/compiler/src/Visitors/Compiler.cpp b/compiler/src/Visitors/Compiler.cpp index 5c4f400..a70c116 100644 --- a/compiler/src/Visitors/Compiler.cpp +++ b/compiler/src/Visitors/Compiler.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -28,8 +29,7 @@ using namespace ast; namespace { -static llvm::Type *getLLVMType(llvm::LLVMContext &ctx, - const std::shared_ptr &type) { +static llvm::Type *getLLVMType(llvm::LLVMContext &ctx, const std::shared_ptr &type) { if (utils::is(type.get())) { auto primitiveType = (PrimitiveType *)type.get(); if (primitiveType->name == "i8" || primitiveType->name == "u8") @@ -48,14 +48,20 @@ static llvm::Type *getLLVMType(llvm::LLVMContext &ctx, return llvm::Type::getDoubleTy(ctx); } + else if (utils::is(type.get())) { + auto pointerType = (PointerType *)type.get(); + auto baseType = getLLVMType(ctx, pointerType->baseType); + return baseType->getPointerTo(); + } + else if (utils::is(type.get())) { auto functionType = (FunctionType *)type.get(); auto returnType = getLLVMType(ctx, functionType->returnType); std::vector llvmParams; - llvmParams.push_back(llvm::IntegerType::get(ctx, 8)->getPointerTo()); for (auto ¶mType : functionType->paramTypes) llvmParams.push_back(getLLVMType(ctx, paramType)); + llvmParams.push_back(llvm::IntegerType::get(ctx, 8)->getPointerTo()); return llvm::FunctionType::get(returnType, llvmParams, false); } @@ -72,29 +78,20 @@ class IRGenerator1 : public BaseASTVisitor { std::set &functions; public: - IRGenerator1(llvm::LLVMContext &ctx, llvm::Module &mod, - llvm::IRBuilder<> &builder, - std::map &symbolMap, - std::set &functions) - : ctx(ctx), mod(mod), builder(builder), symbolMap(symbolMap), - functions(functions) {} + IRGenerator1(llvm::LLVMContext &ctx, llvm::Module &mod, llvm::IRBuilder<> &builder, + std::map &symbolMap, std::set &functions) + : ctx(ctx), mod(mod), builder(builder), symbolMap(symbolMap), functions(functions) {} virtual std::any visit(VarDecl &varDecl, std::any param) override { auto llvmType = getLLVMType(ctx, varDecl.symbol->type); - - // auto global = new llvm::GlobalVariable(mod, llvmType, false, - // llvm::GlobalValue::ExternalLinkage, - // nullptr, varDecl.name); auto global = mod.getOrInsertGlobal(varDecl.name, llvmType); - symbolMap[varDecl.symbol.get()] = global; return std::any(); } virtual std::any visit(FnDecl &fnDecl, std::any param) override { - auto functionType = - (llvm::FunctionType *)getLLVMType(ctx, fnDecl.symbol->type); + auto functionType = (llvm::FunctionType *)getLLVMType(ctx, fnDecl.symbol->type); mod.getOrInsertFunction(fnDecl.name, functionType); auto function = mod.getFunction(fnDecl.name); @@ -117,53 +114,51 @@ class IRGenerator2 : public BaseASTVisitor { llvm::Value *retStore = nullptr; llvm::BasicBlock *retBlock = nullptr; - llvm::Value *rvalueForLValue = nullptr; + bool requireLValue = false; - size_t labelCounter = 0; - std::string createLabel() { return "L" + std::to_string(labelCounter++); } + llvm::StructType *closureType = nullptr; + llvm::PointerType *pointerType = nullptr; llvm::Value *wrapCallee(llvm::Value *callee) { - auto ptr = llvm::IntegerType::get(ctx, 8)->getPointerTo(); - auto structType = llvm::StructType::get(ptr, callee->getType()); + auto value = (llvm::Value *)llvm::UndefValue::get(closureType); + auto calleePointer = builder.CreatePointerCast(callee, pointerType); + value = builder.CreateInsertValue(value, calleePointer, 0); + auto nullContext = llvm::ConstantPointerNull::get(pointerType); + value = builder.CreateInsertValue(value, nullContext, 1); - auto store = builder.CreateAlloca(structType); - auto ep = builder.CreateStructGEP(structType, store, 0); - builder.CreateStore(llvm::ConstantPointerNull::get(ptr), ep); - ep = builder.CreateStructGEP(structType, store, 1); - builder.CreateStore(callee, ep); - - return store; + return value; // (f() -> asdf) -> { context, f() -> asdf } } public: - IRGenerator2(llvm::LLVMContext &ctx, llvm::Module &mod, - llvm::IRBuilder<> &builder, - std::map &symbolMap, - std::set &functions) - : ctx(ctx), mod(mod), builder(builder), symbolMap(symbolMap), - functions(functions) {} + IRGenerator2(llvm::LLVMContext &ctx, llvm::Module &mod, llvm::IRBuilder<> &builder, + std::map &symbolMap, std::set &functions) + : ctx(ctx), mod(mod), builder(builder), symbolMap(symbolMap), functions(functions) { + this->pointerType = llvm::IntegerType::get(ctx, 8)->getPointerTo(); + this->closureType = llvm::StructType::get(pointerType, pointerType); + } virtual std::any visit(FnDecl &fnDecl, std::any param) override { auto function = mod.getFunction(fnDecl.name); - auto block = llvm::BasicBlock::Create(ctx, createLabel(), function); + auto block = llvm::BasicBlock::Create(ctx, "", function); builder.SetInsertPoint(block); size_t i = 0; for (auto &arg : function->args()) { - if (i > 0) { - auto store = builder.CreateAlloca(arg.getType(), nullptr); - builder.CreateStore(&arg, store); - symbolMap[fnDecl.params[i - 1]->symbol.get()] = store; - } + if (i == fnDecl.params.size()) + break; + + auto store = builder.CreateAlloca(arg.getType(), nullptr); + builder.CreateStore(&arg, store); + symbolMap[fnDecl.params[i]->symbol.get()] = store; i += 1; } auto fnType = (llvm::FunctionType *)getLLVMType(ctx, fnDecl.symbol->type); retStore = builder.CreateAlloca(fnType->getReturnType()); - retBlock = llvm::BasicBlock::Create(ctx, createLabel(), function); + retBlock = llvm::BasicBlock::Create(ctx, "", function); BaseASTVisitor::visit(fnDecl, param); if (!fnDecl.body->alywasReturns()) @@ -268,8 +263,7 @@ public: } virtual std::any visit(CastExpr &castExpr, std::any param) override { - auto value = - std::any_cast(castExpr.value->accept(this, param)); + auto value = std::any_cast(castExpr.value->accept(this, param)); if (utils::is(castExpr.value->type.get()) && utils::is(castExpr.type.get())) { @@ -278,11 +272,9 @@ public: auto newType = getLLVMType(ctx, castExpr.type); - auto wasFloat = - primitiveType->name == "float" || primitiveType->name == "double"; + auto wasFloat = primitiveType->name == "float" || primitiveType->name == "double"; auto wasUnsigned = primitiveType->name[0] == 'u'; - auto willFloat = newPrimitiveType->name == "float" || - newPrimitiveType->name == "double"; + auto willFloat = newPrimitiveType->name == "float" || newPrimitiveType->name == "double"; auto willUnsigned = newPrimitiveType->name[0] == 'u'; if (wasFloat) { @@ -320,63 +312,69 @@ public: } virtual std::any visit(CallExpr &callExpr, std::any param) override { - auto callee = - std::any_cast(callExpr.callee->accept(this, param)); + auto callee = std::any_cast(callExpr.callee->accept(this, param)); - auto ptrType = llvm::IntegerType::get(ctx, 8)->getPointerTo(); - auto structType = llvm::StructType::get(ptrType, ptrType); - auto ep = builder.CreateStructGEP(structType, callee, 0); - auto callCtx = builder.CreateLoad(ptrType, ep); - - ep = builder.CreateStructGEP(structType, callee, 1); - auto realCallee = (llvm::Value *)builder.CreateLoad(ptrType, ep); - - // realCallee = builder.CreatePointerCast(realCallee, calleeType); + auto realCallee = builder.CreateExtractValue(callee, 0); + auto callCtx = builder.CreateExtractValue(callee, 1); std::vector llvmArgs; + for (auto &arg : callExpr.args) + llvmArgs.push_back(std::any_cast(arg->accept(this, param))); llvmArgs.push_back(callCtx); - for (auto &arg : callExpr.args) { - llvmArgs.push_back( - std::any_cast(arg->accept(this, param))); - } - auto calleeType = - (llvm::FunctionType *)getLLVMType(ctx, callExpr.callee->type); + auto calleeType = (llvm::FunctionType *)getLLVMType(ctx, callExpr.callee->type); return (llvm::Value *)builder.CreateCall(calleeType, realCallee, llvmArgs); } virtual std::any visit(Identifier &identifier, std::any param) override { auto value = symbolMap[identifier.symbol.get()]; - if (rvalueForLValue) { - builder.CreateStore(rvalueForLValue, value); - return nullptr; - } + if (requireLValue) + return value; - else { - if (functions.count(value)) - return wrapCallee(value); - if (utils::is(identifier.type.get())) - return value; + if (functions.count(value)) + return wrapCallee(value); + if (utils::is(identifier.type.get())) + return (llvm::Value *)builder.CreateLoad(closureType, value); - auto loadType = getLLVMType(ctx, identifier.type); - return (llvm::Value *)builder.CreateLoad(loadType, value); - } + auto loadType = getLLVMType(ctx, identifier.type); + return (llvm::Value *)builder.CreateLoad(loadType, value); } virtual std::any visit(IntValue &intValue, std::any param) override { - return (llvm::Value *)llvm::ConstantInt::get( - getLLVMType(ctx, intValue.type), intValue.value); + return (llvm::Value *)llvm::ConstantInt::get(getLLVMType(ctx, intValue.type), intValue.value); } virtual std::any visit(FloatValue &floatValue, std::any param) override { - return (llvm::Value *)llvm::ConstantFP::get( - getLLVMType(ctx, floatValue.type), floatValue.value); + return (llvm::Value *)llvm::ConstantFP::get(getLLVMType(ctx, floatValue.type), floatValue.value); + } + + virtual std::any visit(StringValue &stringValue, std::any param) override { + return (llvm::Value *)builder.CreateGlobalStringPtr(stringValue.value); + + auto valType = llvm::Type::getInt8Ty(ctx); + + auto size = stringValue.value.size() + 1; + auto store = builder.CreateAlloca(llvm::Type::getInt8Ty(ctx), size); + + size_t i = 0; + for (auto &c : stringValue.value) { + auto ep = builder.CreateConstGEP1_64(valType, store, i); + builder.CreateStore(llvm::ConstantInt::get(valType, c), ep); + i += 1; + } + + // add null terminator + auto ep = builder.CreateConstGEP1_64(valType, store, i); + builder.CreateStore(llvm::ConstantInt::get(valType, 0), ep); + + auto ptr = builder.CreatePointerCast(store, pointerType); + + return (llvm::Value *)ptr; } virtual std::any visit(RetStmt &retStmt, std::any param) override { - auto value = - std::any_cast(retStmt.value->accept(this, param)); + auto value = std::any_cast(retStmt.value->accept(this, param)); builder.CreateStore(value, retStore); builder.CreateBr(retBlock); @@ -385,16 +383,15 @@ public: } virtual std::any visit(IfStmt &ifStmt, std::any param) override { - auto cond = - std::any_cast(ifStmt.condition->accept(this, param)); + auto cond = std::any_cast(ifStmt.condition->accept(this, param)); auto fn = builder.GetInsertBlock()->getParent(); - auto ifBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); - auto elseBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); + auto ifBlock = llvm::BasicBlock::Create(ctx, "", fn); + auto elseBlock = llvm::BasicBlock::Create(ctx, "", fn); llvm::BasicBlock *endBlock = nullptr; if (!ifStmt.alywasReturns()) - endBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); + endBlock = llvm::BasicBlock::Create(ctx, "", fn); builder.CreateCondBr(cond, ifBlock, elseBlock); @@ -416,15 +413,14 @@ public: virtual std::any visit(WhileStmt &whileStmt, std::any param) override { auto fn = builder.GetInsertBlock()->getParent(); - auto condBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); - auto whileBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); - auto endBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); + auto condBlock = llvm::BasicBlock::Create(ctx, "", fn); + auto whileBlock = llvm::BasicBlock::Create(ctx, "", fn); + auto endBlock = llvm::BasicBlock::Create(ctx, "", fn); builder.CreateBr(condBlock); builder.SetInsertPoint(condBlock); - auto cond = - std::any_cast(whileStmt.condition->accept(this, param)); + auto cond = std::any_cast(whileStmt.condition->accept(this, param)); builder.CreateCondBr(cond, whileBlock, endBlock); builder.SetInsertPoint(whileBlock); @@ -438,12 +434,56 @@ public: } virtual std::any visit(AssignStmt &assignStmt, std::any param) override { - auto rvalue = - std::any_cast(assignStmt.rval->accept(this, param)); + auto rvalue = std::any_cast(assignStmt.rval->accept(this, param)); - rvalueForLValue = rvalue; - auto lvalue = assignStmt.lval->accept(this, param); - rvalueForLValue = nullptr; + requireLValue = true; + auto lvalue = std::any_cast(assignStmt.lval->accept(this, param)); + requireLValue = false; + + builder.CreateStore(rvalue, lvalue); + + return std::any(); + } + + virtual std::any visit(InlineAsm &inlineAsm, std::any param) override { + // Handle the inline assembly code. + auto codeValue = builder.CreateGlobalStringPtr(inlineAsm.code); + + // Build constraints string for LLVM InlineAsm + std::string asmConstraints; + for (const auto &output : inlineAsm.outputs) + asmConstraints += "=" + output->constraint + ","; + for (const auto &input : inlineAsm.inputs) + asmConstraints += input->constraint + ","; + for (const auto &clobber : inlineAsm.clobbers) + asmConstraints += "~{" + clobber + "}" + ","; + if (!asmConstraints.empty()) + asmConstraints.pop_back(); // Remove trailing comma + + // Process inputs and prepare arguments + std::vector llvmArgs; + for (const auto &output : inlineAsm.outputs) { + requireLValue = true; + auto outputValue = std::any_cast(output->value->accept(this, param)); + requireLValue = false; + + llvmArgs.push_back(outputValue); + } + + for (const auto &input : inlineAsm.inputs) { + auto inputValue = std::any_cast(input->value->accept(this, param)); + llvmArgs.push_back(inputValue); + } + + std::vector paramTypes; + for (const auto &arg : llvmArgs) + paramTypes.push_back(arg->getType()); + + auto asmType = llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), paramTypes, false); + auto inlineAsmValue = llvm::InlineAsm::get(asmType, inlineAsm.code, asmConstraints, + true /* hasSideEffects */, false /* isAlignStack */); + + builder.CreateCall(inlineAsmValue, llvmArgs); return std::any(); } @@ -466,8 +506,8 @@ static void runMPM(llvm::Module &mod) { passBuilder.crossRegisterProxies(lam, fam, gam, mam); - mpm = passBuilder.buildModuleOptimizationPipeline( - llvm::OptimizationLevel::O3, llvm::ThinOrFullLTOPhase::None); + mpm = passBuilder.buildModuleOptimizationPipeline(llvm::OptimizationLevel::O3, + llvm::ThinOrFullLTOPhase::None); mpm.run(mod, mam); @@ -477,8 +517,7 @@ static void runMPM(llvm::Module &mod) { lam.clear(); } -static void writeToFile(llvm::LLVMContext &ctx, llvm::Module &mod, - llvm::IRBuilder<> &builder, +static void writeToFile(llvm::LLVMContext &ctx, llvm::Module &mod, llvm::IRBuilder<> &builder, const std::string &outfile) { llvm::InitializeAllTargetInfos(); llvm::InitializeAllTargets(); @@ -495,10 +534,9 @@ static void writeToFile(llvm::LLVMContext &ctx, llvm::Module &mod, if (!t) throw std::runtime_error(err); - llvm::TargetMachine *targetMachine = - t->createTargetMachine(target, "", "", llvm::TargetOptions(), - // TODO: make configurable - llvm::Reloc::PIC_); + llvm::TargetMachine *targetMachine = t->createTargetMachine(target, "", "", llvm::TargetOptions(), + // TODO: make configurable + llvm::Reloc::PIC_); mod.setDataLayout(targetMachine->createDataLayout()); @@ -514,8 +552,7 @@ static void writeToFile(llvm::LLVMContext &ctx, llvm::Module &mod, pm.add(new llvm::TargetLibraryInfoWrapperPass()); pm.add(new llvm::MachineModuleInfoWrapperPass(&tm)); - bool objResult = targetMachine->addPassesToEmitFile( - pm, dest, nullptr, llvm::CodeGenFileType::ObjectFile); + bool objResult = targetMachine->addPassesToEmitFile(pm, dest, nullptr, llvm::CodeGenFileType::ObjectFile); if (objResult) throw std::runtime_error("failed to produce " + outfile); @@ -526,8 +563,7 @@ static void writeToFile(llvm::LLVMContext &ctx, llvm::Module &mod, } // namespace -void compileModule(std::unique_ptr &module, - const std::string &filename) { +void compileModule(std::unique_ptr &module, const std::string &filename) { auto moduleId = filename; llvm::LLVMContext ctx; @@ -550,6 +586,9 @@ void compileModule(std::unique_ptr &module, mod.print(llvm::outs(), nullptr); llvm::outs().flush(); throw std::runtime_error("Module verification failed"); + } else { + mod.print(llvm::outs(), nullptr); + llvm::outs().flush(); } runMPM(mod); // info: does not work, programs will malfunction diff --git a/compiler/src/Visitors/NameAnalysis.cpp b/compiler/src/Visitors/NameAnalysis.cpp index a35feec..92f5937 100644 --- a/compiler/src/Visitors/NameAnalysis.cpp +++ b/compiler/src/Visitors/NameAnalysis.cpp @@ -13,8 +13,7 @@ class NameAnalysisVisitor1 : public ast::BaseASTVisitor { std::vector>> *scopes; public: - NameAnalysisVisitor1( - std::vector>> *scopes) + NameAnalysisVisitor1(std::vector>> *scopes) : scopes(scopes) {} virtual std::any visit(ast::FnDecl &fnDecl, std::any param) override { @@ -22,8 +21,7 @@ public: return std::any(); if (scopes->back().count(fnDecl.name)) { - errors::put( - fnDecl.error("redeclaration of global symbol '" + fnDecl.name + "'")); + errors::put(fnDecl.error("redeclaration of global symbol '" + fnDecl.name + "'")); return std::any(); } @@ -39,8 +37,7 @@ public: return std::any(); if (scopes->back().count(varDecl.name)) { - errors::put(varDecl.error("redeclaration of global symbol '" + - varDecl.name + "'")); + errors::put(varDecl.error("redeclaration of global symbol '" + varDecl.name + "'")); return std::any(); } @@ -55,9 +52,7 @@ public: class NameAnalysisVisitor2 : public ast::BaseASTVisitor { std::vector>> *scopes; - void push() { - scopes->push_back(std::map>()); - } + void push() { scopes->push_back(std::map>()); } void pop() { scopes->pop_back(); } @@ -72,8 +67,7 @@ class NameAnalysisVisitor2 : public ast::BaseASTVisitor { } public: - NameAnalysisVisitor2( - std::vector>> *scopes) + NameAnalysisVisitor2(std::vector>> *scopes) : scopes(scopes) {} virtual std::any visit(ast::FnParam &fnParam, std::any param) override { @@ -118,8 +112,7 @@ public: auto symbol = findSymbol(identifier.name); if (!symbol.get()) { - errors::put(identifier.error("unable to resolve identifier '" + - identifier.name + "'")); + errors::put(identifier.error("unable to resolve identifier '" + identifier.name + "'")); return std::any(); } diff --git a/compiler/src/Visitors/TypeAnalysis.cpp b/compiler/src/Visitors/TypeAnalysis.cpp index eab42b9..fc5ca87 100644 --- a/compiler/src/Visitors/TypeAnalysis.cpp +++ b/compiler/src/Visitors/TypeAnalysis.cpp @@ -32,29 +32,36 @@ static std::map> primitiveTypes = { }; static std::shared_ptr resolveTypeName(const TypeName *typeName) { + if (!typeName) + return std::shared_ptr(nullptr); + if (utils::is(typeName)) { auto named = (NamedTypeName *)typeName; if (primitiveTypes.count(named->name)) return primitiveTypes[named->name]; - errors::put( - typeName->error("unable to resolve named type '" + named->name + "'")); + errors::put(typeName->error("unable to resolve named type '" + named->name + "'")); return std::shared_ptr(nullptr); } + if (utils::is(typeName)) { + auto ptr = (PointerTypeName *)typeName; + auto baseType = resolveTypeName(ptr->baseTypeName.get()); + if (baseType.get()) + return std::make_shared(baseType); + } + // TODO: function type errors::put(typeName->error("unable to resolve type")); return std::shared_ptr(nullptr); } -static void castTo(std::unique_ptr &expr, - const std::shared_ptr &type) { +static void castTo(std::unique_ptr &expr, const std::shared_ptr &type) { if (*expr->type == *type) return; - auto cast = new CastExpr(expr->sourceRange, std::move(expr), - std::unique_ptr(type->toTypeName())); + auto cast = new CastExpr(expr->sourceRange, std::move(expr), std::unique_ptr(type->toTypeName())); cast->type = type; cast->typeName->type = type; @@ -62,19 +69,16 @@ static void castTo(std::unique_ptr &expr, expr.swap(newExpr); } -static bool tryAssignTo(std::unique_ptr &from, - const std::shared_ptr &toType) { - if (*from->type == *toType) +static bool tryAssignTo(std::unique_ptr &from, const std::shared_ptr &toType) { + if ((*from->type) == *toType) return true; - if (utils::is(from->type.get()) && - utils::is(toType.get())) { + if (utils::is(from->type.get()) && utils::is(toType.get())) { PrimitiveType *fromT = (PrimitiveType *)from->type.get(); const PrimitiveType *toT = (PrimitiveType *)toType.get(); std::map> castMatrix = { - {"i8", - {"i16", "i32", "i64", "u8", "u16", "u32", "u64", "i128", "u128"}}, + {"i8", {"i16", "i32", "i64", "u8", "u16", "u32", "u64", "i128", "u128"}}, {"i16", {"i32", "i64", "u16", "u32", "u64", "i128", "u128"}}, {"i32", {"i64", "u32", "u64", "i128", "u128"}}, {"i64", {"u64", "i128", "u128"}}, @@ -103,29 +107,39 @@ static bool tryAssignTo(std::unique_ptr &from, return false; } + else if (utils::is(from->type.get()) && utils::is(toType.get())) { + castTo(from, toType); + return true; + } + + else if (utils::is(from->type.get()) && utils::is(toType.get())) { + castTo(from, toType); + return true; + } + + else if (utils::is(from->type.get()) && utils::is(toType.get())) { + castTo(from, toType); + return true; + } + return false; } -static bool canBeCastedTo(std::unique_ptr &from, - const std::shared_ptr &toType) { +static bool canBeCastedTo(std::unique_ptr &from, const std::shared_ptr &toType) { if (*from->type == *toType) return true; - if (utils::is(from->type.get()) && - utils::is(toType.get())) { + if (utils::is(from->type.get()) && utils::is(toType.get())) { PrimitiveType *fromT = (PrimitiveType *)from->type.get(); const PrimitiveType *toT = (PrimitiveType *)toType.get(); - std::vector allNumberTypes = { - "i8", "i16", "i32", "i64", "i128", "u8", - "u16", "u32", "u64", "u128", "float", "double"}; + std::vector allNumberTypes = {"i8", "i16", "i32", "i64", "i128", "u8", + "u16", "u32", "u64", "u128", "float", "double"}; std::map> castMatrix = { - {"i8", allNumberTypes}, {"i16", allNumberTypes}, - {"i32", allNumberTypes}, {"i64", allNumberTypes}, - {"i128", allNumberTypes}, {"u8", allNumberTypes}, - {"u16", allNumberTypes}, {"u32", allNumberTypes}, - {"u64", allNumberTypes}, {"u128", allNumberTypes}, - {"float", allNumberTypes}, {"double", allNumberTypes}, + {"i8", allNumberTypes}, {"i16", allNumberTypes}, {"i32", allNumberTypes}, + {"i64", allNumberTypes}, {"i128", allNumberTypes}, {"u8", allNumberTypes}, + {"u16", allNumberTypes}, {"u32", allNumberTypes}, {"u64", allNumberTypes}, + {"u128", allNumberTypes}, {"float", allNumberTypes}, {"double", allNumberTypes}, }; if (!castMatrix.count(fromT->name)) @@ -148,12 +162,16 @@ static bool canBeCastedTo(std::unique_ptr &from, class TypeAnalysisVisitor1 : public BaseASTVisitor { public: - virtual std::any visit(NamedTypeName &namedTypeName, - std::any param) override { + virtual std::any visit(NamedTypeName &namedTypeName, std::any param) override { namedTypeName.type = resolveTypeName(&namedTypeName); return std::any(); } + virtual std::any visit(PointerTypeName &pointerTypeName, std::any param) override { + pointerTypeName.type = resolveTypeName(&pointerTypeName); + return std::any(); + } + virtual std::any visit(VarDecl &varDecl, std::any param) override { if (!varDecl.typeName.get() || !varDecl.symbol.get()) return std::any(); @@ -178,8 +196,7 @@ public: virtual std::any visit(FnDecl &fnDecl, std::any param) override { if (!(fnDecl.body.get() && fnDecl.body->alywasReturns())) { - errors::put(fnDecl.error("function '" + fnDecl.name + - "' does not always return a value")); + errors::put(fnDecl.error("function '" + fnDecl.name + "' does not always return a value")); } if (!fnDecl.symbol.get()) @@ -203,8 +220,7 @@ public: } } - fnDecl.symbol->type = - std::make_shared(paramTypes, returnType); + fnDecl.symbol->type = std::make_shared(paramTypes, returnType); return std::any(); } @@ -214,12 +230,16 @@ class TypeAnalysisVisitor2 : public BaseASTVisitor { std::shared_ptr currentReturnType = nullptr; public: - virtual std::any visit(NamedTypeName &namedTypeName, - std::any param) override { + virtual std::any visit(NamedTypeName &namedTypeName, std::any param) override { namedTypeName.type = resolveTypeName(&namedTypeName); return std::any(); } + virtual std::any visit(PointerTypeName &pointerTypeName, std::any param) override { + pointerTypeName.type = resolveTypeName(&pointerTypeName); + return std::any(); + } + virtual std::any visit(VarDecl &varDecl, std::any param) override { if (!varDecl.typeName.get() || !varDecl.symbol.get()) return std::any(); @@ -315,6 +335,12 @@ public: return std::any(); } + virtual std::any visit(StringValue &stringValue, std::any param) override { + stringValue.type = std::make_shared(std::make_shared("u8")); + + return std::any(); + } + virtual std::any visit(FloatValue &floatValue, std::any param) override { if ((float_t)floatValue.value == floatValue.value) { floatValue.type = std::make_shared("float"); @@ -353,11 +379,6 @@ public: if (!assignStmt.lval.get() || !assignStmt.rval->type.get()) return std::any(); - if (!utils::is(assignStmt.lval.get())) { - errors::put(assignStmt.error("invalid lvalue")); - return std::any(); - } - if (!tryAssignTo(assignStmt.rval, assignStmt.lval->type)) { errors::put(assignStmt.error("assignment type mismatch")); return std::any(); @@ -387,13 +408,15 @@ public: BaseASTVisitor::visit(ifStmt, param); // do not multiplicate errors - if (ifStmt.condition.get() && ifStmt.condition->type.get()) { - if (!utils::is(ifStmt.condition->type.get())) { - errors::put(ifStmt.error("condition must be of primitive type")); - } else { - // cast condition to int (to make sure jnz succeeds) - castTo(ifStmt.condition, std::make_shared("i32")); - } + if (!ifStmt.condition.get() || !ifStmt.condition->type.get()) + return std::any(); + + if (!utils::is(ifStmt.condition->type.get()) && + !utils::is(ifStmt.condition->type.get())) { + errors::put(ifStmt.error("condition must be of primitive or pointer type")); + } else { + // cast condition to int (to make sure jnz succeeds) + castTo(ifStmt.condition, std::make_shared("i32")); } return std::any(); @@ -409,7 +432,7 @@ public: return std::any(); } else { // cast condition to int (to make sure jnz succeeds) - castTo(whileStmt.condition, std::make_shared("i32")); + castTo(whileStmt.condition, std::make_shared("i64")); } } @@ -437,8 +460,7 @@ public: return std::any(); } - size_t smallerArgCount = - std::min(functionType->paramTypes.size(), callExpr.args.size()); + size_t smallerArgCount = std::min(functionType->paramTypes.size(), callExpr.args.size()); for (size_t i = 0; i < smallerArgCount; i++) { if (!callExpr.args[i].get() || !functionType->paramTypes[i].get()) @@ -495,8 +517,7 @@ public: case BinOp::DIV: { if (!tryAssignTo(binExpr.rhs, binExpr.lhs->type)) { if (!tryAssignTo(binExpr.lhs, binExpr.rhs->type)) { - errors::put( - binExpr.error("operands incompatible, explicit cast required")); + errors::put(binExpr.error("operands incompatible, explicit cast required")); return std::any(); } } @@ -506,10 +527,8 @@ public: break; } case BinOp::MOD: { - auto lhsSuccess = - tryAssignTo(binExpr.lhs, std::make_shared("i64")); - auto rhsSuccess = - tryAssignTo(binExpr.rhs, std::make_shared("i64")); + auto lhsSuccess = tryAssignTo(binExpr.lhs, std::make_shared("i64")); + auto rhsSuccess = tryAssignTo(binExpr.rhs, std::make_shared("i64")); if (!lhsSuccess || !rhsSuccess) { errors::put(binExpr.error("operands must be of integer type")); @@ -526,8 +545,7 @@ public: case BinOp::GE: { if (!tryAssignTo(binExpr.rhs, binExpr.lhs->type)) { if (!tryAssignTo(binExpr.lhs, binExpr.rhs->type)) { - errors::put( - binExpr.error("operands incompatible, explicit cast required")); + errors::put(binExpr.error("operands incompatible, explicit cast required")); return std::any(); } } @@ -549,6 +567,38 @@ public: return std::any(); } + + virtual std::any visit(InlineAsmConstraint &constraint, std::any param) override { + if (constraint.constraint.empty()) { + errors::put(constraint.error("inline assembly constraint cannot be empty")); + } + + // For simplicity, assume constraints are string-based and should match + // certain patterns. You can add regex-based validation if needed. + + BaseASTVisitor::visit(constraint, param); + + return std::any(); + } + + virtual std::any visit(InlineAsm &inlineAsm, std::any param) override { + if (inlineAsm.code.empty()) { + errors::put(inlineAsm.error("inline assembly code cannot be empty")); + return std::any(); + } + + BaseASTVisitor::visit(inlineAsm, param); + + for (const auto &clobber : inlineAsm.clobbers) { + if (clobber.empty()) { + errors::put(inlineAsm.error("clobber cannot be empty")); + } + // Optionally, check if clobber registers are valid (e.g., using a + // whitelist of registers). + } + + return std::any(); + } }; } // namespace diff --git a/examples/new b/examples/new index e943eae..636dc9b 100755 Binary files a/examples/new and b/examples/new differ diff --git a/examples/new.plsm b/examples/new.plsm index d73e5d2..b938da8 100644 --- a/examples/new.plsm +++ b/examples/new.plsm @@ -1,22 +1,35 @@ -fun addFirst(n : i32) i32 { - var result : i32; - result = 0; +fun write(fd : i64, msg : &u8, len : u64) i64 { + inline asm ( + "mov $0, %rax" // syscall: write + "mov $1, %rdi" // file descriptor: stdout + "mov $2, %rsi" // message to write + "mov $3, %rdx" // length of message + "syscall" // make the syscall + : + : "r"(1 as i64), "r"(1 as i64), "r"(msg), "r"(len) // input: different params + : "rax", "rdi", "rsi", "rdx" // clobbered registers + ); - var i : i32; - i = 0; - while (i < n) { - result = result + i; - i = i + 1; - } - - ret result; + ret 0; } -// fun f(n : i32) i32 { -// if (n < 2) ret n; -// ret f(n - 1); -// } +fun exit(code : u8) i64 { + inline asm ( + "mov $0, %rax" + "mov $1, %rdi" + "syscall" + : + : "r"(60 as i64), "r"(code as i64) + : "rax", "rdi" + ); + + ret 0; +} fun main(argc : i32) u8 { - ret addFirst(argc) as u8; + write(1, "Hello World!\n", 14); + write(1, "Ə Ɛ Ƒ ƒ Ɠ Ɣ ƕ Ɩ Ɨ Ƙ ƙ ƚ ƛ Ɯ Ɲ ƞ Ɵ Ơ ơ Ƣ ƣ Ƥ ƥ\n", 69); + exit(10); + + ret 0; }