From ddc47d453ca80104fc451ef63930a195ddda5224 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Mon, 1 Nov 2021 19:10:33 -0400 Subject: [PATCH] Update JIT --- codon/app/main.cpp | 44 +++++++++++++++-- codon/compiler/compiler.cpp | 27 +++++++++-- codon/compiler/compiler.h | 4 +- codon/compiler/jit.cpp | 97 ++++++++++++++++++++++++++++++------- codon/compiler/jit.h | 22 ++++----- test/main.cpp | 34 +++++++------ 6 files changed, 175 insertions(+), 53 deletions(-) diff --git a/codon/app/main.cpp b/codon/app/main.cpp index 8d6525a2..8019fcc5 100644 --- a/codon/app/main.cpp +++ b/codon/app/main.cpp @@ -1,10 +1,13 @@ #include +#include #include +#include #include #include #include #include "codon/compiler/compiler.h" +#include "codon/compiler/jit.h" #include "codon/parser/parser.h" #include "codon/util/common.h" #include "llvm/Support/CommandLine.h" @@ -49,7 +52,21 @@ enum OptMode { Debug, Release }; int docMode(const std::vector &args, const std::string &argv0) { llvm::cl::ParseCommandLineOptions(args.size(), args.data()); - codon::generateDocstr(argv0); + std::vector files; + for (std::string line; std::getline(std::cin, line);) + files.push_back(line); + + auto compiler = std::make_unique(args[0]); + std::string result; + if (auto err = compiler->docgen(files, &result)) { + for (auto &msg : err.messages) { + codon::compilationError(msg.msg, msg.file, msg.line, msg.col, + /*terminate=*/false); + } + return EXIT_FAILURE; + } + + fmt::print("{}\n", result); return EXIT_SUCCESS; } @@ -148,7 +165,27 @@ int runMode(const std::vector &args) { return EXIT_SUCCESS; } -int jitMode(const std::string &argv0) { return codon::jitLoop(argv0); } +int jitMode(const std::vector &args) { + codon::jit::JIT jit(args[0]); + if (auto err = jit.init()) { + codon::compilationError("failed to initialize JIT: " + err.getMessage()); + } + fmt::print(">>> Codon JIT v{} <<<\n", CODON_VERSION); + std::string code; + for (std::string line; std::getline(std::cin, line);) { + if (line != "#%%") { + code += line + "\n"; + } else { + jit.exec(code); + code = ""; + fmt::print("\n\n[done]\n\n"); + fflush(stdout); + } + } + if (!code.empty()) + jit.exec(code); + return EXIT_SUCCESS; +} int buildMode(const std::vector &args) { llvm::cl::list libs( @@ -260,7 +297,8 @@ int main(int argc, const char **argv) { return docMode(args, oldArgv0); } if (mode == "jit") { - return jitMode(args[0]); + args[0] = argv0.data(); + return jitMode(args); } return otherMode({argv, argv + argc}); } diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index 4b2a0899..440a7b7b 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -13,11 +13,12 @@ namespace codon { Compiler::Compiler(const std::string &argv0, bool debug, - const std::vector &disabledPasses) - : debug(debug), input(), plm(std::make_unique()), + const std::vector &disabledPasses, bool isTest) + : argv0(argv0), debug(debug), input(), plm(std::make_unique()), cache(std::make_unique(argv0)), module(std::make_unique()), - pm(std::make_unique(debug, disabledPasses)), + pm(std::make_unique(debug && !isTest, + disabledPasses)), llvisitor(std::make_unique(debug)) { cache->module = module.get(); module->setCache(cache.get()); @@ -75,7 +76,7 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code, t4.log(); } catch (const exc::ParserException &e) { auto result = Compiler::ParserError::failure(); - for (int i = 0; i < e.messages.size(); i++) { + for (unsigned i = 0; i < e.messages.size(); i++) { if (!e.messages[i].empty()) result.messages.push_back({e.messages[i], e.locations[i].file, e.locations[i].line, e.locations[i].col}); @@ -105,4 +106,22 @@ void Compiler::compile() { llvisitor->visit(module.get()); } +Compiler::ParserError Compiler::docgen(const std::vector &files, + std::string *output) { + try { + auto j = ast::DocVisitor::apply(argv0, files); + if (output) + *output = j->toString(); + } catch (exc::ParserException &e) { + auto result = Compiler::ParserError::failure(); + for (unsigned i = 0; i < e.messages.size(); i++) { + if (!e.messages[i].empty()) + result.messages.push_back({e.messages[i], e.locations[i].file, + e.locations[i].line, e.locations[i].col}); + } + return result; + } + return Compiler::ParserError::success(); +} + } // namespace codon diff --git a/codon/compiler/compiler.h b/codon/compiler/compiler.h index 5ca68dea..94e20737 100644 --- a/codon/compiler/compiler.h +++ b/codon/compiler/compiler.h @@ -34,6 +34,7 @@ public: }; private: + std::string argv0; bool debug; std::string input; std::unique_ptr plm; @@ -48,7 +49,7 @@ private: public: Compiler(const std::string &argv0, bool debug = false, - const std::vector &disabledPasses = {}); + const std::vector &disabledPasses = {}, bool isTest = false); std::string getInput() const { return input; } PluginManager *getPluginManager() const { return plm.get(); } @@ -66,6 +67,7 @@ public: int testFlags = 0, const std::unordered_map &defines = {}); void compile(); + ParserError docgen(const std::vector &files, std::string *output); }; } // namespace codon diff --git a/codon/compiler/jit.cpp b/codon/compiler/jit.cpp index 5e9696a8..0632b439 100644 --- a/codon/compiler/jit.cpp +++ b/codon/compiler/jit.cpp @@ -1,5 +1,11 @@ #include "jit.h" +#include "codon/parser/peg/peg.h" +#include "codon/parser/visitors/doc/doc.h" +#include "codon/parser/visitors/format/format.h" +#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/visitors/translate/translate.h" +#include "codon/parser/visitors/typecheck/typecheck.h" #include "codon/runtime/lib.h" namespace codon { @@ -8,44 +14,54 @@ namespace { typedef int MainFunc(int, char **); typedef void InputFunc(); -Status statusFromError(llvm::Error err) { - return Status(Status::Code::LLVM_ERROR, llvm::toString(std::move(err))); +Error fromLLVMError(llvm::Error err) { + return Error(Error::Code::LLVM_ERROR, llvm::toString(std::move(err))); } + +const std::string JIT_FILENAME = ""; } // namespace -const Status Status::OK = Status(Status::Code::SUCCESS); +const Error Error::NONE = Error(Error::Code::SUCCESS); JIT::JIT(const std::string &argv0) - : cache(std::make_shared(argv0)), module(nullptr), - pm(std::make_unique(/*debug=*/true)), - plm(std::make_unique()), - llvisitor(std::make_unique(/*debug=*/true, /*jit=*/true)) { + : compiler(std::make_unique(argv0, /*debug=*/true)) { if (auto e = Engine::create()) { engine = std::move(e.get()); } else { engine = {}; seqassert(false, "JIT engine creation error"); } - llvisitor->setPluginManager(plm.get()); } -Status JIT::init() { +Error JIT::init() { + auto *cache = compiler->getCache(); + auto *module = compiler->getModule(); + auto *llvisitor = compiler->getLLVMVisitor(); + + auto transformed = ast::SimplifyVisitor::apply( + cache, std::make_shared(), JIT_FILENAME, {}); + auto typechecked = ast::TypecheckVisitor::apply(cache, std::move(transformed)); + ast::TranslateVisitor::apply(cache, std::move(typechecked)); + cache->isJit = true; // we still need main(), so set isJit after it has been set + module->setSrcInfo({JIT_FILENAME, 0, 0, 0}); + module->accept(*llvisitor); auto pair = llvisitor->takeModule(); if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)})) - return statusFromError(std::move(err)); + return fromLLVMError(std::move(err)); auto func = engine->lookup("main"); if (auto err = func.takeError()) - return statusFromError(std::move(err)); + return fromLLVMError(std::move(err)); auto *main = (MainFunc *)func->getAddress(); (*main)(0, nullptr); - return Status::OK; + return Error::NONE; } -Status JIT::run(const ir::Func *input, const std::vector &newGlobals) { +Error JIT::run(const ir::Func *input, const std::vector &newGlobals) { + auto *llvisitor = compiler->getLLVMVisitor(); const std::string name = ir::LLVMVisitor::getNameForFunction(input); llvisitor->registerGlobal(input); for (auto *var : newGlobals) { @@ -60,20 +76,65 @@ Status JIT::run(const ir::Func *input, const std::vector &newGlobals) llvm::StripDebugInfo(*pair.first); // TODO: needed? if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)})) - return statusFromError(std::move(err)); + return fromLLVMError(std::move(err)); auto func = engine->lookup(name); if (auto err = func.takeError()) - return statusFromError(std::move(err)); + return fromLLVMError(std::move(err)); auto *repl = (InputFunc *)func->getAddress(); try { (*repl)(); } catch (const seq_jit_error &err) { - return Status(Status::Code::RUNTIME_ERROR, err.what(), err.getType(), - SrcInfo(err.getFile(), err.getLine(), err.getCol(), /*len=*/0)); + return Error(Error::Code::RUNTIME_ERROR, err.what(), err.getType(), + SrcInfo(err.getFile(), err.getLine(), err.getCol(), /*len=*/0)); } - return Status::OK; + return Error::NONE; +} + +Error JIT::exec(const std::string &code) { + auto *cache = compiler->getCache(); + ast::StmtPtr node = ast::parseCode(cache, JIT_FILENAME, code, /*startLine=*/0); + + auto sctx = cache->imports[MAIN_IMPORT].ctx; + auto preamble = std::make_shared(); + auto s = ast::SimplifyVisitor(sctx, preamble).transform(node); + auto simplified = std::make_shared(); + for (auto &s : preamble->globals) + simplified->stmts.push_back(s); + for (auto &s : preamble->functions) + simplified->stmts.push_back(s); + simplified->stmts.push_back(s); + // TODO: unroll on errors... + + auto typechecked = ast::TypecheckVisitor::apply(cache, simplified); + std::vector globalNames; + for (auto &g : cache->globals) { + if (!g.second) + globalNames.push_back(g.first); + } + // add newly realized functions + std::vector v; + std::vector frs; + for (auto &p : cache->pendingRealizations) { + v.push_back(cache->functions[p.first].ast); + frs.push_back(&cache->functions[p.first].realizations[p.second]->ir); + } + v.push_back(typechecked); + auto func = ast::TranslateVisitor::apply(cache, std::make_shared(v)); + cache->jitCell++; + + std::vector globalVars; + for (auto &g : globalNames) { + seqassert(cache->globals[g], "JIT global {} not set", g); + globalVars.push_back(cache->globals[g]); + std::cout << g << std::endl; + } + for (auto &i : frs) { + seqassert(*i, "JIT fn not set"); + globalVars.push_back(*i); + } + return run(func, globalVars); } } // namespace jit diff --git a/codon/compiler/jit.h b/codon/compiler/jit.h index 65607e4b..4b34bf6a 100644 --- a/codon/compiler/jit.h +++ b/codon/compiler/jit.h @@ -4,6 +4,7 @@ #include #include +#include "codon/compiler/compiler.h" #include "codon/compiler/engine.h" #include "codon/parser/cache.h" #include "codon/sir/llvm/llvisitor.h" @@ -13,7 +14,7 @@ namespace codon { namespace jit { -class Status { +class Error { public: enum Code { SUCCESS = 0, @@ -29,33 +30,30 @@ private: SrcInfo src; public: - explicit Status(Code code = Code::SUCCESS, const std::string &message = "", - const std::string &type = "", const SrcInfo &src = {}) + explicit Error(Code code = Code::SUCCESS, const std::string &message = "", + const std::string &type = "", const SrcInfo &src = {}) : code(code), message(message), type(type), src(src) {} - operator bool() const { return code == Code::SUCCESS; } + operator bool() const { return code != Code::SUCCESS; } Code getCode() const { return code; } std::string getType() const { return type; } std::string getMessage() const { return message; } SrcInfo getSrcInfo() const { return src; } - static const Status OK; + static const Error NONE; }; class JIT { private: - std::shared_ptr cache; - ir::Module *module; - std::unique_ptr pm; - std::unique_ptr plm; - std::unique_ptr llvisitor; + std::unique_ptr compiler; std::unique_ptr engine; public: explicit JIT(const std::string &argv0); - Status init(); - Status run(const ir::Func *input, const std::vector &newGlobals = {}); + Error init(); + Error run(const ir::Func *input, const std::vector &newGlobals = {}); + Error exec(const std::string &code); }; } // namespace jit diff --git a/test/main.cpp b/test/main.cpp index d757503c..4ba0e6c8 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -13,14 +14,12 @@ #include #include "codon/parser/common.h" -#include "codon/parser/parser.h" -#include "codon/sir/llvm/llvisitor.h" -#include "codon/sir/transform/manager.h" -#include "codon/sir/transform/pass.h" +#include "codon/compiler/compiler.h" #include "codon/sir/util/inlining.h" #include "codon/sir/util/irtools.h" #include "codon/sir/util/outlining.h" #include "codon/util/common.h" + #include "gtest/gtest.h" using namespace codon; @@ -178,22 +177,27 @@ public: close(out_pipe[1]); auto file = getFilename(get<0>(GetParam())); + bool debug = get<1>(GetParam()); auto code = get<3>(GetParam()); auto startLine = get<4>(GetParam()); - auto *module = parse(argv0, file, code, !code.empty(), - /* isTest */ 1 + get<5>(GetParam()), startLine); - if (!module) + int testFlags = 1 + get<5>(GetParam()); + + auto compiler = std::make_unique(argv0, debug, /*disabledPasses=*/std::vector{}, /*isTest=*/true); + if (auto err = code.empty() ? compiler->parseFile(file, testFlags) : compiler->parseCode(file, code, startLine, testFlags)) { + for (auto &msg : err.messages) { + getLogger().level = 0; + printf("%s\n", msg.msg.c_str()); + } + fflush(stdout); exit(EXIT_FAILURE); + } - ir::transform::PassManager pm; - pm.registerPass(std::make_unique()); - pm.registerPass(std::make_unique()); - pm.run(module); - - ir::LLVMVisitor visitor(/*debug=*/get<1>(GetParam())); - visitor.visit(module); - visitor.run({file}); + auto *pm = compiler->getPassManager(); + pm->registerPass(std::make_unique()); + pm->registerPass(std::make_unique()); + compiler->compile(); + compiler->getLLVMVisitor()->run({file}); fflush(stdout); exit(EXIT_SUCCESS); } else {