1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

More engine updates

This commit is contained in:
A. R. Shajii 2021-10-22 17:57:23 -04:00
parent 910c880666
commit 62e5577a1e
5 changed files with 166 additions and 126 deletions

View File

@ -120,10 +120,13 @@ public:
} }
}; };
typedef int MainFunc(int, char **);
typedef void InputFunc();
JIT::JIT(ir::Module *module) JIT::JIT(ir::Module *module)
: module(module), pm(std::make_unique<ir::transform::PassManager>(/*debug=*/true)), : module(module), pm(std::make_unique<ir::transform::PassManager>(/*debug=*/true)),
plm(std::make_unique<PluginManager>()), plm(std::make_unique<PluginManager>()),
llvisitor(std::make_unique<ir::LLVMVisitor>(/*debug=*/true)) { llvisitor(std::make_unique<ir::LLVMVisitor>(/*debug=*/true, /*jit=*/true)) {
if (auto e = Engine::create()) { if (auto e = Engine::create()) {
engine = std::move(e.get()); engine = std::move(e.get());
} else { } else {
@ -133,5 +136,29 @@ JIT::JIT(ir::Module *module)
llvisitor->setPluginManager(plm.get()); llvisitor->setPluginManager(plm.get());
} }
void JIT::init() {
module->accept(*llvisitor);
auto module = llvisitor->takeModule();
llvm::cantFail(
engine->addModule({std::move(module), std::make_unique<llvm::LLVMContext>()}));
auto func = llvm::cantFail(engine->lookup("main"));
auto *main = (MainFunc *)func.getAddress();
(*main)(0, nullptr);
}
void JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
const std::string name = ir::LLVMVisitor::getNameForFunction(input);
llvisitor->registerGlobal(input);
for (auto *var : newGlobals)
llvisitor->registerGlobal(var);
input->accept(*llvisitor);
auto module = llvisitor->takeModule();
llvm::cantFail(
engine->addModule({std::move(module), std::make_unique<llvm::LLVMContext>()}));
auto func = llvm::cantFail(engine->lookup(name));
auto *repl = (InputFunc *)func.getAddress();
(*repl)();
}
} // namespace jit } // namespace jit
} // namespace codon } // namespace codon

View File

@ -1,9 +1,11 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include "codon/sir/llvm/llvisitor.h" #include "codon/sir/llvm/llvisitor.h"
#include "codon/sir/transform/manager.h" #include "codon/sir/transform/manager.h"
#include "codon/sir/var.h"
namespace codon { namespace codon {
namespace jit { namespace jit {
@ -21,6 +23,8 @@ private:
public: public:
JIT(ir::Module *module); JIT(ir::Module *module);
ir::Module *getModule() const { return module; } ir::Module *getModule() const { return module; }
void init();
void run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals = {});
}; };
} // namespace jit } // namespace jit

View File

@ -13,39 +13,6 @@
namespace codon { namespace codon {
namespace ir { namespace ir {
namespace {
std::string getNameForFunction(const Func *x) {
if (auto *externalFunc = cast<ExternalFunc>(x)) {
return x->getUnmangledName();
} else {
return x->referenceString();
}
}
std::string getDebugNameForVariable(const Var *x) {
std::string name = x->getName();
auto pos = name.find(".");
if (pos != 0 && pos != std::string::npos) {
return name.substr(0, pos);
} else {
return name;
}
}
const SrcInfo *getSrcInfo(const Node *x) {
if (auto *srcInfo = x->getAttribute<SrcInfoAttribute>()) {
return &srcInfo->info;
} else {
static SrcInfo defaultSrcInfo("<internal>", 0, 0, 0);
return &defaultSrcInfo;
}
}
llvm::Value *getDummyVoidValue(llvm::LLVMContext &context) {
return llvm::ConstantTokenNone::get(context);
}
} // namespace
llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
std::string filename; std::string filename;
std::string directory; std::string directory;
@ -60,10 +27,10 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
return builder->createFile(filename, directory); return builder->createFile(filename, directory);
} }
LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags) LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags)
: util::ConstVisitor(), context(), builder(context), module(), func(nullptr), : util::ConstVisitor(), context(), builder(context), module(), func(nullptr),
block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
db(debug, flags), plugins(nullptr) { db(debug, jit, flags), plugins(nullptr) {
llvm::InitializeAllTargets(); llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs(); llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmPrinters(); llvm::InitializeAllAsmPrinters();
@ -107,65 +74,102 @@ LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags)
llvm::initializeTypePromotionPass(registry); llvm::initializeTypePromotionPass(registry);
} }
void LLVMVisitor::registerGlobal(const Var *var) {
if (!var->isGlobal())
return;
if (auto *f = cast<Func>(var)) {
makeLLVMFunction(f);
funcs.insert(f, func);
} else {
llvm::Type *llvmType = getLLVMType(var->getType());
if (llvmType->isVoidTy()) {
vars.insert(var, getDummyVoidValue());
} else {
auto *storage = new llvm::GlobalVariable(
*module, llvmType, /*isConstant=*/false, llvm::GlobalVariable::PrivateLinkage,
llvm::Constant::getNullValue(llvmType), var->getName());
vars.insert(var, storage);
// debug info
auto *srcInfo = getSrcInfo(var);
llvm::DIFile *file = db.getFile(srcInfo->file);
llvm::DIScope *scope = db.unit;
llvm::DIGlobalVariableExpression *debugVar =
db.builder->createGlobalVariableExpression(
scope, getDebugNameForVariable(var), var->getName(), file, srcInfo->line,
getDIType(var->getType()),
/*IsLocalToUnit=*/true);
storage->addDebugInfo(debugVar);
}
}
}
llvm::Value *LLVMVisitor::getVar(const Var *var) { llvm::Value *LLVMVisitor::getVar(const Var *var) {
llvm::Value *val = vars[var]; llvm::Value *val = vars[var];
if (!val) if (db.jit && var->isGlobal()) {
return nullptr; if (val) {
llvm::Module *m = nullptr;
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val))
m = x->getModule();
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val))
m = x->getParent();
llvm::Module *m = nullptr; if (m != module.get()) {
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val)) // see if it's in the module already
m = x->getModule(); auto name = var->getName();
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val)) if (auto *global = module->getNamedValue(name))
m = x->getParent(); return global;
// the following happens when JIT'ing llvm::Type *llvmType = getLLVMType(var->getType());
if (var->isGlobal() && m != module.get()) { auto *storage =
// see if it's in the module already new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false,
auto name = var->getName(); llvm::GlobalVariable::ExternalLinkage,
if (auto *global = module->getNamedValue(name)) /*Initializer=*/nullptr, name);
return global; storage->setExternallyInitialized(true);
llvm::Type *llvmType = getLLVMType(var->getType()); // debug info
auto *storage = new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false, auto *srcInfo = getSrcInfo(var);
llvm::GlobalVariable::ExternalLinkage, llvm::DIFile *file = db.getFile(srcInfo->file);
/*Initializer=*/nullptr, name); llvm::DIScope *scope = db.unit;
storage->setExternallyInitialized(true); llvm::DIGlobalVariableExpression *debugVar =
db.builder->createGlobalVariableExpression(
// debug info scope, getDebugNameForVariable(var), name, file, srcInfo->line,
auto *srcInfo = getSrcInfo(var); getDIType(var->getType()),
llvm::DIFile *file = db.getFile(srcInfo->file); /*IsLocalToUnit=*/true);
llvm::DIScope *scope = db.unit; storage->addDebugInfo(debugVar);
llvm::DIGlobalVariableExpression *debugVar = vars.insert(var, storage);
db.builder->createGlobalVariableExpression(scope, getDebugNameForVariable(var), return storage;
name, file, srcInfo->line, }
getDIType(var->getType()), } else {
/*IsLocalToUnit=*/true); registerGlobal(var);
storage->addDebugInfo(debugVar); return vars[var];
return storage; }
} }
// should never have a non-global val from another module
return val; return val;
} }
llvm::Function *LLVMVisitor::getFunc(const Func *func) { llvm::Function *LLVMVisitor::getFunc(const Func *func) {
llvm::Function *f = funcs[func]; llvm::Function *f = funcs[func];
if (!f) if (db.jit) {
return nullptr; if (f) {
if (f->getParent() != module.get()) {
// see if it's in the module already
if (auto *g = module->getFunction(f->getName()))
return g;
// the following happens when JIT'ing auto *g = llvm::Function::Create(f->getFunctionType(),
if (f->getParent() != module.get()) { llvm::Function::ExternalLinkage, f->getName(),
// see if it's in the module already module.get());
if (auto *g = module->getFunction(f->getName())) g->copyAttributesFrom(f);
return g; funcs.insert(func, g);
return g;
auto *g = }
llvm::Function::Create(f->getFunctionType(), llvm::Function::ExternalLinkage, } else {
f->getName(), module.get()); registerGlobal(func);
g->copyAttributesFrom(f); return funcs[func];
return g; }
} }
return f; return f;
} }
@ -511,44 +515,12 @@ void LLVMVisitor::visit(const Module *x) {
module = makeModule(getSrcInfo(x)); module = makeModule(getSrcInfo(x));
// args variable // args variable
const Var *argVar = x->getArgVar(); seqassert(x->getArgVar()->isGlobal(), "arg var is not global");
llvm::Type *argVarType = getLLVMType(argVar->getType()); registerGlobal(x->getArgVar());
auto *argStorage = new llvm::GlobalVariable(
*module, argVarType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
llvm::Constant::getNullValue(argVarType), argVar->getName());
vars.insert(argVar, argStorage);
// set up global variables and initialize functions // set up global variables and initialize functions
for (auto *var : *x) { for (auto *var : *x) {
if (!var->isGlobal()) registerGlobal(var);
continue;
if (auto *f = cast<Func>(var)) {
makeLLVMFunction(f);
funcs.insert(f, func);
} else {
llvm::Type *llvmType = getLLVMType(var->getType());
if (llvmType->isVoidTy()) {
vars.insert(var, getDummyVoidValue(context));
} else {
auto *storage = new llvm::GlobalVariable(
*module, llvmType, /*isConstant=*/false,
llvm::GlobalVariable::PrivateLinkage,
llvm::Constant::getNullValue(llvmType), var->getName());
vars.insert(var, storage);
// debug info
auto *srcInfo = getSrcInfo(var);
llvm::DIFile *file = db.getFile(srcInfo->file);
llvm::DIScope *scope = db.unit;
llvm::DIGlobalVariableExpression *debugVar =
db.builder->createGlobalVariableExpression(
scope, getDebugNameForVariable(var), var->getName(), file,
srcInfo->line, getDIType(var->getType()),
/*IsLocalToUnit=*/true);
storage->addDebugInfo(debugVar);
}
}
} }
// process functions // process functions
@ -632,6 +604,8 @@ void LLVMVisitor::visit(const Module *x) {
builder.CreateBr(loopBlock); builder.CreateBr(loopBlock);
builder.SetInsertPoint(exitBlock); builder.SetInsertPoint(exitBlock);
llvm::Value *argStorage = vars[x->getArgVar()];
seqassert(argStorage, "argument storage missing");
builder.CreateStore(arr, argStorage); builder.CreateStore(arr, argStorage);
builder.CreateCall(initFunc, builder.getInt32(db.debug ? 1 : 0)); builder.CreateCall(initFunc, builder.getInt32(db.debug ? 1 : 0));
@ -739,7 +713,7 @@ void LLVMVisitor::makeYield(llvm::Value *value, bool finalYield) {
} }
void LLVMVisitor::visit(const ExternalFunc *x) { void LLVMVisitor::visit(const ExternalFunc *x) {
func = module->getFunction(getNameForFunction(x)); // inserted during module visit func = module->getFunction(getNameForFunction(x));
coro = {}; coro = {};
seqassert(func, "{} not inserted", *x); seqassert(func, "{} not inserted", *x);
func->setDoesNotThrow(); func->setDoesNotThrow();
@ -775,7 +749,7 @@ bool internalFuncMatches(const std::string &name, const InternalFunc *x) {
void LLVMVisitor::visit(const InternalFunc *x) { void LLVMVisitor::visit(const InternalFunc *x) {
using namespace types; using namespace types;
func = module->getFunction(getNameForFunction(x)); // inserted during module visit func = module->getFunction(getNameForFunction(x));
coro = {}; coro = {};
seqassert(func, "{} not inserted", *x); seqassert(func, "{} not inserted", *x);
setDebugInfoForNode(x); setDebugInfoForNode(x);
@ -953,7 +927,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
} }
void LLVMVisitor::visit(const BodiedFunc *x) { void LLVMVisitor::visit(const BodiedFunc *x) {
func = module->getFunction(getNameForFunction(x)); // inserted during module visit func = module->getFunction(getNameForFunction(x));
coro = {}; coro = {};
seqassert(func, "{} not inserted", *x); seqassert(func, "{} not inserted", *x);
setDebugInfoForNode(x); setDebugInfoForNode(x);
@ -1009,7 +983,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
for (auto *var : *x) { for (auto *var : *x) {
llvm::Type *llvmType = getLLVMType(var->getType()); llvm::Type *llvmType = getLLVMType(var->getType());
if (llvmType->isVoidTy()) { if (llvmType->isVoidTy()) {
vars.insert(var, getDummyVoidValue(context)); vars.insert(var, getDummyVoidValue());
} else { } else {
llvm::Value *storage = builder.CreateAlloca(llvmType); llvm::Value *storage = builder.CreateAlloca(llvmType);
vars.insert(var, storage); vars.insert(var, storage);
@ -2031,7 +2005,7 @@ void LLVMVisitor::visit(const AssignInstr *x) {
llvm::Value *var = getVar(x->getLhs()); llvm::Value *var = getVar(x->getLhs());
seqassert(var, "could not find {} var", *x); seqassert(var, "could not find {} var", *x);
process(x->getRhs()); process(x->getRhs());
if (var != getDummyVoidValue(context)) { if (var != getDummyVoidValue()) {
builder.SetInsertPoint(block); builder.SetInsertPoint(block);
builder.CreateStore(value, var); builder.CreateStore(value, var);
} }

View File

@ -97,11 +97,13 @@ private:
llvm::DICompileUnit *unit; llvm::DICompileUnit *unit;
/// Whether we are compiling in debug mode /// Whether we are compiling in debug mode
bool debug; bool debug;
/// Whether we are compiling in JIT mode
bool jit;
/// Program command-line flags /// Program command-line flags
std::string flags; std::string flags;
explicit DebugInfo(bool debug, const std::string &flags) DebugInfo(bool debug, bool jit, const std::string &flags)
: builder(), unit(nullptr), debug(debug), flags(flags) {} : builder(), unit(nullptr), debug(debug), jit(jit), flags(flags) {}
llvm::DIFile *getFile(const std::string &path); llvm::DIFile *getFile(const std::string &path);
}; };
@ -178,9 +180,37 @@ private:
llvm::Value *getVar(const Var *var); llvm::Value *getVar(const Var *var);
llvm::Function *getFunc(const Func *func); llvm::Function *getFunc(const Func *func);
llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(context); }
public: public:
LLVMVisitor(bool debug = false, const std::string &flags = ""); static std::string getNameForFunction(const Func *x) {
if (auto *externalFunc = cast<ExternalFunc>(x)) {
return x->getUnmangledName();
} else {
return x->referenceString();
}
}
static std::string getDebugNameForVariable(const Var *x) {
std::string name = x->getName();
auto pos = name.find(".");
if (pos != 0 && pos != std::string::npos) {
return name.substr(0, pos);
} else {
return name;
}
}
static const SrcInfo *getSrcInfo(const Node *x) {
if (auto *srcInfo = x->getAttribute<SrcInfoAttribute>()) {
return &srcInfo->info;
} else {
static SrcInfo defaultSrcInfo("<internal>", 0, 0, 0);
return &defaultSrcInfo;
}
}
LLVMVisitor(bool debug = false, bool jit = false, const std::string &flags = "");
llvm::LLVMContext &getContext() { return context; } llvm::LLVMContext &getContext() { return context; }
llvm::IRBuilder<> &getBuilder() { return builder; } llvm::IRBuilder<> &getBuilder() { return builder; }
@ -199,6 +229,11 @@ public:
void setBlock(llvm::BasicBlock *b) { block = b; } void setBlock(llvm::BasicBlock *b) { block = b; }
void setValue(llvm::Value *v) { value = v; } void setValue(llvm::Value *v) { value = v; }
/// Registers a new global variable or function with
/// this visitor.
/// @param var the global variable (or function) to register
void registerGlobal(const Var *var);
/// Returns a new LLVM module initialized for the host /// Returns a new LLVM module initialized for the host
/// architecture. /// architecture.
/// @param src source information for the new module /// @param src source information for the new module

View File

@ -45,7 +45,7 @@ void GlobalDemotionPass::run(Module *M) {
} }
for (auto it : localGlobals) { for (auto it : localGlobals) {
if (!it.second) if (!it.second || it.first->getId() == M->getArgVar()->getId())
continue; continue;
seqassert(it.first->isGlobal(), "var was not global"); seqassert(it.first->isGlobal(), "var was not global");
it.first->setGlobal(false); it.first->setGlobal(false);