From c2dfcf3e7d27b5a97c3327af73fdc93c0ef49307 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Sun, 24 Oct 2021 13:51:58 -0400 Subject: [PATCH] Fix JIT engine --- codon/jit/engine.cpp | 13 ++-- codon/parser/visitors/translate/translate.cpp | 1 + .../visitors/typecheck/typecheck_infer.cpp | 2 - codon/sir/func.h | 14 ++-- codon/sir/llvm/llvisitor.cpp | 64 +++++++++---------- codon/sir/llvm/llvisitor.h | 28 ++++++-- codon/sir/util/cloning.cpp | 2 +- codon/sir/util/matching.cpp | 2 +- test/sir/func.cpp | 6 +- test/sir/util/matching.cpp | 6 +- 10 files changed, 78 insertions(+), 60 deletions(-) diff --git a/codon/jit/engine.cpp b/codon/jit/engine.cpp index 62a94995..695c2ced 100644 --- a/codon/jit/engine.cpp +++ b/codon/jit/engine.cpp @@ -104,13 +104,13 @@ JIT::JIT(ir::Module *module) void JIT::init() { module->accept(*llvisitor); auto pair = llvisitor->takeModule(); - auto rt = engine->getMainJITDylib().createResourceTracker(); + //auto rt = engine->getMainJITDylib().createResourceTracker(); llvm::cantFail( - engine->addModule({std::move(std::get<1>(pair)), std::move(std::get<0>(pair))})); + engine->addModule({std::move(pair.second), std::move(pair.first)})); auto func = llvm::cantFail(engine->lookup("main")); auto *main = (MainFunc *)func.getAddress(); (*main)(0, nullptr); - llvm::cantFail(rt->remove()); + //llvm::cantFail(rt->remove()); } void JIT::run(const ir::Func *input, const std::vector &newGlobals) { @@ -120,13 +120,14 @@ void JIT::run(const ir::Func *input, const std::vector &newGlobals) { llvisitor->registerGlobal(var); input->accept(*llvisitor); auto pair = llvisitor->takeModule(); - auto rt = engine->getMainJITDylib().createResourceTracker(); + //auto rt = engine->getMainJITDylib().createResourceTracker(); + llvm::StripDebugInfo(*pair.second); llvm::cantFail( - engine->addModule({std::move(std::get<1>(pair)), std::move(std::get<0>(pair))})); + engine->addModule({std::move(pair.second), std::move(pair.first)})); auto func = llvm::cantFail(engine->lookup(name)); auto *repl = (InputFunc *)func.getAddress(); (*repl)(); - llvm::cantFail(rt->remove()); + //llvm::cantFail(rt->remove()); } } // namespace jit diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 7cd40a6d..b85ad39c 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -31,6 +31,7 @@ ir::Func *TranslateVisitor::apply(std::shared_ptr cache, StmtPtr stmts) { auto irType = cache->module->unsafeGetFuncType( fnName, cache->classes["void"].realizations["void"]->ir, {}, false); main->realize(irType, {}); + main->setJIT(); } else { main = cast(cache->module->getMainFunc()); char buf[PATH_MAX + 1]; diff --git a/codon/parser/visitors/typecheck/typecheck_infer.cpp b/codon/parser/visitors/typecheck/typecheck_infer.cpp index ecbde243..8cbf8d47 100644 --- a/codon/parser/visitors/typecheck/typecheck_infer.cpp +++ b/codon/parser/visitors/typecheck/typecheck_infer.cpp @@ -220,8 +220,6 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { r->ir = ctx->cache->module->Nr(type->realizedName()); } else { r->ir = ctx->cache->module->Nr(type->realizedName()); - if (ast->attributes.has(Attr::ForceRealize)) - ir::cast(r->ir)->setBuiltin(); } auto parent = type->funcParent; diff --git a/codon/sir/func.h b/codon/sir/func.h index c8cecfc9..967ae362 100644 --- a/codon/sir/func.h +++ b/codon/sir/func.h @@ -88,8 +88,8 @@ private: std::list symbols; /// the function body Value *body = nullptr; - /// whether the function is builtin - bool builtin = false; + /// whether the function is a JIT input + bool jit = false; public: static const char NodeId; @@ -136,11 +136,11 @@ public: /// @param b the new body void setBody(Flow *b) { body = b; } - /// @return true if the function is builtin - bool isBuiltin() const { return builtin; } - /// Changes the function's builtin status. - /// @param v true if builtin, false otherwise - void setBuiltin(bool v = true) { builtin = v; } + /// @return true if the function is a JIT input + bool isJIT() const { return jit; } + /// Changes the function's JIT input status. + /// @param v true if JIT input, false otherwise + void setJIT(bool v = true) { jit = v; } protected: std::vector doGetUsedValues() const override { diff --git a/codon/sir/llvm/llvisitor.cpp b/codon/sir/llvm/llvisitor.cpp index 629a15e6..4d427965 100644 --- a/codon/sir/llvm/llvisitor.cpp +++ b/codon/sir/llvm/llvisitor.cpp @@ -29,8 +29,8 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags) : util::ConstVisitor(), context(std::make_unique()), M(), - B(std::make_unique>(*context)), func(nullptr), block(nullptr), - value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), + moduleId(0), B(std::make_unique>(*context)), func(nullptr), + block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), db(debug, jit, flags), plugins(nullptr) { llvm::InitializeAllTargets(); llvm::InitializeAllTargetMCs(); @@ -75,22 +75,27 @@ LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags) llvm::initializeTypePromotionPass(registry); } +llvm::GlobalValue::LinkageTypes LLVMVisitor::getDefaultLinkage() { + return db.jit ? llvm::GlobalValue::ExternalLinkage + : llvm::GlobalValue::PrivateLinkage; +} + void LLVMVisitor::registerGlobal(const Var *var) { if (!var->isGlobal()) return; if (auto *f = cast(var)) { makeLLVMFunction(f); - funcs.insert(f, func); + insertFunc(f, func); } else { llvm::Type *llvmType = getLLVMType(var->getType()); if (llvmType->isVoidTy()) { - vars.insert(var, getDummyVoidValue()); + insertVar(var, getDummyVoidValue()); } else { auto *storage = new llvm::GlobalVariable( - *M, llvmType, /*isConstant=*/false, llvm::GlobalVariable::PrivateLinkage, + *M, llvmType, /*isConstant=*/false, getDefaultLinkage(), llvm::Constant::getNullValue(llvmType), var->getName()); - vars.insert(var, storage); + insertVar(var, storage); // debug info auto *srcInfo = getSrcInfo(var); @@ -107,24 +112,18 @@ void LLVMVisitor::registerGlobal(const Var *var) { } llvm::Value *LLVMVisitor::getVar(const Var *var) { - llvm::Value *val = vars[var]; + std::pair p = vars.get(var); if (db.jit && var->isGlobal()) { - if (val) { - llvm::Module *m = nullptr; - if (auto *x = llvm::dyn_cast(val)) - m = x->getModule(); - else if (auto *x = llvm::dyn_cast(val)) - m = x->getParent(); - - if (m != M.get()) { - // see if it's in the M already + if (auto *val = p.first) { + if (p.second != moduleId) { + // see if it's in the module already auto name = var->getName(); if (auto *global = M->getNamedValue(name)) return global; llvm::Type *llvmType = getLLVMType(var->getType()); auto *storage = new llvm::GlobalVariable(*M, llvmType, /*isConstant=*/false, - llvm::GlobalVariable::ExternalLinkage, + llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, name); storage->setExternallyInitialized(true); @@ -138,7 +137,7 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) { getDIType(var->getType()), /*IsLocalToUnit=*/true); storage->addDebugInfo(debugVar); - vars.insert(var, storage); + insertVar(var, storage); return storage; } } else { @@ -146,15 +145,15 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) { return vars[var]; } } - return val; + return p.first; } llvm::Function *LLVMVisitor::getFunc(const Func *func) { - llvm::Function *f = funcs[func]; + std::pair p = funcs.get(func); if (db.jit) { - if (f) { - if (f->getParent() != M.get()) { - // see if it's in the M already + if (auto *f = p.first) { + if (p.second != moduleId) { + // see if it's in the module already if (auto *g = M->getFunction(f->getName())) return g; @@ -162,7 +161,7 @@ llvm::Function *LLVMVisitor::getFunc(const Func *func) { llvm::Function::ExternalLinkage, f->getName(), M.get()); g->copyAttributesFrom(f); - funcs.insert(func, g); + insertFunc(func, g); return g; } } else { @@ -170,7 +169,7 @@ llvm::Function *LLVMVisitor::getFunc(const Func *func) { return funcs[func]; } } - return f; + return p.first; } std::unique_ptr LLVMVisitor::makeModule(llvm::LLVMContext &context, @@ -205,6 +204,7 @@ LLVMVisitor::takeModule(const SrcInfo *src) { auto currentModule = std::move(M); context = std::make_unique(); M = makeModule(*context, src); + ++moduleId; return {std::move(currentContext), std::move(currentModule)}; } @@ -599,7 +599,7 @@ void LLVMVisitor::visit(const Module *x) { B->CreateBr(loopBlock); B->SetInsertPoint(exitBlock); - llvm::Value *argStorage = vars[x->getArgVar()]; + llvm::Value *argStorage = getVar(x->getArgVar()); seqassert(argStorage, "argument storage missing"); B->CreateStore(arr, argStorage); B->CreateCall(initFunc, B->getInt32(db.debug ? 1 : 0)); @@ -752,7 +752,7 @@ void LLVMVisitor::visit(const InternalFunc *x) { auto *funcType = cast(x->getType()); std::vector argTypes(funcType->begin(), funcType->end()); - func->setLinkage(llvm::GlobalValue::PrivateLinkage); + func->setLinkage(getDefaultLinkage()); func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); std::vector args; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { @@ -915,7 +915,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) { seqassert(!fail, "linking failed"); func = M->getFunction(getNameForFunction(x)); seqassert(func, "function not linked in"); - func->setLinkage(llvm::GlobalValue::PrivateLinkage); + func->setLinkage(getDefaultLinkage()); func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); } @@ -929,7 +929,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) { if (fnAttributes && fnAttributes->has("std.internal.attributes.export")) { func->setLinkage(llvm::GlobalValue::ExternalLinkage); } else { - func->setLinkage(llvm::GlobalValue::PrivateLinkage); + func->setLinkage(getDefaultLinkage()); } if (fnAttributes && fnAttributes->has("std.internal.attributes.inline")) { func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); @@ -956,7 +956,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) { const Var *var = *varIter; llvm::Value *storage = B->CreateAlloca(getLLVMType(var->getType())); B->CreateStore(argIter, storage); - vars.insert(var, storage); + insertVar(var, storage); // debug info auto *srcInfo = getSrcInfo(var); @@ -977,10 +977,10 @@ void LLVMVisitor::visit(const BodiedFunc *x) { for (auto *var : *x) { llvm::Type *llvmType = getLLVMType(var->getType()); if (llvmType->isVoidTy()) { - vars.insert(var, getDummyVoidValue()); + insertVar(var, getDummyVoidValue()); } else { llvm::Value *storage = B->CreateAlloca(llvmType); - vars.insert(var, storage); + insertVar(var, storage); // debug info auto *srcInfo = getSrcInfo(var); diff --git a/codon/sir/llvm/llvisitor.h b/codon/sir/llvm/llvisitor.h index ccef8a43..81a62b46 100644 --- a/codon/sir/llvm/llvisitor.h +++ b/codon/sir/llvm/llvisitor.h @@ -15,17 +15,22 @@ namespace ir { class LLVMVisitor : public util::ConstVisitor { private: - template using CacheBase = std::unordered_map; + template + using CacheBase = std::unordered_map>; template class Cache : public CacheBase { public: using CacheBase::CacheBase; - V *operator[](const K *key) { + std::pair get(const K *key) { auto it = CacheBase::find(key->getId()); - return (it != CacheBase::end()) ? it->second : nullptr; + return (it != CacheBase::end()) ? it->second : std::make_pair(nullptr, 0); } - void insert(const K *key, V *value) { CacheBase::emplace(key->getId(), value); } + V *operator[](const K *key) { return get(key).first; } + + void insert(const K *key, V *value, int64_t id) { + CacheBase::emplace(key->getId(), std::make_pair(value, id)); + } }; struct CoroData { @@ -113,6 +118,8 @@ private: std::unique_ptr context; /// Module we are compiling std::unique_ptr M; + /// Module ID + int64_t moduleId; /// LLVM IR builder used for constructing LLVM IR std::unique_ptr> B; /// Current function we are compiling @@ -180,12 +187,19 @@ private: void runLLVMPipeline(); llvm::Value *getVar(const Var *var); + void insertVar(const Var *var, llvm::Value *x) { vars.insert(var, x, moduleId); } llvm::Function *getFunc(const Func *func); + void insertFunc(const Func *func, llvm::Function *x) { + funcs.insert(func, x, moduleId); + } llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(*context); } public: static std::string getNameForFunction(const Func *x) { - if (auto *externalFunc = cast(x)) { + auto *bodiedFunc = cast(x); + if (bodiedFunc && bodiedFunc->isJIT()) { + return x->getName(); + } else if (isA(x)) { return x->getUnmangledName(); } else { return x->referenceString(); @@ -244,6 +258,10 @@ public: /// @param var the global variable (or function) to register void registerGlobal(const Var *var); + /// Returns the default LLVM linkage type for the module. + /// @return LLVM linkage type + llvm::GlobalValue::LinkageTypes getDefaultLinkage(); + /// Returns a new LLVM module initialized for the host /// architecture. /// @param context LLVM context used for creating module diff --git a/codon/sir/util/cloning.cpp b/codon/sir/util/cloning.cpp index 74269f28..5acd086a 100644 --- a/codon/sir/util/cloning.cpp +++ b/codon/sir/util/cloning.cpp @@ -35,7 +35,7 @@ void CloneVisitor::visit(const BodiedFunc *v) { if (v->getBody()) res->setBody(clone(v->getBody())); - res->setBuiltin(v->isBuiltin()); + res->setJIT(v->isJIT()); result = res; } diff --git a/codon/sir/util/matching.cpp b/codon/sir/util/matching.cpp index 38f6aedf..82e51a9f 100644 --- a/codon/sir/util/matching.cpp +++ b/codon/sir/util/matching.cpp @@ -47,7 +47,7 @@ public: result = compareFuncs(x, y) && std::equal(x->begin(), x->end(), y->begin(), y->end(), [this](auto *x, auto *y) { return process(x, y); }) && - process(x->getBody(), y->getBody()) && x->isBuiltin() == y->isBuiltin(); + process(x->getBody(), y->getBody()) && x->isJIT() == y->isJIT(); } VISIT(ExternalFunc); void handle(const ExternalFunc *x, const ExternalFunc *y) { diff --git a/test/sir/func.cpp b/test/sir/func.cpp index dfa57923..e3615633 100644 --- a/test/sir/func.cpp +++ b/test/sir/func.cpp @@ -32,8 +32,8 @@ TEST_F(SIRCoreTest, FuncRealizationAndVarInsertionEraseAndIterators) { TEST_F(SIRCoreTest, BodiedFuncQueryAndReplace) { auto *fn = module->Nr(); fn->realize(module->unsafeGetDummyFuncType(), {}); - fn->setBuiltin(); - ASSERT_TRUE(fn->isBuiltin()); + fn->setJIT(); + ASSERT_TRUE(fn->isJIT()); auto *body = fn->getBody(); ASSERT_FALSE(body); @@ -63,7 +63,7 @@ TEST_F(SIRCoreTest, BodiedFuncCloning) { auto *fn = module->Nr("fn"); fn->realize(module->unsafeGetDummyFuncType(), {}); - fn->setBuiltin(); + fn->setJIT(); fn->setBody(module->Nr()); ASSERT_TRUE(util::match(fn, cv->clone(fn))); } diff --git a/test/sir/util/matching.cpp b/test/sir/util/matching.cpp index 38257f5f..00aef9ba 100644 --- a/test/sir/util/matching.cpp +++ b/test/sir/util/matching.cpp @@ -25,8 +25,8 @@ TEST_F(SIRCoreTest, MatchingEquivalentFunc) { auto *second = module->Nr(); second->realize(module->unsafeGetDummyFuncType(), {}); - first->setBuiltin(); - second->setBuiltin(); + first->setJIT(); + second->setJIT(); ASSERT_TRUE(util::match(first, second)); } @@ -58,7 +58,7 @@ TEST_F(SIRCoreTest, MatchingNonEquivalentFunc) { auto *second = module->Nr(); second->realize(module->unsafeGetDummyFuncType(), {}); - first->setBuiltin(); + first->setJIT(); ASSERT_FALSE(util::match(first, second)); }