diff --git a/codon/jit/engine.cpp b/codon/jit/engine.cpp index 613e4935..4f71ef17 100644 --- a/codon/jit/engine.cpp +++ b/codon/jit/engine.cpp @@ -138,12 +138,14 @@ JIT::JIT(ir::Module *module) void JIT::init() { module->accept(*llvisitor); - auto module = llvisitor->takeModule(); + auto pair = llvisitor->takeModule(); + auto rt = engine->getMainJITDylib().createResourceTracker(); llvm::cantFail( - engine->addModule({std::move(module), std::make_unique()})); + engine->addModule({std::move(std::get<1>(pair)), std::move(std::get<0>(pair))})); auto func = llvm::cantFail(engine->lookup("main")); auto *main = (MainFunc *)func.getAddress(); (*main)(0, nullptr); + llvm::cantFail(rt->remove()); } void JIT::run(const ir::Func *input, const std::vector &newGlobals) { @@ -152,12 +154,14 @@ void JIT::run(const ir::Func *input, const std::vector &newGlobals) { for (auto *var : newGlobals) llvisitor->registerGlobal(var); input->accept(*llvisitor); - auto module = llvisitor->takeModule(); + auto pair = llvisitor->takeModule(); + auto rt = engine->getMainJITDylib().createResourceTracker(); llvm::cantFail( - engine->addModule({std::move(module), std::make_unique()})); + engine->addModule({std::move(std::get<1>(pair)), std::move(std::get<0>(pair))})); auto func = llvm::cantFail(engine->lookup(name)); auto *repl = (InputFunc *)func.getAddress(); (*repl)(); + llvm::cantFail(rt->remove()); } } // namespace jit diff --git a/codon/sir/llvm/llvisitor.cpp b/codon/sir/llvm/llvisitor.cpp index ab442123..30b64008 100644 --- a/codon/sir/llvm/llvisitor.cpp +++ b/codon/sir/llvm/llvisitor.cpp @@ -28,9 +28,9 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { } LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags) - : util::ConstVisitor(), context(), builder(context), module(), func(nullptr), - block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), - db(debug, jit, flags), plugins(nullptr) { + : util::ConstVisitor(), context(std::make_unique()), module(), + builder(*context), func(nullptr), block(nullptr), value(nullptr), vars(), funcs(), + coro(), loops(), trycatch(), db(debug, jit, flags), plugins(nullptr) { llvm::InitializeAllTargets(); llvm::InitializeAllTargetMCs(); llvm::InitializeAllAsmPrinters(); @@ -173,7 +173,8 @@ llvm::Function *LLVMVisitor::getFunc(const Func *func) { return f; } -std::unique_ptr LLVMVisitor::makeModule(const SrcInfo *src) { +std::unique_ptr LLVMVisitor::makeModule(llvm::LLVMContext &context, + const SrcInfo *src) { auto module = std::make_unique("codon", context); module->setTargetTriple( llvm::EngineBuilder().selectTarget()->getTargetTriple().str()); @@ -199,17 +200,20 @@ std::unique_ptr LLVMVisitor::makeModule(const SrcInfo *src) { return module; } -std::unique_ptr LLVMVisitor::takeModule(const SrcInfo *src) { +std::pair, std::unique_ptr> +LLVMVisitor::takeModule(const SrcInfo *src) { + auto currentContext = std::move(context); auto currentModule = std::move(module); - module = makeModule(src); - return currentModule; + context = std::make_unique(); + module = makeModule(*context, src); + return {std::move(currentContext), std::move(currentModule)}; } void LLVMVisitor::setDebugInfoForNode(const Node *x) { if (x && func) { auto *srcInfo = getSrcInfo(x); builder.SetCurrentDebugLocation(llvm::DILocation::get( - context, srcInfo->line, srcInfo->col, func->getSubprogram())); + *context, srcInfo->line, srcInfo->col, func->getSubprogram())); } else { builder.SetCurrentDebugLocation(llvm::DebugLoc()); } @@ -466,7 +470,7 @@ llvm::Value *LLVMVisitor::call(llvm::FunctionCallee callee, if (trycatch.empty()) { return builder.CreateCall(callee, args); } else { - auto *normalBlock = llvm::BasicBlock::Create(context, "invoke.normal", func); + auto *normalBlock = llvm::BasicBlock::Create(*context, "invoke.normal", func); auto *unwindBlock = trycatch.back().exceptionBlock; auto *result = builder.CreateInvoke(callee, normalBlock, unwindBlock, args); block = normalBlock; @@ -512,7 +516,8 @@ LLVMVisitor::TryCatchData *LLVMVisitor::getInnermostTryCatchBeforeLoop() { */ void LLVMVisitor::visit(const Module *x) { - module = makeModule(getSrcInfo(x)); + // initialize module + module = makeModule(*context, getSrcInfo(x)); // args variable seqassert(x->getArgVar()->isGlobal(), "arg var is not global"); @@ -538,9 +543,9 @@ void LLVMVisitor::visit(const Module *x) { // build canonical main function auto *strType = - llvm::StructType::get(context, {builder.getInt64Ty(), builder.getInt8PtrTy()}); + llvm::StructType::get(*context, {builder.getInt64Ty(), builder.getInt8PtrTy()}); auto *arrType = - llvm::StructType::get(context, {builder.getInt64Ty(), strType->getPointerTo()}); + llvm::StructType::get(*context, {builder.getInt64Ty(), strType->getPointerTo()}); auto *initFunc = llvm::cast( module->getOrInsertFunction("seq_init", builder.getVoidTy(), builder.getInt32Ty()) @@ -567,10 +572,10 @@ void LLVMVisitor::visit(const Module *x) { // The following generates code to put program arguments in an array, i.e.: // for (int i = 0; i < argc; i++) // array[i] = {strlen(argv[i]), argv[i]} - auto *entryBlock = llvm::BasicBlock::Create(context, "entry", canonicalMainFunc); - auto *loopBlock = llvm::BasicBlock::Create(context, "loop", canonicalMainFunc); - auto *bodyBlock = llvm::BasicBlock::Create(context, "body", canonicalMainFunc); - auto *exitBlock = llvm::BasicBlock::Create(context, "exit", canonicalMainFunc); + auto *entryBlock = llvm::BasicBlock::Create(*context, "entry", canonicalMainFunc); + auto *loopBlock = llvm::BasicBlock::Create(*context, "loop", canonicalMainFunc); + auto *bodyBlock = llvm::BasicBlock::Create(*context, "body", canonicalMainFunc); + auto *exitBlock = llvm::BasicBlock::Create(*context, "exit", canonicalMainFunc); builder.SetInsertPoint(entryBlock); auto allocFunc = makeAllocFunc(/*atomic=*/false); @@ -617,9 +622,9 @@ void LLVMVisitor::visit(const Module *x) { proxyMain->setLinkage(llvm::GlobalValue::PrivateLinkage); proxyMain->setPersonalityFn( llvm::cast(makePersonalityFunc().getCallee())); - auto *proxyBlockEntry = llvm::BasicBlock::Create(context, "entry", proxyMain); - auto *proxyBlockMain = llvm::BasicBlock::Create(context, "main", proxyMain); - auto *proxyBlockExit = llvm::BasicBlock::Create(context, "exit", proxyMain); + auto *proxyBlockEntry = llvm::BasicBlock::Create(*context, "entry", proxyMain); + auto *proxyBlockMain = llvm::BasicBlock::Create(*context, "main", proxyMain); + auto *proxyBlockExit = llvm::BasicBlock::Create(*context, "exit", proxyMain); builder.SetInsertPoint(proxyBlockEntry); llvm::Value *shouldExit = builder.getFalse(); @@ -629,8 +634,8 @@ void LLVMVisitor::visit(const Module *x) { builder.CreateRetVoid(); // invoke real main - auto *normal = llvm::BasicBlock::Create(context, "normal", proxyMain); - auto *unwind = llvm::BasicBlock::Create(context, "unwind", proxyMain); + auto *normal = llvm::BasicBlock::Create(*context, "normal", proxyMain); + auto *unwind = llvm::BasicBlock::Create(*context, "unwind", proxyMain); builder.SetInsertPoint(proxyBlockMain); builder.CreateInvoke(realMain, normal, unwind); @@ -702,10 +707,10 @@ void LLVMVisitor::makeYield(llvm::Value *value, bool finalYield) { llvm::FunctionCallee coroSuspend = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::coro_suspend); llvm::Value *suspendResult = - builder.CreateCall(coroSuspend, {llvm::ConstantTokenNone::get(context), + builder.CreateCall(coroSuspend, {llvm::ConstantTokenNone::get(*context), builder.getInt1(finalYield)}); - block = llvm::BasicBlock::Create(context, "yield.new", func); + block = llvm::BasicBlock::Create(*context, "yield.new", func); llvm::SwitchInst *inst = builder.CreateSwitch(suspendResult, coro.suspend, 2); inst->addCase(builder.getInt8(0), block); @@ -764,7 +769,7 @@ void LLVMVisitor::visit(const InternalFunc *x) { for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { args.push_back(it); } - block = llvm::BasicBlock::Create(context, "entry", func); + block = llvm::BasicBlock::Create(*context, "entry", func); builder.SetInsertPoint(block); llvm::Value *result = nullptr; @@ -908,7 +913,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) { std::unique_ptr buf = llvm::MemoryBuffer::getMemBuffer(code); seqassert(buf, "could not create buffer"); std::unique_ptr sub = - llvm::parseIR(buf->getMemBufferRef(), err, context); + llvm::parseIR(buf->getMemBufferRef(), err, *context); if (!sub) { std::string bufStr; llvm::raw_string_ostream buf(bufStr); @@ -949,7 +954,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) { auto *funcType = cast(x->getType()); seqassert(funcType, "{} is not a function type", *x->getType()); auto *returnType = funcType->getReturnType(); - auto *entryBlock = llvm::BasicBlock::Create(context, "entry", func); + auto *entryBlock = llvm::BasicBlock::Create(*context, "entry", func); builder.SetInsertPoint(entryBlock); builder.SetCurrentDebugLocation(llvm::DebugLoc()); @@ -974,7 +979,8 @@ void LLVMVisitor::visit(const BodiedFunc *x) { getDIType(var->getType()), db.debug); db.builder->insertDeclare( storage, debugVar, db.builder->createExpression(), - llvm::DILocation::get(context, srcInfo->line, srcInfo->col, scope), entryBlock); + llvm::DILocation::get(*context, srcInfo->line, srcInfo->col, scope), + entryBlock); ++argIter; ++argIdx; @@ -997,12 +1003,12 @@ void LLVMVisitor::visit(const BodiedFunc *x) { getDIType(var->getType()), db.debug); db.builder->insertDeclare( storage, debugVar, db.builder->createExpression(), - llvm::DILocation::get(context, srcInfo->line, srcInfo->col, scope), + llvm::DILocation::get(*context, srcInfo->line, srcInfo->col, scope), entryBlock); } } - auto *startBlock = llvm::BasicBlock::Create(context, "start", func); + auto *startBlock = llvm::BasicBlock::Create(*context, "start", func); if (x->isGenerator()) { auto *generatorType = cast(returnType); @@ -1021,11 +1027,11 @@ void LLVMVisitor::visit(const BodiedFunc *x) { llvm::FunctionCallee coroFree = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::coro_free); - coro.cleanup = llvm::BasicBlock::Create(context, "coro.cleanup", func); - coro.suspend = llvm::BasicBlock::Create(context, "coro.suspend", func); - coro.exit = llvm::BasicBlock::Create(context, "coro.exit", func); - auto *allocBlock = llvm::BasicBlock::Create(context, "coro.alloc", func); - auto *freeBlock = llvm::BasicBlock::Create(context, "coro.free", func); + coro.cleanup = llvm::BasicBlock::Create(*context, "coro.cleanup", func); + coro.suspend = llvm::BasicBlock::Create(*context, "coro.suspend", func); + coro.exit = llvm::BasicBlock::Create(*context, "coro.exit", func); + auto *allocBlock = llvm::BasicBlock::Create(*context, "coro.alloc", func); + auto *freeBlock = llvm::BasicBlock::Create(*context, "coro.free", func); // coro ID and promise llvm::Value *id = nullptr; @@ -1152,7 +1158,7 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) { for (const auto &field : *x) { body.push_back(getLLVMType(field.getType())); } - return llvm::StructType::get(context, body); + return llvm::StructType::get(*context, body); } if (auto *x = cast(t)) { @@ -1283,7 +1289,7 @@ llvm::DIType *LLVMVisitor::getDITypeHelper( argTypes.push_back(getDITypeHelper(argType, cache)); } return db.builder->createPointerType( - db.builder->createSubroutineType(llvm::MDTuple::get(context, argTypes)), + db.builder->createSubroutineType(llvm::MDTuple::get(*context, argTypes)), layout.getTypeAllocSizeInBits(type)); } @@ -1383,7 +1389,7 @@ void LLVMVisitor::visit(const StringConst *x) { auto *strVar = new llvm::GlobalVariable( *module, llvm::ArrayType::get(builder.getInt8Ty(), s.length() + 1), /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, - llvm::ConstantDataArray::getString(context, s), "str_literal"); + llvm::ConstantDataArray::getString(*context, s), "str_literal"); auto *strType = llvm::StructType::get(builder.getInt64Ty(), builder.getInt8PtrTy()); llvm::Value *ptr = builder.CreateBitCast(strVar, builder.getInt8PtrTy()); llvm::Value *len = builder.getInt64(s.length()); @@ -1408,9 +1414,9 @@ void LLVMVisitor::visit(const SeriesFlow *x) { } void LLVMVisitor::visit(const IfFlow *x) { - auto *trueBlock = llvm::BasicBlock::Create(context, "if.true", func); - auto *falseBlock = llvm::BasicBlock::Create(context, "if.false", func); - auto *exitBlock = llvm::BasicBlock::Create(context, "if.exit", func); + auto *trueBlock = llvm::BasicBlock::Create(*context, "if.true", func); + auto *falseBlock = llvm::BasicBlock::Create(*context, "if.false", func); + auto *exitBlock = llvm::BasicBlock::Create(*context, "if.exit", func); process(x->getCond()); llvm::Value *cond = value; @@ -1436,9 +1442,9 @@ void LLVMVisitor::visit(const IfFlow *x) { } void LLVMVisitor::visit(const WhileFlow *x) { - auto *condBlock = llvm::BasicBlock::Create(context, "while.cond", func); - auto *bodyBlock = llvm::BasicBlock::Create(context, "while.body", func); - auto *exitBlock = llvm::BasicBlock::Create(context, "while.exit", func); + auto *condBlock = llvm::BasicBlock::Create(*context, "while.cond", func); + auto *bodyBlock = llvm::BasicBlock::Create(*context, "while.body", func); + auto *exitBlock = llvm::BasicBlock::Create(*context, "while.exit", func); builder.SetInsertPoint(block); builder.CreateBr(condBlock); @@ -1467,10 +1473,10 @@ void LLVMVisitor::visit(const ForFlow *x) { llvm::Value *loopVar = getVar(x->getVar()); seqassert(loopVar, "{} loop variable not found", *x); - auto *condBlock = llvm::BasicBlock::Create(context, "for.cond", func); - auto *bodyBlock = llvm::BasicBlock::Create(context, "for.body", func); - auto *cleanupBlock = llvm::BasicBlock::Create(context, "for.cleanup", func); - auto *exitBlock = llvm::BasicBlock::Create(context, "for.exit", func); + auto *condBlock = llvm::BasicBlock::Create(*context, "for.cond", func); + auto *bodyBlock = llvm::BasicBlock::Create(*context, "for.body", func); + auto *cleanupBlock = llvm::BasicBlock::Create(*context, "for.cleanup", func); + auto *exitBlock = llvm::BasicBlock::Create(*context, "for.exit", func); // LLVM coroutine intrinsics // https://prereleases.llvm.org/6.0.0/rc3/docs/Coroutines.html @@ -1526,10 +1532,10 @@ void LLVMVisitor::visit(const ImperativeForFlow *x) { seqassert(loopVar, "{} loop variable not found", *x); seqassert(x->getStep() != 0, "step cannot be 0"); - auto *condBlock = llvm::BasicBlock::Create(context, "imp_for.cond", func); - auto *bodyBlock = llvm::BasicBlock::Create(context, "imp_for.body", func); - auto *updateBlock = llvm::BasicBlock::Create(context, "imp_for.update", func); - auto *exitBlock = llvm::BasicBlock::Create(context, "imp_for.exit", func); + auto *condBlock = llvm::BasicBlock::Create(*context, "imp_for.cond", func); + auto *bodyBlock = llvm::BasicBlock::Create(*context, "imp_for.body", func); + auto *updateBlock = llvm::BasicBlock::Create(*context, "imp_for.update", func); + auto *exitBlock = llvm::BasicBlock::Create(*context, "imp_for.exit", func); process(x->getStart()); builder.SetInsertPoint(block); @@ -1586,20 +1592,20 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { const bool isRoot = trycatch.empty(); const bool supportBreakAndContinue = !loops.empty(); builder.SetInsertPoint(block); - auto *entryBlock = llvm::BasicBlock::Create(context, "trycatch.entry", func); + auto *entryBlock = llvm::BasicBlock::Create(*context, "trycatch.entry", func); builder.CreateBr(entryBlock); TryCatchData tc; - tc.exceptionBlock = llvm::BasicBlock::Create(context, "trycatch.exception", func); + tc.exceptionBlock = llvm::BasicBlock::Create(*context, "trycatch.exception", func); tc.exceptionRouteBlock = - llvm::BasicBlock::Create(context, "trycatch.exception_route", func); - tc.finallyBlock = llvm::BasicBlock::Create(context, "trycatch.finally", func); + llvm::BasicBlock::Create(*context, "trycatch.exception_route", func); + tc.finallyBlock = llvm::BasicBlock::Create(*context, "trycatch.finally", func); auto *externalExcBlock = - llvm::BasicBlock::Create(context, "trycatch.exception_external", func); + llvm::BasicBlock::Create(*context, "trycatch.exception_external", func); auto *unwindResumeBlock = - llvm::BasicBlock::Create(context, "trycatch.unwind_resume", func); - auto *endBlock = llvm::BasicBlock::Create(context, "trycatch.end", func); + llvm::BasicBlock::Create(*context, "trycatch.unwind_resume", func); + auto *endBlock = llvm::BasicBlock::Create(*context, "trycatch.end", func); builder.SetInsertPoint(func->getEntryBlock().getTerminator()); auto *excStateNotThrown = builder.getInt8(TryCatchData::State::NOT_THROWN); @@ -1646,9 +1652,9 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { llvm::Value *depthRead = builder.CreateLoad(tc.delegateDepth); llvm::Value *delegate = builder.CreateICmpSGT(depthRead, builder.getInt64(0)); auto *finallyNormal = - llvm::BasicBlock::Create(context, "trycatch.finally.normal", func); + llvm::BasicBlock::Create(*context, "trycatch.finally.normal", func); auto *finallyDelegate = - llvm::BasicBlock::Create(context, "trycatch.finally.delegate", func); + llvm::BasicBlock::Create(*context, "trycatch.finally.delegate", func); builder.CreateCondBr(delegate, finallyDelegate, finallyNormal); builder.SetInsertPoint(finallyDelegate); @@ -1670,7 +1676,7 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { if (isRoot) { auto *finallyReturn = - llvm::BasicBlock::Create(context, "trycatch.finally.return", func); + llvm::BasicBlock::Create(*context, "trycatch.finally.return", func); theSwitch->addCase(excStateReturn, finallyReturn); builder.SetInsertPoint(finallyReturn); if (coro.exit) { @@ -1689,13 +1695,13 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { auto prevSeq = isRoot ? -1 : trycatch.back().sequenceNumber; auto *finallyBreak = - llvm::BasicBlock::Create(context, "trycatch.finally.break", func); + llvm::BasicBlock::Create(*context, "trycatch.finally.break", func); auto *finallyBreakDone = - llvm::BasicBlock::Create(context, "trycatch.finally.break.done", func); + llvm::BasicBlock::Create(*context, "trycatch.finally.break.done", func); auto *finallyContinue = - llvm::BasicBlock::Create(context, "trycatch.finally.continue", func); + llvm::BasicBlock::Create(*context, "trycatch.finally.continue", func); auto *finallyContinueDone = - llvm::BasicBlock::Create(context, "trycatch.finally.continue.done", func); + llvm::BasicBlock::Create(*context, "trycatch.finally.continue.done", func); builder.SetInsertPoint(finallyBreak); auto *breakSwitch = @@ -1740,7 +1746,7 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { llvm::BasicBlock *catchAll = nullptr; for (auto *c : catches) { - auto *catchBlock = llvm::BasicBlock::Create(context, "trycatch.catch", func); + auto *catchBlock = llvm::BasicBlock::Create(*context, "trycatch.catch", func); tc.catchTypes.push_back(c->getType()); tc.handlers.push_back(catchBlock); @@ -1784,7 +1790,7 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { if (!it->catchTypes[i] && !catchAll) { // catch-all is in parent; set finally depth catchAll = - llvm::BasicBlock::Create(context, "trycatch.fdepth_catchall", func); + llvm::BasicBlock::Create(*context, "trycatch.fdepth_catchall", func); builder.SetInsertPoint(catchAll); builder.CreateStore(builder.getInt64(depth), tc.delegateDepth); builder.CreateBr(it->handlers[i]); @@ -1846,7 +1852,7 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { llvm::Value *objPtr = builder.CreateExtractValue(loadedExc, 1); // set depth when catch-all entered - auto *defaultRouteBlock = llvm::BasicBlock::Create(context, "trycatch.fdepth", func); + auto *defaultRouteBlock = llvm::BasicBlock::Create(*context, "trycatch.fdepth", func); builder.SetInsertPoint(defaultRouteBlock); if (catchAll) builder.CreateStore(builder.getInt64(catchAllDepth), tc.delegateDepth); @@ -1858,7 +1864,7 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { builder.CreateSwitch(objType, defaultRouteBlock, (unsigned)handlersFull.size()); for (unsigned i = 0; i < handlersFull.size(); i++) { // set finally depth - auto *depthSet = llvm::BasicBlock::Create(context, "trycatch.fdepth", func); + auto *depthSet = llvm::BasicBlock::Create(*context, "trycatch.fdepth", func); builder.SetInsertPoint(depthSet); builder.CreateStore(builder.getInt64(depths[i]), tc.delegateDepth); builder.CreateBr((i < tc.handlers.size()) ? handlersFull[i] : tc.finallyBlock); @@ -1932,10 +1938,10 @@ void LLVMVisitor::codegenPipeline( seqassert(generatorType, "{} is not a generator type", *prevStage->getOutputType()); auto *baseType = getLLVMType(generatorType->getBase()); - auto *condBlock = llvm::BasicBlock::Create(context, "pipeline.cond", func); - auto *bodyBlock = llvm::BasicBlock::Create(context, "pipeline.body", func); - auto *cleanupBlock = llvm::BasicBlock::Create(context, "pipeline.cleanup", func); - auto *exitBlock = llvm::BasicBlock::Create(context, "pipeline.exit", func); + auto *condBlock = llvm::BasicBlock::Create(*context, "pipeline.cond", func); + auto *bodyBlock = llvm::BasicBlock::Create(*context, "pipeline.body", func); + auto *cleanupBlock = llvm::BasicBlock::Create(*context, "pipeline.cleanup", func); + auto *exitBlock = llvm::BasicBlock::Create(*context, "pipeline.exit", func); // LLVM coroutine intrinsics // https://prereleases.llvm.org/6.0.0/rc3/docs/Coroutines.html @@ -2081,11 +2087,11 @@ void LLVMVisitor::visit(const YieldInInstr *x) { if (x->isSuspending()) { llvm::FunctionCallee coroSuspend = llvm::Intrinsic::getDeclaration(module.get(), llvm::Intrinsic::coro_suspend); - llvm::Value *tok = llvm::ConstantTokenNone::get(context); + llvm::Value *tok = llvm::ConstantTokenNone::get(*context); llvm::Value *final = builder.getFalse(); llvm::Value *susp = builder.CreateCall(coroSuspend, {tok, final}); - block = llvm::BasicBlock::Create(context, "yieldin.new", func); + block = llvm::BasicBlock::Create(*context, "yieldin.new", func); llvm::SwitchInst *inst = builder.CreateSwitch(susp, coro.suspend, 2); inst->addCase(builder.getInt8(0), block); inst->addCase(builder.getInt8(1), coro.cleanup); @@ -2108,9 +2114,9 @@ void LLVMVisitor::visit(const StackAllocInstr *x) { } void LLVMVisitor::visit(const TernaryInstr *x) { - auto *trueBlock = llvm::BasicBlock::Create(context, "ternary.true", func); - auto *falseBlock = llvm::BasicBlock::Create(context, "ternary.false", func); - auto *exitBlock = llvm::BasicBlock::Create(context, "ternary.exit", func); + auto *trueBlock = llvm::BasicBlock::Create(*context, "ternary.true", func); + auto *falseBlock = llvm::BasicBlock::Create(*context, "ternary.false", func); + auto *exitBlock = llvm::BasicBlock::Create(*context, "ternary.exit", func); llvm::Type *valueType = getLLVMType(x->getType()); process(x->getCond()); @@ -2158,7 +2164,7 @@ void LLVMVisitor::visit(const BreakInstr *x) { builder.CreateBr(tc->finallyBlock); } - block = llvm::BasicBlock::Create(context, "break.new", func); + block = llvm::BasicBlock::Create(*context, "break.new", func); } void LLVMVisitor::visit(const ContinueInstr *x) { @@ -2176,7 +2182,7 @@ void LLVMVisitor::visit(const ContinueInstr *x) { builder.CreateBr(tc->finallyBlock); } - block = llvm::BasicBlock::Create(context, "continue.new", func); + block = llvm::BasicBlock::Create(*context, "continue.new", func); } void LLVMVisitor::visit(const ReturnInstr *x) { @@ -2209,7 +2215,7 @@ void LLVMVisitor::visit(const ReturnInstr *x) { } } } - block = llvm::BasicBlock::Create(context, "return.new", func); + block = llvm::BasicBlock::Create(*context, "return.new", func); } void LLVMVisitor::visit(const YieldInstr *x) { @@ -2228,7 +2234,7 @@ void LLVMVisitor::visit(const YieldInstr *x) { } else { builder.CreateBr(coro.exit); } - block = llvm::BasicBlock::Create(context, "yield.new", func); + block = llvm::BasicBlock::Create(*context, "yield.new", func); } else { if (x->getValue()) { process(x->getValue()); diff --git a/codon/sir/llvm/llvisitor.h b/codon/sir/llvm/llvisitor.h index 74f59a87..a4c10e98 100644 --- a/codon/sir/llvm/llvisitor.h +++ b/codon/sir/llvm/llvisitor.h @@ -7,6 +7,7 @@ #include #include +#include #include namespace codon { @@ -109,11 +110,11 @@ private: }; /// LLVM context used for compilation - llvm::LLVMContext context; - /// LLVM IR builder used for constructing LLVM IR - llvm::IRBuilder<> builder; + std::unique_ptr context; /// Module we are compiling std::unique_ptr module; + /// LLVM IR builder used for constructing LLVM IR + llvm::IRBuilder<> builder; /// Current function we are compiling llvm::Function *func; /// Current basic block we are compiling @@ -180,7 +181,7 @@ private: llvm::Value *getVar(const Var *var); llvm::Function *getFunc(const Func *func); - llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(context); } + llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(*context); } public: static std::string getNameForFunction(const Func *x) { @@ -210,9 +211,14 @@ public: } } - LLVMVisitor(bool debug = false, bool jit = false, const std::string &flags = ""); + /// Constructs an LLVM visitor. + /// @param debug whether to compile in debug mode + /// @param jit whether to compile in JIT mode + /// @param flags command-line flags to be included in debug info + explicit LLVMVisitor(bool debug = false, bool jit = false, + const std::string &flags = ""); - llvm::LLVMContext &getContext() { return context; } + llvm::LLVMContext &getContext() { return *context; } llvm::IRBuilder<> &getBuilder() { return builder; } llvm::Module *getModule() { return module.get(); } llvm::FunctionCallee getFunc() { return func; } @@ -236,16 +242,19 @@ public: /// Returns a new LLVM module initialized for the host /// architecture. + /// @param context LLVM context used for creating module /// @param src source information for the new module /// @return a new module - std::unique_ptr makeModule(const SrcInfo *src = nullptr); + std::unique_ptr makeModule(llvm::LLVMContext &context, + const SrcInfo *src = nullptr); - /// Returns the current LLVM module and replaces it with a - /// new, fresh one. References to variables or functions + /// Returns the current LLVM context/module and replaces them + /// with new, fresh ones. References to variables or functions /// from the old module will be included as "external". /// @param src source information for the new module - /// @return the current module, replaced internally - std::unique_ptr takeModule(const SrcInfo *src = nullptr); + /// @return the current context/module, replaced internally + std::pair, std::unique_ptr> + takeModule(const SrcInfo *src = nullptr); /// Sets current debug info based on a given node. /// @param node the node whose debug info to use