pull/6/head
A. R. Shajii 2021-10-25 20:02:16 -04:00
parent f04e60b29a
commit 8003220455
4 changed files with 81 additions and 63 deletions

View File

@ -13,11 +13,6 @@ set(CMAKE_CXX_FLAGS_DEBUG "-g -fno-limit-debug-info")
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
include_directories(.)
if(ASAN)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address")
set(CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address")
endif()
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
find_package(LLVM REQUIRED)
@ -277,6 +272,10 @@ set(CODON_CPPFILES
codon/util/fmt/format.cpp)
add_library(codonc SHARED ${CODON_HPPFILES})
target_sources(codonc PRIVATE ${CODON_CPPFILES} codon_rules.cpp omp_rules.cpp)
if(ASAN)
target_compile_options(codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address")
target_link_libraries(codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address")
endif()
if(CMAKE_BUILD_TYPE MATCHES Debug)
set_source_files_properties(codon_rules.cpp codon/parser/peg/peg.cpp PROPERTIES COMPILE_FLAGS "-O2")
endif()

View File

@ -104,30 +104,33 @@ JIT::JIT(ir::Module *module)
void JIT::init() {
module->accept(*llvisitor);
auto pair = llvisitor->takeModule();
//auto rt = engine->getMainJITDylib().createResourceTracker();
llvm::cantFail(
engine->addModule({std::move(pair.second), std::move(pair.first)}));
// auto rt = engine->getMainJITDylib().createResourceTracker();
llvm::cantFail(engine->addModule({std::move(pair.second), std::move(pair.first)}));
auto func = llvm::cantFail(engine->lookup("main"));
auto *main = (MainFunc *)func.getAddress();
(*main)(0, nullptr);
//llvm::cantFail(rt->remove());
// llvm::cantFail(rt->remove());
}
void JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
const std::string name = ir::LLVMVisitor::getNameForFunction(input);
llvisitor->registerGlobal(input);
for (auto *var : newGlobals)
for (auto *var : newGlobals) {
llvisitor->registerGlobal(var);
}
for (auto *var : newGlobals) {
if (auto *func = ir::cast<ir::Func>(var))
func->accept(*llvisitor);
}
input->accept(*llvisitor);
auto pair = llvisitor->takeModule();
//auto rt = engine->getMainJITDylib().createResourceTracker();
llvm::StripDebugInfo(*pair.second);
llvm::cantFail(
engine->addModule({std::move(pair.second), std::move(pair.first)}));
// auto rt = engine->getMainJITDylib().createResourceTracker();
llvm::StripDebugInfo(*pair.second); // TODO: needed?
llvm::cantFail(engine->addModule({std::move(pair.second), std::move(pair.first)}));
auto func = llvm::cantFail(engine->lookup(name));
auto *repl = (InputFunc *)func.getAddress();
(*repl)();
//llvm::cantFail(rt->remove());
// llvm::cantFail(rt->remove());
}
} // namespace jit

View File

@ -29,8 +29,8 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags)
: util::ConstVisitor(), context(std::make_unique<llvm::LLVMContext>()), M(),
moduleId(0), B(std::make_unique<llvm::IRBuilder<>>(*context)), func(nullptr),
block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
B(std::make_unique<llvm::IRBuilder<>>(*context)), func(nullptr), block(nullptr),
value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
db(debug, jit, flags), plugins(nullptr) {
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
@ -112,10 +112,10 @@ void LLVMVisitor::registerGlobal(const Var *var) {
}
llvm::Value *LLVMVisitor::getVar(const Var *var) {
std::pair<llvm::Value *, int64_t> p = vars.get(var);
auto it = vars.find(var->getId());
if (db.jit && var->isGlobal()) {
if (auto *val = p.first) {
if (p.second != moduleId) {
if (it != vars.end()) {
if (!it->second) { // if value is null, it's from another module
// see if it's in the module already
auto name = var->getName();
if (auto *global = M->getNamedValue(name))
@ -142,34 +142,42 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) {
}
} else {
registerGlobal(var);
return vars[var];
return it->second;
}
}
return p.first;
return (it != vars.end()) ? it->second : nullptr;
}
llvm::Function *LLVMVisitor::getFunc(const Func *func) {
std::pair<llvm::Function *, int64_t> p = funcs.get(func);
auto it = funcs.find(func->getId());
if (db.jit) {
if (auto *f = p.first) {
if (p.second != moduleId) {
if (it != funcs.end()) {
if (!it->second) { // if value is null, it's from another module
// see if it's in the module already
if (auto *g = M->getFunction(f->getName()))
const std::string name = getNameForFunction(func);
if (auto *g = M->getFunction(name))
return g;
auto *g = llvm::Function::Create(f->getFunctionType(),
llvm::Function::ExternalLinkage, f->getName(),
M.get());
g->copyAttributesFrom(f);
auto *funcType = cast<types::FuncType>(func->getType());
llvm::Type *returnType = getLLVMType(funcType->getReturnType());
std::vector<llvm::Type *> argTypes;
for (const auto &argType : *funcType) {
argTypes.push_back(getLLVMType(argType));
}
auto *llvmFuncType =
llvm::FunctionType::get(returnType, argTypes, funcType->isVariadic());
auto *g = llvm::Function::Create(llvmFuncType, llvm::Function::ExternalLinkage,
name, M.get());
insertFunc(func, g);
return g;
}
} else {
registerGlobal(func);
return funcs[func];
return it->second;
}
}
return p.first;
return (it != funcs.end()) ? it->second : nullptr;
}
std::unique_ptr<llvm::Module> LLVMVisitor::makeModule(llvm::LLVMContext &context,
@ -202,9 +210,24 @@ std::pair<std::unique_ptr<llvm::LLVMContext>, std::unique_ptr<llvm::Module>>
LLVMVisitor::takeModule(const SrcInfo *src) {
auto currentContext = std::move(context);
auto currentModule = std::move(M);
// reset all LLVM fields/data -- they are owned by the context
B = {};
func = nullptr;
block = nullptr;
value = nullptr;
for (auto &it : vars)
it.second = nullptr;
for (auto &it : funcs)
it.second = nullptr;
coro.reset();
loops.clear();
trycatch.clear();
db.reset();
context = std::make_unique<llvm::LLVMContext>();
M = makeModule(*context, src);
++moduleId;
return {std::move(currentContext), std::move(currentModule)};
}

View File

@ -15,24 +15,6 @@ namespace ir {
class LLVMVisitor : public util::ConstVisitor {
private:
template <typename V>
using CacheBase = std::unordered_map<id_t, std::pair<V *, int64_t>>;
template <typename K, typename V> class Cache : public CacheBase<V> {
public:
using CacheBase<V>::CacheBase;
std::pair<V *, int64_t> get(const K *key) {
auto it = CacheBase<V>::find(key->getId());
return (it != CacheBase<V>::end()) ? it->second : std::make_pair(nullptr, 0);
}
V *operator[](const K *key) { return get(key).first; }
void insert(const K *key, V *value, int64_t id) {
CacheBase<V>::emplace(key->getId(), std::make_pair(value, id));
}
};
struct CoroData {
/// Coroutine promise (where yielded values are stored)
llvm::Value *promise;
@ -44,6 +26,8 @@ private:
llvm::BasicBlock *suspend;
/// Coroutine exit block
llvm::BasicBlock *exit;
void reset() { promise = handle = cleanup = suspend = exit = nullptr; }
};
struct NestableData {
@ -63,6 +47,8 @@ private:
LoopData(llvm::BasicBlock *breakBlock, llvm::BasicBlock *continueBlock, id_t loopId)
: NestableData(), breakBlock(breakBlock), continueBlock(continueBlock),
loopId(loopId) {}
void reset() { breakBlock = continueBlock = nullptr; }
};
struct TryCatchData : NestableData {
@ -94,6 +80,13 @@ private:
finallyBlock(nullptr), catchTypes(), handlers(), excFlag(nullptr),
catchStore(nullptr), delegateDepth(nullptr), retStore(nullptr),
loopSequence(nullptr) {}
void reset() {
exceptionBlock = exceptionRouteBlock = finallyBlock = nullptr;
catchTypes.clear();
handlers.clear();
excFlag = catchStore = delegateDepth = loopSequence = nullptr;
}
};
struct DebugInfo {
@ -112,14 +105,17 @@ private:
: builder(), unit(nullptr), debug(debug), jit(jit), flags(flags) {}
llvm::DIFile *getFile(const std::string &path);
void reset() {
builder = {};
unit = nullptr;
}
};
/// LLVM context used for compilation
std::unique_ptr<llvm::LLVMContext> context;
/// Module we are compiling
std::unique_ptr<llvm::Module> M;
/// Module ID
int64_t moduleId;
/// LLVM IR builder used for constructing LLVM IR
std::unique_ptr<llvm::IRBuilder<>> B;
/// Current function we are compiling
@ -129,9 +125,9 @@ private:
/// Last compiled value
llvm::Value *value;
/// LLVM values corresponding to IR variables
Cache<Var, llvm::Value> vars;
std::unordered_map<id_t, llvm::Value *> vars;
/// LLVM functions corresponding to IR functions
Cache<Func, llvm::Function> funcs;
std::unordered_map<id_t, llvm::Function *> funcs;
/// Coroutine data, if current function is a coroutine
CoroData coro;
/// Loop data stack, containing break/continue blocks
@ -187,19 +183,16 @@ private:
void runLLVMPipeline();
llvm::Value *getVar(const Var *var);
void insertVar(const Var *var, llvm::Value *x) { vars.insert(var, x, moduleId); }
void insertVar(const Var *var, llvm::Value *x) { vars.emplace(var->getId(), x); }
llvm::Function *getFunc(const Func *func);
void insertFunc(const Func *func, llvm::Function *x) {
funcs.insert(func, x, moduleId);
funcs.emplace(func->getId(), x);
}
llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(*context); }
public:
static std::string getNameForFunction(const Func *x) {
auto *bodiedFunc = cast<BodiedFunc>(x);
if (bodiedFunc && bodiedFunc->isJIT()) {
return x->getName();
} else if (isA<ExternalFunc>(x)) {
if (isA<ExternalFunc>(x)) {
return x->getUnmangledName();
} else {
return x->referenceString();
@ -242,8 +235,8 @@ public:
llvm::FunctionCallee getFunc() { return func; }
llvm::BasicBlock *getBlock() { return block; }
llvm::Value *getValue() { return value; }
Cache<Var, llvm::Value> &getVars() { return vars; }
Cache<Func, llvm::Function> &getFuncs() { return funcs; }
std::unordered_map<id_t, llvm::Value *> &getVars() { return vars; }
std::unordered_map<id_t, llvm::Function *> &getFuncs() { return funcs; }
CoroData &getCoro() { return coro; }
std::vector<LoopData> &getLoops() { return loops; }
std::vector<TryCatchData> &getTryCatch() { return trycatch; }