1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Update JIT

This commit is contained in:
A. R. Shajii 2021-11-01 19:10:33 -04:00
parent 5503e71ace
commit ddc47d453c
6 changed files with 175 additions and 53 deletions

View File

@ -1,10 +1,13 @@
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
#include "codon/compiler/compiler.h"
#include "codon/compiler/jit.h"
#include "codon/parser/parser.h"
#include "codon/util/common.h"
#include "llvm/Support/CommandLine.h"
@ -49,7 +52,21 @@ enum OptMode { Debug, Release };
int docMode(const std::vector<const char *> &args, const std::string &argv0) {
llvm::cl::ParseCommandLineOptions(args.size(), args.data());
codon::generateDocstr(argv0);
std::vector<std::string> files;
for (std::string line; std::getline(std::cin, line);)
files.push_back(line);
auto compiler = std::make_unique<codon::Compiler>(args[0]);
std::string result;
if (auto err = compiler->docgen(files, &result)) {
for (auto &msg : err.messages) {
codon::compilationError(msg.msg, msg.file, msg.line, msg.col,
/*terminate=*/false);
}
return EXIT_FAILURE;
}
fmt::print("{}\n", result);
return EXIT_SUCCESS;
}
@ -148,7 +165,27 @@ int runMode(const std::vector<const char *> &args) {
return EXIT_SUCCESS;
}
int jitMode(const std::string &argv0) { return codon::jitLoop(argv0); }
int jitMode(const std::vector<const char *> &args) {
codon::jit::JIT jit(args[0]);
if (auto err = jit.init()) {
codon::compilationError("failed to initialize JIT: " + err.getMessage());
}
fmt::print(">>> Codon JIT v{} <<<\n", CODON_VERSION);
std::string code;
for (std::string line; std::getline(std::cin, line);) {
if (line != "#%%") {
code += line + "\n";
} else {
jit.exec(code);
code = "";
fmt::print("\n\n[done]\n\n");
fflush(stdout);
}
}
if (!code.empty())
jit.exec(code);
return EXIT_SUCCESS;
}
int buildMode(const std::vector<const char *> &args) {
llvm::cl::list<std::string> libs(
@ -260,7 +297,8 @@ int main(int argc, const char **argv) {
return docMode(args, oldArgv0);
}
if (mode == "jit") {
return jitMode(args[0]);
args[0] = argv0.data();
return jitMode(args);
}
return otherMode({argv, argv + argc});
}

View File

@ -13,11 +13,12 @@
namespace codon {
Compiler::Compiler(const std::string &argv0, bool debug,
const std::vector<std::string> &disabledPasses)
: debug(debug), input(), plm(std::make_unique<PluginManager>()),
const std::vector<std::string> &disabledPasses, bool isTest)
: argv0(argv0), debug(debug), input(), plm(std::make_unique<PluginManager>()),
cache(std::make_unique<ast::Cache>(argv0)),
module(std::make_unique<ir::Module>()),
pm(std::make_unique<ir::transform::PassManager>(debug, disabledPasses)),
pm(std::make_unique<ir::transform::PassManager>(debug && !isTest,
disabledPasses)),
llvisitor(std::make_unique<ir::LLVMVisitor>(debug)) {
cache->module = module.get();
module->setCache(cache.get());
@ -75,7 +76,7 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code,
t4.log();
} catch (const exc::ParserException &e) {
auto result = Compiler::ParserError::failure();
for (int i = 0; i < e.messages.size(); i++) {
for (unsigned i = 0; i < e.messages.size(); i++) {
if (!e.messages[i].empty())
result.messages.push_back({e.messages[i], e.locations[i].file,
e.locations[i].line, e.locations[i].col});
@ -105,4 +106,22 @@ void Compiler::compile() {
llvisitor->visit(module.get());
}
Compiler::ParserError Compiler::docgen(const std::vector<std::string> &files,
std::string *output) {
try {
auto j = ast::DocVisitor::apply(argv0, files);
if (output)
*output = j->toString();
} catch (exc::ParserException &e) {
auto result = Compiler::ParserError::failure();
for (unsigned i = 0; i < e.messages.size(); i++) {
if (!e.messages[i].empty())
result.messages.push_back({e.messages[i], e.locations[i].file,
e.locations[i].line, e.locations[i].col});
}
return result;
}
return Compiler::ParserError::success();
}
} // namespace codon

View File

@ -34,6 +34,7 @@ public:
};
private:
std::string argv0;
bool debug;
std::string input;
std::unique_ptr<PluginManager> plm;
@ -48,7 +49,7 @@ private:
public:
Compiler(const std::string &argv0, bool debug = false,
const std::vector<std::string> &disabledPasses = {});
const std::vector<std::string> &disabledPasses = {}, bool isTest = false);
std::string getInput() const { return input; }
PluginManager *getPluginManager() const { return plm.get(); }
@ -66,6 +67,7 @@ public:
int testFlags = 0,
const std::unordered_map<std::string, std::string> &defines = {});
void compile();
ParserError docgen(const std::vector<std::string> &files, std::string *output);
};
} // namespace codon

View File

@ -1,5 +1,11 @@
#include "jit.h"
#include "codon/parser/peg/peg.h"
#include "codon/parser/visitors/doc/doc.h"
#include "codon/parser/visitors/format/format.h"
#include "codon/parser/visitors/simplify/simplify.h"
#include "codon/parser/visitors/translate/translate.h"
#include "codon/parser/visitors/typecheck/typecheck.h"
#include "codon/runtime/lib.h"
namespace codon {
@ -8,44 +14,54 @@ namespace {
typedef int MainFunc(int, char **);
typedef void InputFunc();
Status statusFromError(llvm::Error err) {
return Status(Status::Code::LLVM_ERROR, llvm::toString(std::move(err)));
Error fromLLVMError(llvm::Error err) {
return Error(Error::Code::LLVM_ERROR, llvm::toString(std::move(err)));
}
const std::string JIT_FILENAME = "<jit>";
} // namespace
const Status Status::OK = Status(Status::Code::SUCCESS);
const Error Error::NONE = Error(Error::Code::SUCCESS);
JIT::JIT(const std::string &argv0)
: cache(std::make_shared<ast::Cache>(argv0)), module(nullptr),
pm(std::make_unique<ir::transform::PassManager>(/*debug=*/true)),
plm(std::make_unique<PluginManager>()),
llvisitor(std::make_unique<ir::LLVMVisitor>(/*debug=*/true, /*jit=*/true)) {
: compiler(std::make_unique<Compiler>(argv0, /*debug=*/true)) {
if (auto e = Engine::create()) {
engine = std::move(e.get());
} else {
engine = {};
seqassert(false, "JIT engine creation error");
}
llvisitor->setPluginManager(plm.get());
}
Status JIT::init() {
Error JIT::init() {
auto *cache = compiler->getCache();
auto *module = compiler->getModule();
auto *llvisitor = compiler->getLLVMVisitor();
auto transformed = ast::SimplifyVisitor::apply(
cache, std::make_shared<ast::SuiteStmt>(), JIT_FILENAME, {});
auto typechecked = ast::TypecheckVisitor::apply(cache, std::move(transformed));
ast::TranslateVisitor::apply(cache, std::move(typechecked));
cache->isJit = true; // we still need main(), so set isJit after it has been set
module->setSrcInfo({JIT_FILENAME, 0, 0, 0});
module->accept(*llvisitor);
auto pair = llvisitor->takeModule();
if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)}))
return statusFromError(std::move(err));
return fromLLVMError(std::move(err));
auto func = engine->lookup("main");
if (auto err = func.takeError())
return statusFromError(std::move(err));
return fromLLVMError(std::move(err));
auto *main = (MainFunc *)func->getAddress();
(*main)(0, nullptr);
return Status::OK;
return Error::NONE;
}
Status JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
Error JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
auto *llvisitor = compiler->getLLVMVisitor();
const std::string name = ir::LLVMVisitor::getNameForFunction(input);
llvisitor->registerGlobal(input);
for (auto *var : newGlobals) {
@ -60,20 +76,65 @@ Status JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals)
llvm::StripDebugInfo(*pair.first); // TODO: needed?
if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)}))
return statusFromError(std::move(err));
return fromLLVMError(std::move(err));
auto func = engine->lookup(name);
if (auto err = func.takeError())
return statusFromError(std::move(err));
return fromLLVMError(std::move(err));
auto *repl = (InputFunc *)func->getAddress();
try {
(*repl)();
} catch (const seq_jit_error &err) {
return Status(Status::Code::RUNTIME_ERROR, err.what(), err.getType(),
SrcInfo(err.getFile(), err.getLine(), err.getCol(), /*len=*/0));
return Error(Error::Code::RUNTIME_ERROR, err.what(), err.getType(),
SrcInfo(err.getFile(), err.getLine(), err.getCol(), /*len=*/0));
}
return Status::OK;
return Error::NONE;
}
Error JIT::exec(const std::string &code) {
auto *cache = compiler->getCache();
ast::StmtPtr node = ast::parseCode(cache, JIT_FILENAME, code, /*startLine=*/0);
auto sctx = cache->imports[MAIN_IMPORT].ctx;
auto preamble = std::make_shared<ast::SimplifyVisitor::Preamble>();
auto s = ast::SimplifyVisitor(sctx, preamble).transform(node);
auto simplified = std::make_shared<ast::SuiteStmt>();
for (auto &s : preamble->globals)
simplified->stmts.push_back(s);
for (auto &s : preamble->functions)
simplified->stmts.push_back(s);
simplified->stmts.push_back(s);
// TODO: unroll on errors...
auto typechecked = ast::TypecheckVisitor::apply(cache, simplified);
std::vector<std::string> globalNames;
for (auto &g : cache->globals) {
if (!g.second)
globalNames.push_back(g.first);
}
// add newly realized functions
std::vector<ast::StmtPtr> v;
std::vector<ir::Func **> frs;
for (auto &p : cache->pendingRealizations) {
v.push_back(cache->functions[p.first].ast);
frs.push_back(&cache->functions[p.first].realizations[p.second]->ir);
}
v.push_back(typechecked);
auto func = ast::TranslateVisitor::apply(cache, std::make_shared<ast::SuiteStmt>(v));
cache->jitCell++;
std::vector<ir::Var *> globalVars;
for (auto &g : globalNames) {
seqassert(cache->globals[g], "JIT global {} not set", g);
globalVars.push_back(cache->globals[g]);
std::cout << g << std::endl;
}
for (auto &i : frs) {
seqassert(*i, "JIT fn not set");
globalVars.push_back(*i);
}
return run(func, globalVars);
}
} // namespace jit

View File

@ -4,6 +4,7 @@
#include <string>
#include <vector>
#include "codon/compiler/compiler.h"
#include "codon/compiler/engine.h"
#include "codon/parser/cache.h"
#include "codon/sir/llvm/llvisitor.h"
@ -13,7 +14,7 @@
namespace codon {
namespace jit {
class Status {
class Error {
public:
enum Code {
SUCCESS = 0,
@ -29,33 +30,30 @@ private:
SrcInfo src;
public:
explicit Status(Code code = Code::SUCCESS, const std::string &message = "",
const std::string &type = "", const SrcInfo &src = {})
explicit Error(Code code = Code::SUCCESS, const std::string &message = "",
const std::string &type = "", const SrcInfo &src = {})
: code(code), message(message), type(type), src(src) {}
operator bool() const { return code == Code::SUCCESS; }
operator bool() const { return code != Code::SUCCESS; }
Code getCode() const { return code; }
std::string getType() const { return type; }
std::string getMessage() const { return message; }
SrcInfo getSrcInfo() const { return src; }
static const Status OK;
static const Error NONE;
};
class JIT {
private:
std::shared_ptr<ast::Cache> cache;
ir::Module *module;
std::unique_ptr<ir::transform::PassManager> pm;
std::unique_ptr<PluginManager> plm;
std::unique_ptr<ir::LLVMVisitor> llvisitor;
std::unique_ptr<Compiler> compiler;
std::unique_ptr<Engine> engine;
public:
explicit JIT(const std::string &argv0);
Status init();
Status run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals = {});
Error init();
Error run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals = {});
Error exec(const std::string &code);
};
} // namespace jit

View File

@ -1,5 +1,6 @@
#include <algorithm>
#include <dirent.h>
#include <cstdio>
#include <fcntl.h>
#include <fstream>
#include <gc.h>
@ -13,14 +14,12 @@
#include <vector>
#include "codon/parser/common.h"
#include "codon/parser/parser.h"
#include "codon/sir/llvm/llvisitor.h"
#include "codon/sir/transform/manager.h"
#include "codon/sir/transform/pass.h"
#include "codon/compiler/compiler.h"
#include "codon/sir/util/inlining.h"
#include "codon/sir/util/irtools.h"
#include "codon/sir/util/outlining.h"
#include "codon/util/common.h"
#include "gtest/gtest.h"
using namespace codon;
@ -178,22 +177,27 @@ public:
close(out_pipe[1]);
auto file = getFilename(get<0>(GetParam()));
bool debug = get<1>(GetParam());
auto code = get<3>(GetParam());
auto startLine = get<4>(GetParam());
auto *module = parse(argv0, file, code, !code.empty(),
/* isTest */ 1 + get<5>(GetParam()), startLine);
if (!module)
int testFlags = 1 + get<5>(GetParam());
auto compiler = std::make_unique<Compiler>(argv0, debug, /*disabledPasses=*/std::vector<std::string>{}, /*isTest=*/true);
if (auto err = code.empty() ? compiler->parseFile(file, testFlags) : compiler->parseCode(file, code, startLine, testFlags)) {
for (auto &msg : err.messages) {
getLogger().level = 0;
printf("%s\n", msg.msg.c_str());
}
fflush(stdout);
exit(EXIT_FAILURE);
}
ir::transform::PassManager pm;
pm.registerPass(std::make_unique<TestOutliner>());
pm.registerPass(std::make_unique<TestInliner>());
pm.run(module);
ir::LLVMVisitor visitor(/*debug=*/get<1>(GetParam()));
visitor.visit(module);
visitor.run({file});
auto *pm = compiler->getPassManager();
pm->registerPass(std::make_unique<TestOutliner>());
pm->registerPass(std::make_unique<TestInliner>());
compiler->compile();
compiler->getLLVMVisitor()->run({file});
fflush(stdout);
exit(EXIT_SUCCESS);
} else {