1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Fix JIT engine

This commit is contained in:
A. R. Shajii 2021-10-24 13:51:58 -04:00
parent f17c8d953c
commit c2dfcf3e7d
10 changed files with 78 additions and 60 deletions

View File

@ -104,13 +104,13 @@ JIT::JIT(ir::Module *module)
void JIT::init() {
module->accept(*llvisitor);
auto pair = llvisitor->takeModule();
auto rt = engine->getMainJITDylib().createResourceTracker();
//auto rt = engine->getMainJITDylib().createResourceTracker();
llvm::cantFail(
engine->addModule({std::move(std::get<1>(pair)), std::move(std::get<0>(pair))}));
engine->addModule({std::move(pair.second), std::move(pair.first)}));
auto func = llvm::cantFail(engine->lookup("main"));
auto *main = (MainFunc *)func.getAddress();
(*main)(0, nullptr);
llvm::cantFail(rt->remove());
//llvm::cantFail(rt->remove());
}
void JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
@ -120,13 +120,14 @@ void JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
llvisitor->registerGlobal(var);
input->accept(*llvisitor);
auto pair = llvisitor->takeModule();
auto rt = engine->getMainJITDylib().createResourceTracker();
//auto rt = engine->getMainJITDylib().createResourceTracker();
llvm::StripDebugInfo(*pair.second);
llvm::cantFail(
engine->addModule({std::move(std::get<1>(pair)), std::move(std::get<0>(pair))}));
engine->addModule({std::move(pair.second), std::move(pair.first)}));
auto func = llvm::cantFail(engine->lookup(name));
auto *repl = (InputFunc *)func.getAddress();
(*repl)();
llvm::cantFail(rt->remove());
//llvm::cantFail(rt->remove());
}
} // namespace jit

View File

@ -31,6 +31,7 @@ ir::Func *TranslateVisitor::apply(std::shared_ptr<Cache> cache, StmtPtr stmts) {
auto irType = cache->module->unsafeGetFuncType(
fnName, cache->classes["void"].realizations["void"]->ir, {}, false);
main->realize(irType, {});
main->setJIT();
} else {
main = cast<ir::BodiedFunc>(cache->module->getMainFunc());
char buf[PATH_MAX + 1];

View File

@ -220,8 +220,6 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
r->ir = ctx->cache->module->Nr<ir::ExternalFunc>(type->realizedName());
} else {
r->ir = ctx->cache->module->Nr<ir::BodiedFunc>(type->realizedName());
if (ast->attributes.has(Attr::ForceRealize))
ir::cast<ir::BodiedFunc>(r->ir)->setBuiltin();
}
auto parent = type->funcParent;

View File

@ -88,8 +88,8 @@ private:
std::list<Var *> symbols;
/// the function body
Value *body = nullptr;
/// whether the function is builtin
bool builtin = false;
/// whether the function is a JIT input
bool jit = false;
public:
static const char NodeId;
@ -136,11 +136,11 @@ public:
/// @param b the new body
void setBody(Flow *b) { body = b; }
/// @return true if the function is builtin
bool isBuiltin() const { return builtin; }
/// Changes the function's builtin status.
/// @param v true if builtin, false otherwise
void setBuiltin(bool v = true) { builtin = v; }
/// @return true if the function is a JIT input
bool isJIT() const { return jit; }
/// Changes the function's JIT input status.
/// @param v true if JIT input, false otherwise
void setJIT(bool v = true) { jit = v; }
protected:
std::vector<Value *> doGetUsedValues() const override {

View File

@ -29,8 +29,8 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags)
: util::ConstVisitor(), context(std::make_unique<llvm::LLVMContext>()), M(),
B(std::make_unique<llvm::IRBuilder<>>(*context)), func(nullptr), block(nullptr),
value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
moduleId(0), B(std::make_unique<llvm::IRBuilder<>>(*context)), func(nullptr),
block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
db(debug, jit, flags), plugins(nullptr) {
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
@ -75,22 +75,27 @@ LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags)
llvm::initializeTypePromotionPass(registry);
}
llvm::GlobalValue::LinkageTypes LLVMVisitor::getDefaultLinkage() {
return db.jit ? llvm::GlobalValue::ExternalLinkage
: llvm::GlobalValue::PrivateLinkage;
}
void LLVMVisitor::registerGlobal(const Var *var) {
if (!var->isGlobal())
return;
if (auto *f = cast<Func>(var)) {
makeLLVMFunction(f);
funcs.insert(f, func);
insertFunc(f, func);
} else {
llvm::Type *llvmType = getLLVMType(var->getType());
if (llvmType->isVoidTy()) {
vars.insert(var, getDummyVoidValue());
insertVar(var, getDummyVoidValue());
} else {
auto *storage = new llvm::GlobalVariable(
*M, llvmType, /*isConstant=*/false, llvm::GlobalVariable::PrivateLinkage,
*M, llvmType, /*isConstant=*/false, getDefaultLinkage(),
llvm::Constant::getNullValue(llvmType), var->getName());
vars.insert(var, storage);
insertVar(var, storage);
// debug info
auto *srcInfo = getSrcInfo(var);
@ -107,24 +112,18 @@ void LLVMVisitor::registerGlobal(const Var *var) {
}
llvm::Value *LLVMVisitor::getVar(const Var *var) {
llvm::Value *val = vars[var];
std::pair<llvm::Value *, int64_t> p = vars.get(var);
if (db.jit && var->isGlobal()) {
if (val) {
llvm::Module *m = nullptr;
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val))
m = x->getModule();
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val))
m = x->getParent();
if (m != M.get()) {
// see if it's in the M already
if (auto *val = p.first) {
if (p.second != moduleId) {
// see if it's in the module already
auto name = var->getName();
if (auto *global = M->getNamedValue(name))
return global;
llvm::Type *llvmType = getLLVMType(var->getType());
auto *storage = new llvm::GlobalVariable(*M, llvmType, /*isConstant=*/false,
llvm::GlobalVariable::ExternalLinkage,
llvm::GlobalValue::ExternalLinkage,
/*Initializer=*/nullptr, name);
storage->setExternallyInitialized(true);
@ -138,7 +137,7 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) {
getDIType(var->getType()),
/*IsLocalToUnit=*/true);
storage->addDebugInfo(debugVar);
vars.insert(var, storage);
insertVar(var, storage);
return storage;
}
} else {
@ -146,15 +145,15 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) {
return vars[var];
}
}
return val;
return p.first;
}
llvm::Function *LLVMVisitor::getFunc(const Func *func) {
llvm::Function *f = funcs[func];
std::pair<llvm::Function *, int64_t> p = funcs.get(func);
if (db.jit) {
if (f) {
if (f->getParent() != M.get()) {
// see if it's in the M already
if (auto *f = p.first) {
if (p.second != moduleId) {
// see if it's in the module already
if (auto *g = M->getFunction(f->getName()))
return g;
@ -162,7 +161,7 @@ llvm::Function *LLVMVisitor::getFunc(const Func *func) {
llvm::Function::ExternalLinkage, f->getName(),
M.get());
g->copyAttributesFrom(f);
funcs.insert(func, g);
insertFunc(func, g);
return g;
}
} else {
@ -170,7 +169,7 @@ llvm::Function *LLVMVisitor::getFunc(const Func *func) {
return funcs[func];
}
}
return f;
return p.first;
}
std::unique_ptr<llvm::Module> LLVMVisitor::makeModule(llvm::LLVMContext &context,
@ -205,6 +204,7 @@ LLVMVisitor::takeModule(const SrcInfo *src) {
auto currentModule = std::move(M);
context = std::make_unique<llvm::LLVMContext>();
M = makeModule(*context, src);
++moduleId;
return {std::move(currentContext), std::move(currentModule)};
}
@ -599,7 +599,7 @@ void LLVMVisitor::visit(const Module *x) {
B->CreateBr(loopBlock);
B->SetInsertPoint(exitBlock);
llvm::Value *argStorage = vars[x->getArgVar()];
llvm::Value *argStorage = getVar(x->getArgVar());
seqassert(argStorage, "argument storage missing");
B->CreateStore(arr, argStorage);
B->CreateCall(initFunc, B->getInt32(db.debug ? 1 : 0));
@ -752,7 +752,7 @@ void LLVMVisitor::visit(const InternalFunc *x) {
auto *funcType = cast<FuncType>(x->getType());
std::vector<Type *> argTypes(funcType->begin(), funcType->end());
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
func->setLinkage(getDefaultLinkage());
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
std::vector<llvm::Value *> args;
for (auto it = func->arg_begin(); it != func->arg_end(); ++it) {
@ -915,7 +915,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
seqassert(!fail, "linking failed");
func = M->getFunction(getNameForFunction(x));
seqassert(func, "function not linked in");
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
func->setLinkage(getDefaultLinkage());
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
}
@ -929,7 +929,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
if (fnAttributes && fnAttributes->has("std.internal.attributes.export")) {
func->setLinkage(llvm::GlobalValue::ExternalLinkage);
} else {
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
func->setLinkage(getDefaultLinkage());
}
if (fnAttributes && fnAttributes->has("std.internal.attributes.inline")) {
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
@ -956,7 +956,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
const Var *var = *varIter;
llvm::Value *storage = B->CreateAlloca(getLLVMType(var->getType()));
B->CreateStore(argIter, storage);
vars.insert(var, storage);
insertVar(var, storage);
// debug info
auto *srcInfo = getSrcInfo(var);
@ -977,10 +977,10 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
for (auto *var : *x) {
llvm::Type *llvmType = getLLVMType(var->getType());
if (llvmType->isVoidTy()) {
vars.insert(var, getDummyVoidValue());
insertVar(var, getDummyVoidValue());
} else {
llvm::Value *storage = B->CreateAlloca(llvmType);
vars.insert(var, storage);
insertVar(var, storage);
// debug info
auto *srcInfo = getSrcInfo(var);

View File

@ -15,17 +15,22 @@ namespace ir {
class LLVMVisitor : public util::ConstVisitor {
private:
template <typename V> using CacheBase = std::unordered_map<id_t, V *>;
template <typename V>
using CacheBase = std::unordered_map<id_t, std::pair<V *, int64_t>>;
template <typename K, typename V> class Cache : public CacheBase<V> {
public:
using CacheBase<V>::CacheBase;
V *operator[](const K *key) {
std::pair<V *, int64_t> get(const K *key) {
auto it = CacheBase<V>::find(key->getId());
return (it != CacheBase<V>::end()) ? it->second : nullptr;
return (it != CacheBase<V>::end()) ? it->second : std::make_pair(nullptr, 0);
}
void insert(const K *key, V *value) { CacheBase<V>::emplace(key->getId(), value); }
V *operator[](const K *key) { return get(key).first; }
void insert(const K *key, V *value, int64_t id) {
CacheBase<V>::emplace(key->getId(), std::make_pair(value, id));
}
};
struct CoroData {
@ -113,6 +118,8 @@ private:
std::unique_ptr<llvm::LLVMContext> context;
/// Module we are compiling
std::unique_ptr<llvm::Module> M;
/// Module ID
int64_t moduleId;
/// LLVM IR builder used for constructing LLVM IR
std::unique_ptr<llvm::IRBuilder<>> B;
/// Current function we are compiling
@ -180,12 +187,19 @@ private:
void runLLVMPipeline();
llvm::Value *getVar(const Var *var);
void insertVar(const Var *var, llvm::Value *x) { vars.insert(var, x, moduleId); }
llvm::Function *getFunc(const Func *func);
void insertFunc(const Func *func, llvm::Function *x) {
funcs.insert(func, x, moduleId);
}
llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(*context); }
public:
static std::string getNameForFunction(const Func *x) {
if (auto *externalFunc = cast<ExternalFunc>(x)) {
auto *bodiedFunc = cast<BodiedFunc>(x);
if (bodiedFunc && bodiedFunc->isJIT()) {
return x->getName();
} else if (isA<ExternalFunc>(x)) {
return x->getUnmangledName();
} else {
return x->referenceString();
@ -244,6 +258,10 @@ public:
/// @param var the global variable (or function) to register
void registerGlobal(const Var *var);
/// Returns the default LLVM linkage type for the module.
/// @return LLVM linkage type
llvm::GlobalValue::LinkageTypes getDefaultLinkage();
/// Returns a new LLVM module initialized for the host
/// architecture.
/// @param context LLVM context used for creating module

View File

@ -35,7 +35,7 @@ void CloneVisitor::visit(const BodiedFunc *v) {
if (v->getBody())
res->setBody(clone(v->getBody()));
res->setBuiltin(v->isBuiltin());
res->setJIT(v->isJIT());
result = res;
}

View File

@ -47,7 +47,7 @@ public:
result = compareFuncs(x, y) &&
std::equal(x->begin(), x->end(), y->begin(), y->end(),
[this](auto *x, auto *y) { return process(x, y); }) &&
process(x->getBody(), y->getBody()) && x->isBuiltin() == y->isBuiltin();
process(x->getBody(), y->getBody()) && x->isJIT() == y->isJIT();
}
VISIT(ExternalFunc);
void handle(const ExternalFunc *x, const ExternalFunc *y) {

View File

@ -32,8 +32,8 @@ TEST_F(SIRCoreTest, FuncRealizationAndVarInsertionEraseAndIterators) {
TEST_F(SIRCoreTest, BodiedFuncQueryAndReplace) {
auto *fn = module->Nr<BodiedFunc>();
fn->realize(module->unsafeGetDummyFuncType(), {});
fn->setBuiltin();
ASSERT_TRUE(fn->isBuiltin());
fn->setJIT();
ASSERT_TRUE(fn->isJIT());
auto *body = fn->getBody();
ASSERT_FALSE(body);
@ -63,7 +63,7 @@ TEST_F(SIRCoreTest, BodiedFuncCloning) {
auto *fn = module->Nr<BodiedFunc>("fn");
fn->realize(module->unsafeGetDummyFuncType(), {});
fn->setBuiltin();
fn->setJIT();
fn->setBody(module->Nr<SeriesFlow>());
ASSERT_TRUE(util::match(fn, cv->clone(fn)));
}

View File

@ -25,8 +25,8 @@ TEST_F(SIRCoreTest, MatchingEquivalentFunc) {
auto *second = module->Nr<BodiedFunc>();
second->realize(module->unsafeGetDummyFuncType(), {});
first->setBuiltin();
second->setBuiltin();
first->setJIT();
second->setJIT();
ASSERT_TRUE(util::match(first, second));
}
@ -58,7 +58,7 @@ TEST_F(SIRCoreTest, MatchingNonEquivalentFunc) {
auto *second = module->Nr<BodiedFunc>();
second->realize(module->unsafeGetDummyFuncType(), {});
first->setBuiltin();
first->setJIT();
ASSERT_FALSE(util::match(first, second));
}