mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
Fix JIT engine
This commit is contained in:
parent
f17c8d953c
commit
c2dfcf3e7d
@ -104,13 +104,13 @@ JIT::JIT(ir::Module *module)
|
||||
void JIT::init() {
|
||||
module->accept(*llvisitor);
|
||||
auto pair = llvisitor->takeModule();
|
||||
auto rt = engine->getMainJITDylib().createResourceTracker();
|
||||
//auto rt = engine->getMainJITDylib().createResourceTracker();
|
||||
llvm::cantFail(
|
||||
engine->addModule({std::move(std::get<1>(pair)), std::move(std::get<0>(pair))}));
|
||||
engine->addModule({std::move(pair.second), std::move(pair.first)}));
|
||||
auto func = llvm::cantFail(engine->lookup("main"));
|
||||
auto *main = (MainFunc *)func.getAddress();
|
||||
(*main)(0, nullptr);
|
||||
llvm::cantFail(rt->remove());
|
||||
//llvm::cantFail(rt->remove());
|
||||
}
|
||||
|
||||
void JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
|
||||
@ -120,13 +120,14 @@ void JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
|
||||
llvisitor->registerGlobal(var);
|
||||
input->accept(*llvisitor);
|
||||
auto pair = llvisitor->takeModule();
|
||||
auto rt = engine->getMainJITDylib().createResourceTracker();
|
||||
//auto rt = engine->getMainJITDylib().createResourceTracker();
|
||||
llvm::StripDebugInfo(*pair.second);
|
||||
llvm::cantFail(
|
||||
engine->addModule({std::move(std::get<1>(pair)), std::move(std::get<0>(pair))}));
|
||||
engine->addModule({std::move(pair.second), std::move(pair.first)}));
|
||||
auto func = llvm::cantFail(engine->lookup(name));
|
||||
auto *repl = (InputFunc *)func.getAddress();
|
||||
(*repl)();
|
||||
llvm::cantFail(rt->remove());
|
||||
//llvm::cantFail(rt->remove());
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
@ -31,6 +31,7 @@ ir::Func *TranslateVisitor::apply(std::shared_ptr<Cache> cache, StmtPtr stmts) {
|
||||
auto irType = cache->module->unsafeGetFuncType(
|
||||
fnName, cache->classes["void"].realizations["void"]->ir, {}, false);
|
||||
main->realize(irType, {});
|
||||
main->setJIT();
|
||||
} else {
|
||||
main = cast<ir::BodiedFunc>(cache->module->getMainFunc());
|
||||
char buf[PATH_MAX + 1];
|
||||
|
@ -220,8 +220,6 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
|
||||
r->ir = ctx->cache->module->Nr<ir::ExternalFunc>(type->realizedName());
|
||||
} else {
|
||||
r->ir = ctx->cache->module->Nr<ir::BodiedFunc>(type->realizedName());
|
||||
if (ast->attributes.has(Attr::ForceRealize))
|
||||
ir::cast<ir::BodiedFunc>(r->ir)->setBuiltin();
|
||||
}
|
||||
|
||||
auto parent = type->funcParent;
|
||||
|
@ -88,8 +88,8 @@ private:
|
||||
std::list<Var *> symbols;
|
||||
/// the function body
|
||||
Value *body = nullptr;
|
||||
/// whether the function is builtin
|
||||
bool builtin = false;
|
||||
/// whether the function is a JIT input
|
||||
bool jit = false;
|
||||
|
||||
public:
|
||||
static const char NodeId;
|
||||
@ -136,11 +136,11 @@ public:
|
||||
/// @param b the new body
|
||||
void setBody(Flow *b) { body = b; }
|
||||
|
||||
/// @return true if the function is builtin
|
||||
bool isBuiltin() const { return builtin; }
|
||||
/// Changes the function's builtin status.
|
||||
/// @param v true if builtin, false otherwise
|
||||
void setBuiltin(bool v = true) { builtin = v; }
|
||||
/// @return true if the function is a JIT input
|
||||
bool isJIT() const { return jit; }
|
||||
/// Changes the function's JIT input status.
|
||||
/// @param v true if JIT input, false otherwise
|
||||
void setJIT(bool v = true) { jit = v; }
|
||||
|
||||
protected:
|
||||
std::vector<Value *> doGetUsedValues() const override {
|
||||
|
@ -29,8 +29,8 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
|
||||
|
||||
LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags)
|
||||
: util::ConstVisitor(), context(std::make_unique<llvm::LLVMContext>()), M(),
|
||||
B(std::make_unique<llvm::IRBuilder<>>(*context)), func(nullptr), block(nullptr),
|
||||
value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
|
||||
moduleId(0), B(std::make_unique<llvm::IRBuilder<>>(*context)), func(nullptr),
|
||||
block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
|
||||
db(debug, jit, flags), plugins(nullptr) {
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
@ -75,22 +75,27 @@ LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags)
|
||||
llvm::initializeTypePromotionPass(registry);
|
||||
}
|
||||
|
||||
llvm::GlobalValue::LinkageTypes LLVMVisitor::getDefaultLinkage() {
|
||||
return db.jit ? llvm::GlobalValue::ExternalLinkage
|
||||
: llvm::GlobalValue::PrivateLinkage;
|
||||
}
|
||||
|
||||
void LLVMVisitor::registerGlobal(const Var *var) {
|
||||
if (!var->isGlobal())
|
||||
return;
|
||||
|
||||
if (auto *f = cast<Func>(var)) {
|
||||
makeLLVMFunction(f);
|
||||
funcs.insert(f, func);
|
||||
insertFunc(f, func);
|
||||
} else {
|
||||
llvm::Type *llvmType = getLLVMType(var->getType());
|
||||
if (llvmType->isVoidTy()) {
|
||||
vars.insert(var, getDummyVoidValue());
|
||||
insertVar(var, getDummyVoidValue());
|
||||
} else {
|
||||
auto *storage = new llvm::GlobalVariable(
|
||||
*M, llvmType, /*isConstant=*/false, llvm::GlobalVariable::PrivateLinkage,
|
||||
*M, llvmType, /*isConstant=*/false, getDefaultLinkage(),
|
||||
llvm::Constant::getNullValue(llvmType), var->getName());
|
||||
vars.insert(var, storage);
|
||||
insertVar(var, storage);
|
||||
|
||||
// debug info
|
||||
auto *srcInfo = getSrcInfo(var);
|
||||
@ -107,24 +112,18 @@ void LLVMVisitor::registerGlobal(const Var *var) {
|
||||
}
|
||||
|
||||
llvm::Value *LLVMVisitor::getVar(const Var *var) {
|
||||
llvm::Value *val = vars[var];
|
||||
std::pair<llvm::Value *, int64_t> p = vars.get(var);
|
||||
if (db.jit && var->isGlobal()) {
|
||||
if (val) {
|
||||
llvm::Module *m = nullptr;
|
||||
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val))
|
||||
m = x->getModule();
|
||||
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val))
|
||||
m = x->getParent();
|
||||
|
||||
if (m != M.get()) {
|
||||
// see if it's in the M already
|
||||
if (auto *val = p.first) {
|
||||
if (p.second != moduleId) {
|
||||
// see if it's in the module already
|
||||
auto name = var->getName();
|
||||
if (auto *global = M->getNamedValue(name))
|
||||
return global;
|
||||
|
||||
llvm::Type *llvmType = getLLVMType(var->getType());
|
||||
auto *storage = new llvm::GlobalVariable(*M, llvmType, /*isConstant=*/false,
|
||||
llvm::GlobalVariable::ExternalLinkage,
|
||||
llvm::GlobalValue::ExternalLinkage,
|
||||
/*Initializer=*/nullptr, name);
|
||||
storage->setExternallyInitialized(true);
|
||||
|
||||
@ -138,7 +137,7 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) {
|
||||
getDIType(var->getType()),
|
||||
/*IsLocalToUnit=*/true);
|
||||
storage->addDebugInfo(debugVar);
|
||||
vars.insert(var, storage);
|
||||
insertVar(var, storage);
|
||||
return storage;
|
||||
}
|
||||
} else {
|
||||
@ -146,15 +145,15 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) {
|
||||
return vars[var];
|
||||
}
|
||||
}
|
||||
return val;
|
||||
return p.first;
|
||||
}
|
||||
|
||||
llvm::Function *LLVMVisitor::getFunc(const Func *func) {
|
||||
llvm::Function *f = funcs[func];
|
||||
std::pair<llvm::Function *, int64_t> p = funcs.get(func);
|
||||
if (db.jit) {
|
||||
if (f) {
|
||||
if (f->getParent() != M.get()) {
|
||||
// see if it's in the M already
|
||||
if (auto *f = p.first) {
|
||||
if (p.second != moduleId) {
|
||||
// see if it's in the module already
|
||||
if (auto *g = M->getFunction(f->getName()))
|
||||
return g;
|
||||
|
||||
@ -162,7 +161,7 @@ llvm::Function *LLVMVisitor::getFunc(const Func *func) {
|
||||
llvm::Function::ExternalLinkage, f->getName(),
|
||||
M.get());
|
||||
g->copyAttributesFrom(f);
|
||||
funcs.insert(func, g);
|
||||
insertFunc(func, g);
|
||||
return g;
|
||||
}
|
||||
} else {
|
||||
@ -170,7 +169,7 @@ llvm::Function *LLVMVisitor::getFunc(const Func *func) {
|
||||
return funcs[func];
|
||||
}
|
||||
}
|
||||
return f;
|
||||
return p.first;
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> LLVMVisitor::makeModule(llvm::LLVMContext &context,
|
||||
@ -205,6 +204,7 @@ LLVMVisitor::takeModule(const SrcInfo *src) {
|
||||
auto currentModule = std::move(M);
|
||||
context = std::make_unique<llvm::LLVMContext>();
|
||||
M = makeModule(*context, src);
|
||||
++moduleId;
|
||||
return {std::move(currentContext), std::move(currentModule)};
|
||||
}
|
||||
|
||||
@ -599,7 +599,7 @@ void LLVMVisitor::visit(const Module *x) {
|
||||
B->CreateBr(loopBlock);
|
||||
|
||||
B->SetInsertPoint(exitBlock);
|
||||
llvm::Value *argStorage = vars[x->getArgVar()];
|
||||
llvm::Value *argStorage = getVar(x->getArgVar());
|
||||
seqassert(argStorage, "argument storage missing");
|
||||
B->CreateStore(arr, argStorage);
|
||||
B->CreateCall(initFunc, B->getInt32(db.debug ? 1 : 0));
|
||||
@ -752,7 +752,7 @@ void LLVMVisitor::visit(const InternalFunc *x) {
|
||||
auto *funcType = cast<FuncType>(x->getType());
|
||||
std::vector<Type *> argTypes(funcType->begin(), funcType->end());
|
||||
|
||||
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
|
||||
func->setLinkage(getDefaultLinkage());
|
||||
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
|
||||
std::vector<llvm::Value *> args;
|
||||
for (auto it = func->arg_begin(); it != func->arg_end(); ++it) {
|
||||
@ -915,7 +915,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
|
||||
seqassert(!fail, "linking failed");
|
||||
func = M->getFunction(getNameForFunction(x));
|
||||
seqassert(func, "function not linked in");
|
||||
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
|
||||
func->setLinkage(getDefaultLinkage());
|
||||
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
|
||||
}
|
||||
|
||||
@ -929,7 +929,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
|
||||
if (fnAttributes && fnAttributes->has("std.internal.attributes.export")) {
|
||||
func->setLinkage(llvm::GlobalValue::ExternalLinkage);
|
||||
} else {
|
||||
func->setLinkage(llvm::GlobalValue::PrivateLinkage);
|
||||
func->setLinkage(getDefaultLinkage());
|
||||
}
|
||||
if (fnAttributes && fnAttributes->has("std.internal.attributes.inline")) {
|
||||
func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline);
|
||||
@ -956,7 +956,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
|
||||
const Var *var = *varIter;
|
||||
llvm::Value *storage = B->CreateAlloca(getLLVMType(var->getType()));
|
||||
B->CreateStore(argIter, storage);
|
||||
vars.insert(var, storage);
|
||||
insertVar(var, storage);
|
||||
|
||||
// debug info
|
||||
auto *srcInfo = getSrcInfo(var);
|
||||
@ -977,10 +977,10 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
|
||||
for (auto *var : *x) {
|
||||
llvm::Type *llvmType = getLLVMType(var->getType());
|
||||
if (llvmType->isVoidTy()) {
|
||||
vars.insert(var, getDummyVoidValue());
|
||||
insertVar(var, getDummyVoidValue());
|
||||
} else {
|
||||
llvm::Value *storage = B->CreateAlloca(llvmType);
|
||||
vars.insert(var, storage);
|
||||
insertVar(var, storage);
|
||||
|
||||
// debug info
|
||||
auto *srcInfo = getSrcInfo(var);
|
||||
|
@ -15,17 +15,22 @@ namespace ir {
|
||||
|
||||
class LLVMVisitor : public util::ConstVisitor {
|
||||
private:
|
||||
template <typename V> using CacheBase = std::unordered_map<id_t, V *>;
|
||||
template <typename V>
|
||||
using CacheBase = std::unordered_map<id_t, std::pair<V *, int64_t>>;
|
||||
template <typename K, typename V> class Cache : public CacheBase<V> {
|
||||
public:
|
||||
using CacheBase<V>::CacheBase;
|
||||
|
||||
V *operator[](const K *key) {
|
||||
std::pair<V *, int64_t> get(const K *key) {
|
||||
auto it = CacheBase<V>::find(key->getId());
|
||||
return (it != CacheBase<V>::end()) ? it->second : nullptr;
|
||||
return (it != CacheBase<V>::end()) ? it->second : std::make_pair(nullptr, 0);
|
||||
}
|
||||
|
||||
void insert(const K *key, V *value) { CacheBase<V>::emplace(key->getId(), value); }
|
||||
V *operator[](const K *key) { return get(key).first; }
|
||||
|
||||
void insert(const K *key, V *value, int64_t id) {
|
||||
CacheBase<V>::emplace(key->getId(), std::make_pair(value, id));
|
||||
}
|
||||
};
|
||||
|
||||
struct CoroData {
|
||||
@ -113,6 +118,8 @@ private:
|
||||
std::unique_ptr<llvm::LLVMContext> context;
|
||||
/// Module we are compiling
|
||||
std::unique_ptr<llvm::Module> M;
|
||||
/// Module ID
|
||||
int64_t moduleId;
|
||||
/// LLVM IR builder used for constructing LLVM IR
|
||||
std::unique_ptr<llvm::IRBuilder<>> B;
|
||||
/// Current function we are compiling
|
||||
@ -180,12 +187,19 @@ private:
|
||||
void runLLVMPipeline();
|
||||
|
||||
llvm::Value *getVar(const Var *var);
|
||||
void insertVar(const Var *var, llvm::Value *x) { vars.insert(var, x, moduleId); }
|
||||
llvm::Function *getFunc(const Func *func);
|
||||
void insertFunc(const Func *func, llvm::Function *x) {
|
||||
funcs.insert(func, x, moduleId);
|
||||
}
|
||||
llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(*context); }
|
||||
|
||||
public:
|
||||
static std::string getNameForFunction(const Func *x) {
|
||||
if (auto *externalFunc = cast<ExternalFunc>(x)) {
|
||||
auto *bodiedFunc = cast<BodiedFunc>(x);
|
||||
if (bodiedFunc && bodiedFunc->isJIT()) {
|
||||
return x->getName();
|
||||
} else if (isA<ExternalFunc>(x)) {
|
||||
return x->getUnmangledName();
|
||||
} else {
|
||||
return x->referenceString();
|
||||
@ -244,6 +258,10 @@ public:
|
||||
/// @param var the global variable (or function) to register
|
||||
void registerGlobal(const Var *var);
|
||||
|
||||
/// Returns the default LLVM linkage type for the module.
|
||||
/// @return LLVM linkage type
|
||||
llvm::GlobalValue::LinkageTypes getDefaultLinkage();
|
||||
|
||||
/// Returns a new LLVM module initialized for the host
|
||||
/// architecture.
|
||||
/// @param context LLVM context used for creating module
|
||||
|
@ -35,7 +35,7 @@ void CloneVisitor::visit(const BodiedFunc *v) {
|
||||
|
||||
if (v->getBody())
|
||||
res->setBody(clone(v->getBody()));
|
||||
res->setBuiltin(v->isBuiltin());
|
||||
res->setJIT(v->isJIT());
|
||||
result = res;
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ public:
|
||||
result = compareFuncs(x, y) &&
|
||||
std::equal(x->begin(), x->end(), y->begin(), y->end(),
|
||||
[this](auto *x, auto *y) { return process(x, y); }) &&
|
||||
process(x->getBody(), y->getBody()) && x->isBuiltin() == y->isBuiltin();
|
||||
process(x->getBody(), y->getBody()) && x->isJIT() == y->isJIT();
|
||||
}
|
||||
VISIT(ExternalFunc);
|
||||
void handle(const ExternalFunc *x, const ExternalFunc *y) {
|
||||
|
@ -32,8 +32,8 @@ TEST_F(SIRCoreTest, FuncRealizationAndVarInsertionEraseAndIterators) {
|
||||
TEST_F(SIRCoreTest, BodiedFuncQueryAndReplace) {
|
||||
auto *fn = module->Nr<BodiedFunc>();
|
||||
fn->realize(module->unsafeGetDummyFuncType(), {});
|
||||
fn->setBuiltin();
|
||||
ASSERT_TRUE(fn->isBuiltin());
|
||||
fn->setJIT();
|
||||
ASSERT_TRUE(fn->isJIT());
|
||||
|
||||
auto *body = fn->getBody();
|
||||
ASSERT_FALSE(body);
|
||||
@ -63,7 +63,7 @@ TEST_F(SIRCoreTest, BodiedFuncCloning) {
|
||||
auto *fn = module->Nr<BodiedFunc>("fn");
|
||||
fn->realize(module->unsafeGetDummyFuncType(), {});
|
||||
|
||||
fn->setBuiltin();
|
||||
fn->setJIT();
|
||||
fn->setBody(module->Nr<SeriesFlow>());
|
||||
ASSERT_TRUE(util::match(fn, cv->clone(fn)));
|
||||
}
|
||||
|
@ -25,8 +25,8 @@ TEST_F(SIRCoreTest, MatchingEquivalentFunc) {
|
||||
auto *second = module->Nr<BodiedFunc>();
|
||||
second->realize(module->unsafeGetDummyFuncType(), {});
|
||||
|
||||
first->setBuiltin();
|
||||
second->setBuiltin();
|
||||
first->setJIT();
|
||||
second->setJIT();
|
||||
|
||||
ASSERT_TRUE(util::match(first, second));
|
||||
}
|
||||
@ -58,7 +58,7 @@ TEST_F(SIRCoreTest, MatchingNonEquivalentFunc) {
|
||||
auto *second = module->Nr<BodiedFunc>();
|
||||
second->realize(module->unsafeGetDummyFuncType(), {});
|
||||
|
||||
first->setBuiltin();
|
||||
first->setJIT();
|
||||
|
||||
ASSERT_FALSE(util::match(first, second));
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user