Fix JIT; Fix #136

typecheck-v2
Ibrahim Numanagić 2024-11-05 12:31:44 -08:00
parent b7768ea688
commit a2c5219570
5 changed files with 59 additions and 24 deletions

View File

@ -27,6 +27,12 @@ types::Type *Instr::doGetType() const { return getModule()->getNoneType(); }
const char AssignInstr::NodeId = 0;
AssignInstr::AssignInstr(Var *lhs, Value *rhs, std::string name)
: AcceptorExtend(std::move(name)), lhs(lhs), rhs(rhs) {
if (!lhs->getType())
LOG("->");
}
int AssignInstr::doReplaceUsedValue(id_t id, Value *newValue) {
if (rhs->getId() == id) {
rhs = newValue;

View File

@ -42,8 +42,7 @@ public:
/// @param rhs the right-hand side
/// @param field the field being set, may be empty
/// @param name the instruction's name
AssignInstr(Var *lhs, Value *rhs, std::string name = "")
: AcceptorExtend(std::move(name)), lhs(lhs), rhs(rhs) {}
AssignInstr(Var *lhs, Value *rhs, std::string name = "");
/// @return the left-hand side
Var *getLhs() { return lhs; }

View File

@ -93,8 +93,10 @@ llvm::Expected<ir::Func *> JIT::compile(const std::string &code,
ast::Stmt *node = ast::parseCode(cache, file.empty() ? JIT_FILENAME : file, code,
/*startLine=*/line);
ast::Stmt **e = &node;
while (auto se = ast::cast<ast::SuiteStmt>(*e))
while (auto se = ast::cast<ast::SuiteStmt>(*e)) {
if (se->empty()) break;
e = &se->back();
}
if (e)
if (auto ex = ast::cast<ast::ExprStmt>(*e)) {
*e = cache->N<ast::ExprStmt>(cache->N<ast::CallExpr>(

View File

@ -56,9 +56,12 @@ ir::Func *TranslateVisitor::apply(Cache *cache, Stmt *stmts) {
void TranslateVisitor::translateStmts(Stmt *stmts) {
for (auto &[name, g] : ctx->cache->globals)
if (/*g.first &&*/ !g.second) {
ir::types::Type *vt = nullptr;
if (auto t = ctx->cache->typeCtx->forceFind(name)->getType())
vt = getType(t);
g.second = name == VAR_ARGV ? ctx->cache->codegenCtx->getModule()->getArgVar()
: ctx->cache->codegenCtx->getModule()->N<ir::Var>(
SrcInfo(), nullptr, true, false, name);
SrcInfo(), vt, true, false, name);
ctx->cache->codegenCtx->add(TranslateItem::Var, name, g.second);
}
TranslateVisitor(ctx->cache->codegenCtx).transform(stmts);
@ -432,20 +435,25 @@ void TranslateVisitor::visit(AssignStmt *stmt) {
return;
auto lei = cast<IdExpr>(stmt->getLhs());
if (stmt->isUpdate()) {
seqassert(lei, "expected IdExpr, got {}", *(stmt->getLhs()));
auto val = ctx->find(lei->getValue());
seqassert(val && val->getVar(), "{} is not a variable", lei->getValue());
result = make<ir::AssignInstr>(stmt, val->getVar(), transform(stmt->getRhs()));
return;
}
seqassert(lei, "expected IdExpr, got {}", *(stmt->getLhs()));
auto var = lei->getValue();
auto isGlobal = in(ctx->cache->globals, var);
ir::Var *v = nullptr;
if (stmt->isUpdate()) {
auto val = ctx->find(lei->getValue());
seqassert(val && val->getVar(), "{} is not a variable", lei->getValue());
v = val->getVar();
if (!v->getType()) {
v->setSrcInfo(stmt->getSrcInfo());
v->setType(getType(stmt->getRhs()->getType()));
}
result = make<ir::AssignInstr>(stmt, v, transform(stmt->getRhs()));
return;
}
if (!stmt->getLhs()->getType()->isInstantiated() ||
(stmt->getLhs()->getType()->is(TYPE_TYPE))) {
// LOG("{} {}", getSrcInfo(), stmt->toString(0));
@ -473,8 +481,9 @@ void TranslateVisitor::visit(AssignStmt *stmt) {
return;
}
if (stmt->getRhs())
if (stmt->getRhs()) {
result = make<ir::AssignInstr>(stmt, v, transform(stmt->getRhs()));
}
}
void TranslateVisitor::visit(AssignMemberStmt *stmt) {

View File

@ -33,6 +33,8 @@ if "CODON_PATH" not in os.environ:
"Cannot locate Codon. Please install Codon or set CODON_PATH."
)
debug_override = int(os.environ.get("CODON_JIT_DEBUG", 0))
pod_conversions = {
type(None): "pyobj",
int: "int",
@ -83,8 +85,8 @@ def _codon_type(arg, **kwargs):
j = ",".join(_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
return "{}[{}]".format(s, j)
debug = kwargs.get("debug", None)
if debug:
debug = kwargs.get("debug", 0)
if debug > 0:
msg = "cannot convert " + t.__name__
if msg not in _error_msgs:
print("[python]", msg, file=sys.stderr)
@ -104,7 +106,9 @@ def _reset_jit():
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
"setup_decorator()\n"
)
_jit.execute(init_code, "", 0, False)
if debug_override == 2:
print(f"[jit_debug] execute:\n{init_code}", file=sys.stderr)
_jit.execute(init_code, "", 0, int(debug_override > 0))
return _jit
@ -185,7 +189,9 @@ def convert(t):
name, ", ".join("a{}".format(i) for i in range(len(slots)))
)
_jit.execute(code, "", 0, False)
if debug_override == 2:
print(f"[jit_debug] execute:\n{code}", file=sys.stderr)
_jit.execute(code, "", 0, int(debug_override > 0))
custom_conversions[t] = name
return t
@ -196,38 +202,45 @@ def _jit_register_fn(f, pyvars, debug):
fn, fl = "<internal>", 1
if hasattr(f, "__code__"):
fn, fl = f.__code__.co_filename, f.__code__.co_firstlineno
_jit.execute(obj_str, fn, fl, 1 if debug else 0)
if debug == 2:
print(f"[jit_debug] execute:\n{obj_str}", file=sys.stderr)
_jit.execute(obj_str, fn, fl, int(debug > 0))
return obj_name
except JITError:
_reset_jit()
raise
def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs):
def _jit_callback_fn(obj_name, module, debug=0, sample_size=5, pyvars=None, *args, **kwargs):
try:
args = (*args, *kwargs.values())
types = _codon_types(args, debug=debug, sample_size=sample_size)
if debug:
if debug > 0:
print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr)
return _jit.run_wrapper(
obj_name, list(types), module, list(pyvars), args, 1 if debug else 0
obj_name, list(types), module, list(pyvars), args, int(debug > 0)
)
except JITError:
_reset_jit()
raise
def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None):
def _jit_str_fn(fstr, debug=0, sample_size=5, pyvars=None):
obj_name = _jit_register_fn(fstr, pyvars, debug)
def wrapped(*args, **kwargs):
return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs)
return wrapped
def jit(fn=None, debug=None, sample_size=5, pyvars=None):
def jit(fn=None, debug=0, sample_size=5, pyvars=None):
if debug is None:
debug = 0
if not pyvars:
pyvars = []
if not isinstance(pyvars, list):
raise ArgumentError("pyvars must be a list")
if debug_override:
debug = debug_override
if fn and isinstance(fn, str):
return _jit_str_fn(fn, debug, sample_size, pyvars)
@ -240,8 +253,14 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
return _decorate(fn) if fn else _decorate
def execute(code, debug=False):
def execute(code, debug=0):
if debug is None:
debug = 0
if debug_override:
debug = debug_override
try:
if debug == 2:
print(f"[jit_debug] execute:\n{code}", file=sys.stderr)
_jit.execute(code, "<internal>", 0, int(debug))
except JITError:
_reset_jit()