1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Inline JIT functions

This commit is contained in:
A. R. Shajii 2022-01-20 11:22:32 -05:00
parent 6f4e24fb00
commit f51a53e2fe
4 changed files with 74 additions and 40 deletions

View File

@ -56,6 +56,12 @@ void display(const codon::error::ParserErrorInfo &e) {
}
}
void initLogFlags(const llvm::cl::opt<std::string> &log) {
codon::getLogger().parse(log);
if (auto *d = getenv("CODON_DEBUG"))
codon::getLogger().parse(std::string(d));
}
enum BuildKind { LLVM, Bitcode, Object, Executable, Detect };
enum OptMode { Debug, Release };
} // namespace
@ -103,9 +109,7 @@ std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &
llvm::cl::opt<std::string> log("log", llvm::cl::desc("Enable given log streams"));
llvm::cl::ParseCommandLineOptions(args.size(), args.data());
codon::getLogger().parse(log);
if (auto *d = getenv("CODON_DEBUG"))
codon::getLogger().parse(std::string(d));
initLogFlags(log);
auto &exts = supportedExtensions();
if (input != "-" && std::find_if(exts.begin(), exts.end(), [&](auto &ext) {
@ -207,7 +211,9 @@ std::string jitExec(codon::jit::JIT *jit, const std::string &code) {
int jitMode(const std::vector<const char *> &args) {
llvm::cl::list<std::string> plugins("plugin",
llvm::cl::desc("Load specified plugin"));
llvm::cl::opt<std::string> log("log", llvm::cl::desc("Enable given log streams"));
llvm::cl::ParseCommandLineOptions(args.size(), args.data());
initLogFlags(log);
codon::jit::JIT jit(args[0]);
// load plugins

View File

@ -44,7 +44,7 @@ llvm::Error JIT::init() {
pm->run(module);
module->accept(*llvisitor);
auto pair = llvisitor->takeModule();
auto pair = llvisitor->takeModule(module);
if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)}))
return err;
@ -62,12 +62,18 @@ llvm::Expected<std::string> JIT::run(const ir::Func *input) {
auto *module = compiler->getModule();
auto *pm = compiler->getPassManager();
auto *llvisitor = compiler->getLLVMVisitor();
Timer t1("jit/ir");
pm->run(module);
t1.log();
const std::string name = ir::LLVMVisitor::getNameForFunction(input);
llvisitor->processNewGlobals(module);
auto pair = llvisitor->takeModule();
Timer t2("jit/llvm");
auto pair = llvisitor->takeModule(module);
t2.log();
Timer t3("jit/engine");
if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)}))
return std::move(err);
@ -76,6 +82,8 @@ llvm::Expected<std::string> JIT::run(const ir::Func *input) {
return std::move(err);
auto *repl = (InputFunc *)func->getAddress();
t3.log();
try {
(*repl)();
} catch (const JITError &e) {

View File

@ -113,25 +113,6 @@ void LLVMVisitor::registerGlobal(const Var *var) {
}
}
void LLVMVisitor::processNewGlobals(Module *module) {
std::vector<Func *> newFuncs;
for (auto *var : *module) {
if (!var->isGlobal())
continue;
auto id = var->getId();
auto *func = cast<Func>(var);
bool isNewFunc = (func && funcs.find(id) == funcs.end());
if (isNewFunc || (!func && vars.find(id) == vars.end()))
registerGlobal(var);
if (isNewFunc)
newFuncs.push_back(func);
}
for (auto *func : newFuncs) {
func->accept(*this);
}
}
llvm::Value *LLVMVisitor::getVar(const Var *var) {
auto it = vars.find(var->getId());
if (db.jit && var->isGlobal()) {
@ -230,7 +211,34 @@ std::unique_ptr<llvm::Module> LLVMVisitor::makeModule(llvm::LLVMContext &context
}
std::pair<std::unique_ptr<llvm::Module>, std::unique_ptr<llvm::LLVMContext>>
LLVMVisitor::takeModule(const SrcInfo *src) {
LLVMVisitor::takeModule(Module *module, const SrcInfo *src) {
// process any new functions or globals
if (module) {
std::unordered_set<id_t> funcsToProcess;
for (auto *var : *module) {
auto id = var->getId();
if (auto *func = cast<Func>(var)) {
if (funcs.find(id) != funcs.end())
continue;
else
funcsToProcess.insert(id);
} else {
if (vars.find(id) != vars.end())
continue;
}
registerGlobal(var);
}
for (auto *var : *module) {
if (auto *func = cast<Func>(var)) {
if (funcsToProcess.find(func->getId()) != funcsToProcess.end()) {
process(func);
}
}
}
}
db.builder->finalize();
auto currentContext = std::move(context);
auto currentModule = std::move(M);
@ -240,10 +248,25 @@ LLVMVisitor::takeModule(const SrcInfo *src) {
func = nullptr;
block = nullptr;
value = nullptr;
for (auto &it : vars)
it.second = nullptr;
for (auto &it : funcs)
it.second = nullptr;
for (auto it = funcs.begin(); it != funcs.end();) {
if (it->second && it->second->hasPrivateLinkage()) {
it = funcs.erase(it);
} else {
it->second = nullptr;
++it;
}
}
for (auto it = vars.begin(); it != vars.end();) {
if (it->second && !llvm::isa<llvm::GlobalValue>(it->second)) {
it = vars.erase(it);
} else {
it->second = nullptr;
++it;
}
}
coro.reset();
loops.clear();
trycatch.clear();
@ -824,7 +847,7 @@ void LLVMVisitor::visit(const InternalFunc *x) {
auto *funcType = cast<FuncType>(x->getType());
std::vector<Type *> argTypes(funcType->begin(), funcType->end());
func->setLinkage(getDefaultLinkage());
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
std::vector<llvm::Value *> args;
for (auto it = func->arg_begin(); it != func->arg_end(); ++it) {
@ -987,7 +1010,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
seqassert(!fail, "linking failed");
func = M->getFunction(getNameForFunction(x));
seqassert(func, "function not linked in");
func->setLinkage(getDefaultLinkage());
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
func->setSubprogram(getDISubprogramForFunc(x));
@ -1011,10 +1034,11 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
setDebugInfoForNode(x);
auto *fnAttributes = x->getAttribute<KeyValueAttribute>();
if (fnAttributes && fnAttributes->has("std.internal.attributes.export")) {
if (x->isJIT() ||
(fnAttributes && fnAttributes->has("std.internal.attributes.export"))) {
func->setLinkage(llvm::GlobalValue::ExternalLinkage);
} else {
func->setLinkage(getDefaultLinkage());
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
}
if (fnAttributes && fnAttributes->has("std.internal.attributes.inline")) {
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);

View File

@ -275,11 +275,6 @@ public:
/// @param var the global variable (or function) to register
void registerGlobal(const Var *var);
/// Processes new globals that were not previously
/// compiled. Used in JIT mode.
/// @param module the IR module
void processNewGlobals(Module *module);
/// Returns the default LLVM linkage type for the module.
/// @return LLVM linkage type
llvm::GlobalValue::LinkageTypes getDefaultLinkage();
@ -295,10 +290,11 @@ public:
/// Returns the current module/LLVM context and replaces them
/// with new, fresh ones. References to variables or functions
/// from the old module will be included as "external".
/// @param module the IR module
/// @param src source information for the new module
/// @return the current module/context, replaced internally
std::pair<std::unique_ptr<llvm::Module>, std::unique_ptr<llvm::LLVMContext>>
takeModule(const SrcInfo *src = nullptr);
takeModule(Module *module, const SrcInfo *src = nullptr);
/// Sets current debug info based on a given node.
/// @param node the node whose debug info to use