@codon.jit fixes (#401)

* Fix #223

* Fix #188

* Fix #188

* Fix stray import

* Save pyvars

---------

Co-authored-by: A. R. Shajii <ars@ars.me>
pull/420/head
Ibrahim Numanagić 2023-06-09 12:38:49 -07:00 committed by GitHub
parent d1a8d1a79b
commit e95f778df1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 8 deletions

View File

@ -89,9 +89,6 @@ llvm::Error JIT::compile(const ir::Func *input) {
llvm::Expected<ir::Func *> JIT::compile(const std::string &code,
const std::string &file, int line) {
auto *cache = compiler->getCache();
ast::StmtPtr node = ast::parseCode(cache, file.empty() ? JIT_FILENAME : file, code,
/*startLine=*/line);
auto sctx = cache->imports[MAIN_IMPORT].ctx;
auto preamble = std::make_shared<std::vector<ast::StmtPtr>>();
@ -101,6 +98,8 @@ llvm::Expected<ir::Func *> JIT::compile(const std::string &code,
ast::TypeContext bType = *(cache->typeCtx);
ast::TranslateContext bTranslate = *(cache->codegenCtx);
try {
ast::StmtPtr node = ast::parseCode(cache, file.empty() ? JIT_FILENAME : file, code,
/*startLine=*/line);
auto *e = node->getSuite() ? node->getSuite()->lastInBlock() : &node;
if (e)
if (auto ex = const_cast<ast::ExprStmt *>((*e)->getExpr())) {
@ -250,8 +249,9 @@ std::string buildPythonWrapper(const std::string &name, const std::string &wrapn
wrap << "a" << i;
}
for (unsigned i = 0; i < pyVars.size(); i++) {
wrap << ", "
<< "py" << i;
if (i > 0 || types.size() > 0)
wrap << ", ";
wrap << "py" << i;
}
wrap << ").__to_py__()\n";

View File

@ -145,6 +145,10 @@ This also allows imported Python modules to be accessed by Codon. All `pyvars`
are passed as Python objects. Note that JIT'd functions can call each other
by default.
{% hint style="info" %}
`pyvars` takes in variable names as strings, not the variables themselves.
{% endhint %}
# Debugging
`@codon.jit` takes an optional `debug` parameter that can be used to print debug

View File

@ -8,7 +8,6 @@ import os
import functools
import itertools
import ast
import shutil
import astunparse
from pathlib import Path
@ -130,9 +129,13 @@ def _obj_to_str(obj, **kwargs) -> str:
lines = inspect.getsourcelines(obj)[0]
extra_spaces = lines[0].find("@")
obj_str = "".join(l[extra_spaces:] for l in lines[1:])
if kwargs.get("pyvars", None):
pyvars = kwargs.get("pyvars", None)
if pyvars:
for i in pyvars:
if not isinstance(i, str):
raise ValueError("pyvars only takes string literals")
node = ast.fix_missing_locations(
RewriteFunctionArgs(kwargs["pyvars"]).visit(ast.parse(obj_str))
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str))
)
obj_str = astunparse.unparse(node)
else: