1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

JIT argument order fix (#639)

* Fix argument ordering in JIT

* Format

* Update JIT tests

* Fix JIT test
This commit is contained in:
A. R. Shajii 2025-03-18 10:45:34 -04:00 committed by GitHub
parent b3f6c12d57
commit 93fb3d53e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 77 additions and 59 deletions

View File

@ -1,5 +1,7 @@
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
__all__ = ["jit", "convert", "JITError"]
__all__ = [
"jit", "convert", "JITError", "JITWrapper", "_jit_register_fn", "_jit"
]
from .decorator import jit, convert, execute, JITError
from .decorator import jit, convert, execute, JITError, JITWrapper, _jit_register_fn, _jit_callback_fn, _jit

View File

@ -23,16 +23,14 @@ if "CODON_PATH" not in os.environ:
if codon_lib_path:
codon_path.append(Path(codon_lib_path).parent / "stdlib")
codon_path.append(
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib"
)
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib")
for path in codon_path:
if path.exists():
os.environ["CODON_PATH"] = str(path.resolve())
break
else:
raise RuntimeError(
"Cannot locate Codon. Please install Codon or set CODON_PATH."
)
"Cannot locate Codon. Please install Codon or set CODON_PATH.")
pod_conversions = {
type(None): "pyobj",
@ -61,7 +59,6 @@ pod_conversions = {
custom_conversions = {}
_error_msgs = set()
def _common_type(t, debug, sample_size):
sub, is_optional = None, False
for i in itertools.islice(t, sample_size):
@ -76,7 +73,6 @@ def _common_type(t, debug, sample_size):
sub = "Optional[{}]".format(sub)
return sub if sub else "pyobj"
def _codon_type(arg, **kwargs):
t = type(arg)
@ -88,11 +84,11 @@ def _codon_type(arg, **kwargs):
if issubclass(t, set):
return "Set[{}]".format(_common_type(arg, **kwargs))
if issubclass(t, dict):
return "Dict[{},{}]".format(
_common_type(arg.keys(), **kwargs), _common_type(arg.values(), **kwargs)
)
return "Dict[{},{}]".format(_common_type(arg.keys(), **kwargs),
_common_type(arg.values(), **kwargs))
if issubclass(t, tuple):
return "Tuple[{}]".format(",".join(_codon_type(a, **kwargs) for a in arg))
return "Tuple[{}]".format(",".join(
_codon_type(a, **kwargs) for a in arg))
if issubclass(t, np.ndarray):
if arg.dtype == np.bool_:
dtype = "bool"
@ -134,7 +130,8 @@ def _codon_type(arg, **kwargs):
s = custom_conversions.get(t, "")
if s:
j = ",".join(_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
j = ",".join(
_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
return "{}[{}]".format(s, j)
debug = kwargs.get("debug", None)
@ -145,28 +142,22 @@ def _codon_type(arg, **kwargs):
_error_msgs.add(msg)
return "pyobj"
def _codon_types(args, **kwargs):
return tuple(_codon_type(arg, **kwargs) for arg in args)
def _reset_jit():
global _jit
_jit = JITWrapper()
init_code = (
"from internal.python import "
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
"setup_decorator()\n"
"import numpy as np\n"
"import numpy.pybridge\n"
)
init_code = ("from internal.python import "
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
"setup_decorator()\n"
"import numpy as np\n"
"import numpy.pybridge\n")
_jit.execute(init_code, "", 0, False)
return _jit
_jit = _reset_jit()
class RewriteFunctionArgs(ast.NodeTransformer):
def __init__(self, args):
self.args = args
@ -176,7 +167,6 @@ class RewriteFunctionArgs(ast.NodeTransformer):
node.args.args.append(ast.arg(arg=a, annotation=None))
return node
def _obj_to_str(obj, **kwargs) -> str:
if inspect.isclass(obj):
lines = inspect.getsourcelines(obj)[0]
@ -185,8 +175,10 @@ def _obj_to_str(obj, **kwargs) -> str:
obj_name = obj.__name__
elif callable(obj) or isinstance(obj, str):
is_str = isinstance(obj, str)
lines = [i + '\n' for i in obj.split('\n')] if is_str else inspect.getsourcelines(obj)[0]
if not is_str: lines = lines[1:]
lines = [i + '\n' for i in obj.split('\n')
] if is_str else inspect.getsourcelines(obj)[0]
if not is_str:
lines = lines[1:]
obj_str = textwrap.dedent(''.join(lines))
pyvars = kwargs.get("pyvars", None)
@ -195,8 +187,7 @@ def _obj_to_str(obj, **kwargs) -> str:
if not isinstance(i, str):
raise ValueError("pyvars only takes string literals")
node = ast.fix_missing_locations(
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str))
)
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str)))
obj_str = astunparse.unparse(node)
if is_str:
try:
@ -206,28 +197,23 @@ def _obj_to_str(obj, **kwargs) -> str:
else:
obj_name = obj.__name__
else:
raise TypeError("Function or class expected, got " + type(obj).__name__)
raise TypeError("Function or class expected, got " +
type(obj).__name__)
return obj_name, obj_str.replace("_@par", "@par")
def _parse_decorated(obj, **kwargs):
return _obj_to_str(obj, **kwargs)
return _obj_to_str(obj, **kwargs)
def convert(t):
if not hasattr(t, "__slots__"):
raise JITError("class '{}' does not have '__slots__' attribute".format(str(t)))
raise JITError("class '{}' does not have '__slots__' attribute".format(
str(t)))
name = t.__name__
slots = t.__slots__
code = (
"@tuple\n"
"class "
+ name
+ "["
+ ",".join("T{}".format(i) for i in range(len(slots)))
+ "]:\n"
)
code = ("@tuple\n"
"class " + name + "[" +
",".join("T{}".format(i) for i in range(len(slots))) + "]:\n")
for i, slot in enumerate(slots):
code += " {}: T{}\n".format(slot, i)
@ -235,17 +221,14 @@ def convert(t):
code += " def __from_py__(p: cobj):\n"
for i, slot in enumerate(slots):
code += " a{} = T{}.__from_py__(PyObject_GetAttrString(p, '{}'.ptr))\n".format(
i, i, slot
)
i, i, slot)
code += " return {}({})\n".format(
name, ", ".join("a{}".format(i) for i in range(len(slots)))
)
name, ", ".join("a{}".format(i) for i in range(len(slots))))
_jit.execute(code, "", 0, False)
custom_conversions[t] = name
return t
def _jit_register_fn(f, pyvars, debug):
try:
obj_name, obj_str = _parse_decorated(f, pyvars=pyvars)
@ -258,29 +241,46 @@ def _jit_register_fn(f, pyvars, debug):
_reset_jit()
raise
def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs):
try:
def _jit_callback_fn(fn,
obj_name,
module,
debug=None,
sample_size=5,
pyvars=None,
*args,
**kwargs):
if fn is not None:
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
args = tuple(bound_args.arguments[param] for param in sig.parameters)
else:
args = (*args, *kwargs.values())
try:
types = _codon_types(args, debug=debug, sample_size=sample_size)
if debug:
print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr)
return _jit.run_wrapper(
obj_name, list(types), module, list(pyvars), args, 1 if debug else 0
)
print("[python] {}({})".format(obj_name, list(types)),
file=sys.stderr)
return _jit.run_wrapper(obj_name, list(types), module, list(pyvars),
args, 1 if debug else 0)
except JITError:
_reset_jit()
raise
def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None):
obj_name = _jit_register_fn(fstr, pyvars, debug)
def wrapped(*args, **kwargs):
return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs)
return wrapped
def wrapped(*args, **kwargs):
return _jit_callback_fn(None, obj_name, "__main__", debug, sample_size,
pyvars, *args, **kwargs)
return wrapped
def jit(fn=None, debug=None, sample_size=5, pyvars=None):
if not pyvars:
pyvars = []
if not isinstance(pyvars, list):
raise ArgumentError("pyvars must be a list")
@ -289,12 +289,15 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
def _decorate(f):
obj_name = _jit_register_fn(f, pyvars, debug)
@functools.wraps(f)
def wrapped(*args, **kwargs):
return _jit_callback_fn(obj_name, f.__module__, debug, sample_size, pyvars, *args, **kwargs)
return wrapped
return _decorate(fn) if fn else _decorate
return _jit_callback_fn(f, obj_name, f.__module__, debug,
sample_size, pyvars, *args, **kwargs)
return wrapped
return _decorate(fn) if fn else _decorate
def execute(code, debug=False):
try:

View File

@ -181,3 +181,16 @@ def test_ndarray():
assert np.datetime_data(y.dtype) == ('s', 2)
test_ndarray()
@codon.jit
def e(x=2, y=99):
return 2*x + y
def test_arg_order():
assert e(1, 2) == 4
assert e(1) == 101
assert e(y=10, x=1) == 12
assert e(x=1) == 101
assert e() == 103
test_arg_order()