From f51a53e2fe0ccf89a7545ac1a73bf583ba6474ec Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Thu, 20 Jan 2022 11:22:32 -0500 Subject: [PATCH] Inline JIT functions --- codon/app/main.cpp | 12 ++++-- codon/compiler/jit.cpp | 14 +++++-- codon/sir/llvm/llvisitor.cpp | 80 +++++++++++++++++++++++------------- codon/sir/llvm/llvisitor.h | 8 +--- 4 files changed, 74 insertions(+), 40 deletions(-) diff --git a/codon/app/main.cpp b/codon/app/main.cpp index e3fcb0b8..da32bfa8 100644 --- a/codon/app/main.cpp +++ b/codon/app/main.cpp @@ -56,6 +56,12 @@ void display(const codon::error::ParserErrorInfo &e) { } } +void initLogFlags(const llvm::cl::opt &log) { + codon::getLogger().parse(log); + if (auto *d = getenv("CODON_DEBUG")) + codon::getLogger().parse(std::string(d)); +} + enum BuildKind { LLVM, Bitcode, Object, Executable, Detect }; enum OptMode { Debug, Release }; } // namespace @@ -103,9 +109,7 @@ std::unique_ptr processSource(const std::vector & llvm::cl::opt log("log", llvm::cl::desc("Enable given log streams")); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); - codon::getLogger().parse(log); - if (auto *d = getenv("CODON_DEBUG")) - codon::getLogger().parse(std::string(d)); + initLogFlags(log); auto &exts = supportedExtensions(); if (input != "-" && std::find_if(exts.begin(), exts.end(), [&](auto &ext) { @@ -207,7 +211,9 @@ std::string jitExec(codon::jit::JIT *jit, const std::string &code) { int jitMode(const std::vector &args) { llvm::cl::list plugins("plugin", llvm::cl::desc("Load specified plugin")); + llvm::cl::opt log("log", llvm::cl::desc("Enable given log streams")); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); + initLogFlags(log); codon::jit::JIT jit(args[0]); // load plugins diff --git a/codon/compiler/jit.cpp b/codon/compiler/jit.cpp index 4f3b01e7..881e5a82 100644 --- a/codon/compiler/jit.cpp +++ b/codon/compiler/jit.cpp @@ -44,7 +44,7 @@ llvm::Error JIT::init() { pm->run(module); module->accept(*llvisitor); - auto pair = llvisitor->takeModule(); + auto pair = llvisitor->takeModule(module); if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)})) return err; @@ -62,12 +62,18 @@ llvm::Expected JIT::run(const ir::Func *input) { auto *module = compiler->getModule(); auto *pm = compiler->getPassManager(); auto *llvisitor = compiler->getLLVMVisitor(); + + Timer t1("jit/ir"); pm->run(module); + t1.log(); const std::string name = ir::LLVMVisitor::getNameForFunction(input); - llvisitor->processNewGlobals(module); - auto pair = llvisitor->takeModule(); + Timer t2("jit/llvm"); + auto pair = llvisitor->takeModule(module); + t2.log(); + + Timer t3("jit/engine"); if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)})) return std::move(err); @@ -76,6 +82,8 @@ llvm::Expected JIT::run(const ir::Func *input) { return std::move(err); auto *repl = (InputFunc *)func->getAddress(); + t3.log(); + try { (*repl)(); } catch (const JITError &e) { diff --git a/codon/sir/llvm/llvisitor.cpp b/codon/sir/llvm/llvisitor.cpp index d108a810..269094ea 100644 --- a/codon/sir/llvm/llvisitor.cpp +++ b/codon/sir/llvm/llvisitor.cpp @@ -113,25 +113,6 @@ void LLVMVisitor::registerGlobal(const Var *var) { } } -void LLVMVisitor::processNewGlobals(Module *module) { - std::vector newFuncs; - for (auto *var : *module) { - if (!var->isGlobal()) - continue; - auto id = var->getId(); - auto *func = cast(var); - bool isNewFunc = (func && funcs.find(id) == funcs.end()); - if (isNewFunc || (!func && vars.find(id) == vars.end())) - registerGlobal(var); - if (isNewFunc) - newFuncs.push_back(func); - } - - for (auto *func : newFuncs) { - func->accept(*this); - } -} - llvm::Value *LLVMVisitor::getVar(const Var *var) { auto it = vars.find(var->getId()); if (db.jit && var->isGlobal()) { @@ -230,7 +211,34 @@ std::unique_ptr LLVMVisitor::makeModule(llvm::LLVMContext &context } std::pair, std::unique_ptr> -LLVMVisitor::takeModule(const SrcInfo *src) { +LLVMVisitor::takeModule(Module *module, const SrcInfo *src) { + // process any new functions or globals + if (module) { + std::unordered_set funcsToProcess; + for (auto *var : *module) { + auto id = var->getId(); + if (auto *func = cast(var)) { + if (funcs.find(id) != funcs.end()) + continue; + else + funcsToProcess.insert(id); + } else { + if (vars.find(id) != vars.end()) + continue; + } + + registerGlobal(var); + } + + for (auto *var : *module) { + if (auto *func = cast(var)) { + if (funcsToProcess.find(func->getId()) != funcsToProcess.end()) { + process(func); + } + } + } + } + db.builder->finalize(); auto currentContext = std::move(context); auto currentModule = std::move(M); @@ -240,10 +248,25 @@ LLVMVisitor::takeModule(const SrcInfo *src) { func = nullptr; block = nullptr; value = nullptr; - for (auto &it : vars) - it.second = nullptr; - for (auto &it : funcs) - it.second = nullptr; + + for (auto it = funcs.begin(); it != funcs.end();) { + if (it->second && it->second->hasPrivateLinkage()) { + it = funcs.erase(it); + } else { + it->second = nullptr; + ++it; + } + } + + for (auto it = vars.begin(); it != vars.end();) { + if (it->second && !llvm::isa(it->second)) { + it = vars.erase(it); + } else { + it->second = nullptr; + ++it; + } + } + coro.reset(); loops.clear(); trycatch.clear(); @@ -824,7 +847,7 @@ void LLVMVisitor::visit(const InternalFunc *x) { auto *funcType = cast(x->getType()); std::vector argTypes(funcType->begin(), funcType->end()); - func->setLinkage(getDefaultLinkage()); + func->setLinkage(llvm::GlobalValue::PrivateLinkage); func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); std::vector args; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { @@ -987,7 +1010,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) { seqassert(!fail, "linking failed"); func = M->getFunction(getNameForFunction(x)); seqassert(func, "function not linked in"); - func->setLinkage(getDefaultLinkage()); + func->setLinkage(llvm::GlobalValue::PrivateLinkage); func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); func->setSubprogram(getDISubprogramForFunc(x)); @@ -1011,10 +1034,11 @@ void LLVMVisitor::visit(const BodiedFunc *x) { setDebugInfoForNode(x); auto *fnAttributes = x->getAttribute(); - if (fnAttributes && fnAttributes->has("std.internal.attributes.export")) { + if (x->isJIT() || + (fnAttributes && fnAttributes->has("std.internal.attributes.export"))) { func->setLinkage(llvm::GlobalValue::ExternalLinkage); } else { - func->setLinkage(getDefaultLinkage()); + func->setLinkage(llvm::GlobalValue::PrivateLinkage); } if (fnAttributes && fnAttributes->has("std.internal.attributes.inline")) { func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); diff --git a/codon/sir/llvm/llvisitor.h b/codon/sir/llvm/llvisitor.h index 4f28252b..ab20fcc6 100644 --- a/codon/sir/llvm/llvisitor.h +++ b/codon/sir/llvm/llvisitor.h @@ -275,11 +275,6 @@ public: /// @param var the global variable (or function) to register void registerGlobal(const Var *var); - /// Processes new globals that were not previously - /// compiled. Used in JIT mode. - /// @param module the IR module - void processNewGlobals(Module *module); - /// Returns the default LLVM linkage type for the module. /// @return LLVM linkage type llvm::GlobalValue::LinkageTypes getDefaultLinkage(); @@ -295,10 +290,11 @@ public: /// Returns the current module/LLVM context and replaces them /// with new, fresh ones. References to variables or functions /// from the old module will be included as "external". + /// @param module the IR module /// @param src source information for the new module /// @return the current module/context, replaced internally std::pair, std::unique_ptr> - takeModule(const SrcInfo *src = nullptr); + takeModule(Module *module, const SrcInfo *src = nullptr); /// Sets current debug info based on a given node. /// @param node the node whose debug info to use