#include "AST/BaseASTVisitor.h" #include "Compile.h" #include "Utils.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace plsm { using namespace ast; namespace { static llvm::Type *getLLVMType(llvm::LLVMContext &ctx, const std::shared_ptr &type) { if (utils::is(type.get())) { auto primitiveType = (PrimitiveType *)type.get(); if (primitiveType->name == "i8" || primitiveType->name == "u8") return llvm::Type::getInt8Ty(ctx); if (primitiveType->name == "i16" || primitiveType->name == "u16") return llvm::Type::getInt16Ty(ctx); if (primitiveType->name == "i32" || primitiveType->name == "u32") return llvm::Type::getInt32Ty(ctx); if (primitiveType->name == "i64" || primitiveType->name == "u64") return llvm::Type::getInt64Ty(ctx); if (primitiveType->name == "i128" || primitiveType->name == "u128") return llvm::Type::getInt128Ty(ctx); if (primitiveType->name == "float") return llvm::Type::getFloatTy(ctx); if (primitiveType->name == "double") return llvm::Type::getDoubleTy(ctx); } else if (utils::is(type.get())) { auto functionType = (FunctionType *)type.get(); auto returnType = getLLVMType(ctx, functionType->returnType); std::vector llvmParams; llvmParams.push_back(llvm::IntegerType::get(ctx, 8)->getPointerTo()); for (auto ¶mType : functionType->paramTypes) llvmParams.push_back(getLLVMType(ctx, paramType)); return llvm::FunctionType::get(returnType, llvmParams, false); } throw std::runtime_error("cannot determine llvm type"); return nullptr; } class IRGenerator1 : public BaseASTVisitor { llvm::LLVMContext &ctx; llvm::Module &mod; llvm::IRBuilder<> &builder; std::map &symbolMap; std::set &functions; public: IRGenerator1(llvm::LLVMContext &ctx, llvm::Module &mod, llvm::IRBuilder<> &builder, std::map &symbolMap, std::set &functions) : ctx(ctx), mod(mod), builder(builder), symbolMap(symbolMap), functions(functions) {} virtual std::any visit(VarDecl &varDecl, std::any param) override { auto llvmType = getLLVMType(ctx, varDecl.symbol->type); // auto global = new llvm::GlobalVariable(mod, llvmType, false, // llvm::GlobalValue::ExternalLinkage, // nullptr, varDecl.name); auto global = mod.getOrInsertGlobal(varDecl.name, llvmType); symbolMap[varDecl.symbol.get()] = global; return std::any(); } virtual std::any visit(FnDecl &fnDecl, std::any param) override { auto functionType = (llvm::FunctionType *)getLLVMType(ctx, fnDecl.symbol->type); mod.getOrInsertFunction(fnDecl.name, functionType); auto function = mod.getFunction(fnDecl.name); symbolMap[fnDecl.symbol.get()] = function; functions.insert(function); return std::any(); } }; class IRGenerator2 : public BaseASTVisitor { llvm::LLVMContext &ctx; llvm::Module &mod; llvm::IRBuilder<> &builder; std::map &symbolMap; std::set &functions; std::set lambdas; llvm::Value *retStore = nullptr; llvm::BasicBlock *retBlock = nullptr; llvm::Value *rvalueForLValue = nullptr; size_t labelCounter = 0; std::string createLabel() { return "L" + std::to_string(labelCounter++); } llvm::Value *wrapCallee(llvm::Value *callee) { auto ptr = llvm::IntegerType::get(ctx, 8)->getPointerTo(); auto structType = llvm::StructType::get(ptr, callee->getType()); auto store = builder.CreateAlloca(structType); auto ep = builder.CreateStructGEP(structType, store, 0); builder.CreateStore(llvm::ConstantPointerNull::get(ptr), ep); ep = builder.CreateStructGEP(structType, store, 1); builder.CreateStore(callee, ep); return store; // (f() -> asdf) -> { context, f() -> asdf } } public: IRGenerator2(llvm::LLVMContext &ctx, llvm::Module &mod, llvm::IRBuilder<> &builder, std::map &symbolMap, std::set &functions) : ctx(ctx), mod(mod), builder(builder), symbolMap(symbolMap), functions(functions) {} virtual std::any visit(FnDecl &fnDecl, std::any param) override { auto function = mod.getFunction(fnDecl.name); auto block = llvm::BasicBlock::Create(ctx, createLabel(), function); builder.SetInsertPoint(block); size_t i = 0; for (auto &arg : function->args()) { if (i > 0) { auto store = builder.CreateAlloca(arg.getType(), nullptr); builder.CreateStore(&arg, store); symbolMap[fnDecl.params[i - 1]->symbol.get()] = store; } i += 1; } auto fnType = (llvm::FunctionType *)getLLVMType(ctx, fnDecl.symbol->type); retStore = builder.CreateAlloca(fnType->getReturnType()); retBlock = llvm::BasicBlock::Create(ctx, createLabel(), function); BaseASTVisitor::visit(fnDecl, param); if (!fnDecl.body->alywasReturns()) builder.CreateBr(retBlock); builder.SetInsertPoint(retBlock); auto retVal = builder.CreateLoad(fnType->getReturnType(), retStore); builder.CreateRet(retVal); return std::any(); } virtual std::any visit(VarDecl &varDecl, std::any param) override { auto store = builder.CreateAlloca(getLLVMType(ctx, varDecl.symbol->type)); symbolMap[varDecl.symbol.get()] = store; return std::any(); } virtual std::any visit(BinExpr &binExpr, std::any param) override { auto lhs = std::any_cast(binExpr.lhs->accept(this, param)); auto rhs = std::any_cast(binExpr.rhs->accept(this, param)); auto primitiveType = (PrimitiveType *)binExpr.type.get(); auto name = primitiveType->name; auto isFloat = name == "float" || name == "double"; auto isUnsigned = name[0] == 'u'; switch (binExpr.op) { case BinOp::ADD: if (isFloat) return builder.CreateFAdd(lhs, rhs); return builder.CreateAdd(lhs, rhs); case BinOp::SUB: if (isFloat) return builder.CreateFSub(lhs, rhs); return builder.CreateSub(lhs, rhs); case BinOp::MUL: if (isFloat) return builder.CreateFMul(lhs, rhs); return builder.CreateMul(lhs, rhs); case BinOp::DIV: if (isFloat) return builder.CreateFDiv(lhs, rhs); if (isUnsigned) return builder.CreateUDiv(lhs, rhs); return builder.CreateSDiv(lhs, rhs); case BinOp::MOD: if (isUnsigned) return builder.CreateURem(lhs, rhs); return builder.CreateSRem(lhs, rhs); case BinOp::EQ: return builder.CreateICmpEQ(lhs, rhs); case BinOp::NE: return builder.CreateICmpNE(lhs, rhs); case BinOp::LT: if (isFloat) return builder.CreateFCmpOGT(lhs, rhs); if (isUnsigned) return builder.CreateICmpULT(lhs, rhs); return builder.CreateICmpSLT(lhs, rhs); case BinOp::GT: if (isFloat) return builder.CreateFCmpOGT(lhs, rhs); if (isUnsigned) return builder.CreateICmpUGT(lhs, rhs); return builder.CreateICmpSGT(lhs, rhs); case BinOp::LE: if (isFloat) return builder.CreateFCmpOLE(lhs, rhs); if (isUnsigned) return builder.CreateICmpULE(lhs, rhs); return builder.CreateICmpSLE(lhs, rhs); case BinOp::GE: if (isFloat) return builder.CreateFCmpOGE(lhs, rhs); if (isUnsigned) return builder.CreateICmpUGE(lhs, rhs); return builder.CreateICmpSGE(lhs, rhs); case BinOp::AND: return builder.CreateAnd(lhs, rhs); case BinOp::OR: return builder.CreateOr(lhs, rhs); } throw std::runtime_error("binop not implemented"); } virtual std::any visit(UnExpr &unExpr, std::any param) override { auto expr = std::any_cast(unExpr.expr->accept(this, param)); switch (unExpr.op) { case UnOp::NEG: return builder.CreateNeg(expr); case UnOp::NOT: return builder.CreateNot(expr); case UnOp::POS: return expr; } throw std::runtime_error("unop not implemented"); } virtual std::any visit(CastExpr &castExpr, std::any param) override { auto value = std::any_cast(castExpr.value->accept(this, param)); if (utils::is(castExpr.value->type.get()) && utils::is(castExpr.type.get())) { auto primitiveType = (PrimitiveType *)castExpr.value->type.get(); auto newPrimitiveType = (PrimitiveType *)castExpr.type.get(); auto newType = getLLVMType(ctx, castExpr.type); auto wasFloat = primitiveType->name == "float" || primitiveType->name == "double"; auto wasUnsigned = primitiveType->name[0] == 'u'; auto willFloat = newPrimitiveType->name == "float" || newPrimitiveType->name == "double"; auto willUnsigned = newPrimitiveType->name[0] == 'u'; if (wasFloat) { if (willFloat) { if (primitiveType->name == "double") return builder.CreateFPExt(value, newType); else return builder.CreateFPTrunc(value, newType); } else { if (willUnsigned) return builder.CreateFPToUI(value, newType); else return builder.CreateFPToSI(value, newType); } } else { if (willFloat) { if (wasUnsigned) return builder.CreateUIToFP(value, newType); else return builder.CreateSIToFP(value, newType); } if (willUnsigned) return builder.CreateZExtOrTrunc(value, newType); else return builder.CreateSExtOrTrunc(value, newType); } } throw std::runtime_error("cast not implemented"); } virtual std::any visit(CallExpr &callExpr, std::any param) override { auto callee = std::any_cast(callExpr.callee->accept(this, param)); auto ptrType = llvm::IntegerType::get(ctx, 8)->getPointerTo(); auto structType = llvm::StructType::get(ptrType, ptrType); auto ep = builder.CreateStructGEP(structType, callee, 0); auto callCtx = builder.CreateLoad(ptrType, ep); ep = builder.CreateStructGEP(structType, callee, 1); auto realCallee = (llvm::Value *)builder.CreateLoad(ptrType, ep); // realCallee = builder.CreatePointerCast(realCallee, calleeType); std::vector llvmArgs; llvmArgs.push_back(callCtx); for (auto &arg : callExpr.args) { llvmArgs.push_back( std::any_cast(arg->accept(this, param))); } auto calleeType = (llvm::FunctionType *)getLLVMType(ctx, callExpr.callee->type); return (llvm::Value *)builder.CreateCall(calleeType, realCallee, llvmArgs); } virtual std::any visit(Identifier &identifier, std::any param) override { auto value = symbolMap[identifier.symbol.get()]; if (rvalueForLValue) { builder.CreateStore(rvalueForLValue, value); return nullptr; } else { if (functions.count(value)) return wrapCallee(value); if (utils::is(identifier.type.get())) return value; auto loadType = getLLVMType(ctx, identifier.type); return (llvm::Value *)builder.CreateLoad(loadType, value); } } virtual std::any visit(IntValue &intValue, std::any param) override { return (llvm::Value *)llvm::ConstantInt::get( getLLVMType(ctx, intValue.type), intValue.value); } virtual std::any visit(FloatValue &floatValue, std::any param) override { return (llvm::Value *)llvm::ConstantFP::get( getLLVMType(ctx, floatValue.type), floatValue.value); } virtual std::any visit(RetStmt &retStmt, std::any param) override { auto value = std::any_cast(retStmt.value->accept(this, param)); builder.CreateStore(value, retStore); builder.CreateBr(retBlock); return std::any(); } virtual std::any visit(IfStmt &ifStmt, std::any param) override { auto cond = std::any_cast(ifStmt.condition->accept(this, param)); auto fn = builder.GetInsertBlock()->getParent(); auto ifBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); auto elseBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); llvm::BasicBlock *endBlock = nullptr; if (!ifStmt.alywasReturns()) endBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); builder.CreateCondBr(cond, ifBlock, elseBlock); builder.SetInsertPoint(ifBlock); ifStmt.ifBody->accept(this, param); if (endBlock && !ifStmt.ifBody->alywasReturns()) builder.CreateBr(endBlock); builder.SetInsertPoint(elseBlock); ifStmt.elseBody->accept(this, param); if (endBlock && !ifStmt.elseBody->alywasReturns()) builder.CreateBr(endBlock); if (endBlock && !ifStmt.alywasReturns()) builder.SetInsertPoint(endBlock); return std::any(); } virtual std::any visit(WhileStmt &whileStmt, std::any param) override { auto fn = builder.GetInsertBlock()->getParent(); auto condBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); auto whileBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); auto endBlock = llvm::BasicBlock::Create(ctx, createLabel(), fn); builder.CreateBr(condBlock); builder.SetInsertPoint(condBlock); auto cond = std::any_cast(whileStmt.condition->accept(this, param)); builder.CreateCondBr(cond, whileBlock, endBlock); builder.SetInsertPoint(whileBlock); whileStmt.body->accept(this, param); if (!whileStmt.body->alywasReturns()) builder.CreateBr(condBlock); builder.SetInsertPoint(endBlock); return std::any(); } virtual std::any visit(AssignStmt &assignStmt, std::any param) override { auto rvalue = std::any_cast(assignStmt.rval->accept(this, param)); rvalueForLValue = rvalue; auto lvalue = assignStmt.lval->accept(this, param); rvalueForLValue = nullptr; return std::any(); } }; static void runMPM(llvm::Module &mod) { llvm::PassBuilder passBuilder; llvm::ModuleAnalysisManager mam; llvm::CGSCCAnalysisManager gam; llvm::FunctionAnalysisManager fam; llvm::LoopAnalysisManager lam; llvm::ModulePassManager mpm; passBuilder.registerModuleAnalyses(mam); passBuilder.registerCGSCCAnalyses(gam); passBuilder.registerFunctionAnalyses(fam); passBuilder.registerLoopAnalyses(lam); passBuilder.crossRegisterProxies(lam, fam, gam, mam); mpm = passBuilder.buildModuleOptimizationPipeline( llvm::OptimizationLevel::O3, llvm::ThinOrFullLTOPhase::None); mpm.run(mod, mam); mam.clear(); gam.clear(); fam.clear(); lam.clear(); } static void writeToFile(llvm::LLVMContext &ctx, llvm::Module &mod, llvm::IRBuilder<> &builder, const std::string &outfile) { llvm::InitializeAllTargetInfos(); llvm::InitializeAllTargets(); llvm::InitializeAllTargetMCs(); llvm::InitializeAllAsmParsers(); llvm::InitializeAllAsmPrinters(); auto target = LLVM_DEFAULT_TARGET_TRIPLE; mod.setTargetTriple(target); std::string err; const llvm::Target *t = llvm::TargetRegistry::lookupTarget(target, err); if (!t) throw std::runtime_error(err); llvm::TargetMachine *targetMachine = t->createTargetMachine(target, "", "", llvm::TargetOptions(), // TODO: make configurable llvm::Reloc::PIC_); mod.setDataLayout(targetMachine->createDataLayout()); std::error_code ec; llvm::raw_fd_ostream dest(outfile, ec, llvm::sys::fs::OF_None); if (ec) throw std::runtime_error(ec.message()); llvm::legacy::PassManager pm; auto &tm = (llvm::LLVMTargetMachine &)*targetMachine; pm.add(new llvm::TargetLibraryInfoWrapperPass()); pm.add(new llvm::MachineModuleInfoWrapperPass(&tm)); bool objResult = targetMachine->addPassesToEmitFile( pm, dest, nullptr, llvm::CodeGenFileType::ObjectFile); if (objResult) throw std::runtime_error("failed to produce " + outfile); pm.run(mod); dest.flush(); } } // namespace void compileModule(std::unique_ptr &module, const std::string &filename) { auto moduleId = filename; llvm::LLVMContext ctx; llvm::Module mod(moduleId, ctx); llvm::IRBuilder<> builder(ctx); std::map symbolMap; std::set functions; IRGenerator1 generator1(ctx, mod, builder, symbolMap, functions); module->accept(&generator1, nullptr); IRGenerator2 generator2(ctx, mod, builder, symbolMap, functions); for (auto &stmt : module->stmts) { if (utils::is(stmt.get())) stmt->accept(&generator2, nullptr); } if (llvm::verifyModule(mod, &llvm::errs())) { mod.print(llvm::outs(), nullptr); llvm::outs().flush(); throw std::runtime_error("Module verification failed"); } runMPM(mod); // info: does not work, programs will malfunction mod.print(llvm::outs(), nullptr); llvm::outs().flush(); // std::cout << "----------------------------------------------" << std::endl; // mod.print(llvm::outs(), nullptr); // llvm::outs().flush(); writeToFile(ctx, mod, builder, filename + ".o"); } } // namespace plsm