diff --git a/jit/codon/__init__.py b/jit/codon/__init__.py index 568e4bf9..527e51c5 100644 --- a/jit/codon/__init__.py +++ b/jit/codon/__init__.py @@ -1,5 +1,7 @@ # Copyright (C) 2022-2025 Exaloop Inc. -__all__ = ["jit", "convert", "JITError"] +__all__ = [ + "jit", "convert", "JITError", "JITWrapper", "_jit_register_fn", "_jit" +] -from .decorator import jit, convert, execute, JITError +from .decorator import jit, convert, execute, JITError, JITWrapper, _jit_register_fn, _jit_callback_fn, _jit diff --git a/jit/codon/decorator.py b/jit/codon/decorator.py index 60e34fc9..dc0d1cc2 100644 --- a/jit/codon/decorator.py +++ b/jit/codon/decorator.py @@ -23,16 +23,14 @@ if "CODON_PATH" not in os.environ: if codon_lib_path: codon_path.append(Path(codon_lib_path).parent / "stdlib") codon_path.append( - Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib" - ) + Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib") for path in codon_path: if path.exists(): os.environ["CODON_PATH"] = str(path.resolve()) break else: raise RuntimeError( - "Cannot locate Codon. Please install Codon or set CODON_PATH." - ) + "Cannot locate Codon. Please install Codon or set CODON_PATH.") pod_conversions = { type(None): "pyobj", @@ -61,7 +59,6 @@ pod_conversions = { custom_conversions = {} _error_msgs = set() - def _common_type(t, debug, sample_size): sub, is_optional = None, False for i in itertools.islice(t, sample_size): @@ -76,7 +73,6 @@ def _common_type(t, debug, sample_size): sub = "Optional[{}]".format(sub) return sub if sub else "pyobj" - def _codon_type(arg, **kwargs): t = type(arg) @@ -88,11 +84,11 @@ def _codon_type(arg, **kwargs): if issubclass(t, set): return "Set[{}]".format(_common_type(arg, **kwargs)) if issubclass(t, dict): - return "Dict[{},{}]".format( - _common_type(arg.keys(), **kwargs), _common_type(arg.values(), **kwargs) - ) + return "Dict[{},{}]".format(_common_type(arg.keys(), **kwargs), + _common_type(arg.values(), **kwargs)) if issubclass(t, tuple): - return "Tuple[{}]".format(",".join(_codon_type(a, **kwargs) for a in arg)) + return "Tuple[{}]".format(",".join( + _codon_type(a, **kwargs) for a in arg)) if issubclass(t, np.ndarray): if arg.dtype == np.bool_: dtype = "bool" @@ -134,7 +130,8 @@ def _codon_type(arg, **kwargs): s = custom_conversions.get(t, "") if s: - j = ",".join(_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__) + j = ",".join( + _codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__) return "{}[{}]".format(s, j) debug = kwargs.get("debug", None) @@ -145,28 +142,22 @@ def _codon_type(arg, **kwargs): _error_msgs.add(msg) return "pyobj" - def _codon_types(args, **kwargs): return tuple(_codon_type(arg, **kwargs) for arg in args) - def _reset_jit(): global _jit _jit = JITWrapper() - init_code = ( - "from internal.python import " - "setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n" - "setup_decorator()\n" - "import numpy as np\n" - "import numpy.pybridge\n" - ) + init_code = ("from internal.python import " + "setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n" + "setup_decorator()\n" + "import numpy as np\n" + "import numpy.pybridge\n") _jit.execute(init_code, "", 0, False) return _jit - _jit = _reset_jit() - class RewriteFunctionArgs(ast.NodeTransformer): def __init__(self, args): self.args = args @@ -176,7 +167,6 @@ class RewriteFunctionArgs(ast.NodeTransformer): node.args.args.append(ast.arg(arg=a, annotation=None)) return node - def _obj_to_str(obj, **kwargs) -> str: if inspect.isclass(obj): lines = inspect.getsourcelines(obj)[0] @@ -185,8 +175,10 @@ def _obj_to_str(obj, **kwargs) -> str: obj_name = obj.__name__ elif callable(obj) or isinstance(obj, str): is_str = isinstance(obj, str) - lines = [i + '\n' for i in obj.split('\n')] if is_str else inspect.getsourcelines(obj)[0] - if not is_str: lines = lines[1:] + lines = [i + '\n' for i in obj.split('\n') + ] if is_str else inspect.getsourcelines(obj)[0] + if not is_str: + lines = lines[1:] obj_str = textwrap.dedent(''.join(lines)) pyvars = kwargs.get("pyvars", None) @@ -195,8 +187,7 @@ def _obj_to_str(obj, **kwargs) -> str: if not isinstance(i, str): raise ValueError("pyvars only takes string literals") node = ast.fix_missing_locations( - RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str)) - ) + RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str))) obj_str = astunparse.unparse(node) if is_str: try: @@ -206,28 +197,23 @@ def _obj_to_str(obj, **kwargs) -> str: else: obj_name = obj.__name__ else: - raise TypeError("Function or class expected, got " + type(obj).__name__) + raise TypeError("Function or class expected, got " + + type(obj).__name__) return obj_name, obj_str.replace("_@par", "@par") - def _parse_decorated(obj, **kwargs): - return _obj_to_str(obj, **kwargs) - + return _obj_to_str(obj, **kwargs) def convert(t): if not hasattr(t, "__slots__"): - raise JITError("class '{}' does not have '__slots__' attribute".format(str(t))) + raise JITError("class '{}' does not have '__slots__' attribute".format( + str(t))) name = t.__name__ slots = t.__slots__ - code = ( - "@tuple\n" - "class " - + name - + "[" - + ",".join("T{}".format(i) for i in range(len(slots))) - + "]:\n" - ) + code = ("@tuple\n" + "class " + name + "[" + + ",".join("T{}".format(i) for i in range(len(slots))) + "]:\n") for i, slot in enumerate(slots): code += " {}: T{}\n".format(slot, i) @@ -235,17 +221,14 @@ def convert(t): code += " def __from_py__(p: cobj):\n" for i, slot in enumerate(slots): code += " a{} = T{}.__from_py__(PyObject_GetAttrString(p, '{}'.ptr))\n".format( - i, i, slot - ) + i, i, slot) code += " return {}({})\n".format( - name, ", ".join("a{}".format(i) for i in range(len(slots))) - ) + name, ", ".join("a{}".format(i) for i in range(len(slots)))) _jit.execute(code, "", 0, False) custom_conversions[t] = name return t - def _jit_register_fn(f, pyvars, debug): try: obj_name, obj_str = _parse_decorated(f, pyvars=pyvars) @@ -258,29 +241,46 @@ def _jit_register_fn(f, pyvars, debug): _reset_jit() raise -def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs): - try: +def _jit_callback_fn(fn, + obj_name, + module, + debug=None, + sample_size=5, + pyvars=None, + *args, + **kwargs): + if fn is not None: + sig = inspect.signature(fn) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + args = tuple(bound_args.arguments[param] for param in sig.parameters) + else: args = (*args, *kwargs.values()) + + try: types = _codon_types(args, debug=debug, sample_size=sample_size) if debug: - print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr) - return _jit.run_wrapper( - obj_name, list(types), module, list(pyvars), args, 1 if debug else 0 - ) + print("[python] {}({})".format(obj_name, list(types)), + file=sys.stderr) + return _jit.run_wrapper(obj_name, list(types), module, list(pyvars), + args, 1 if debug else 0) except JITError: _reset_jit() raise def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None): obj_name = _jit_register_fn(fstr, pyvars, debug) - def wrapped(*args, **kwargs): - return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs) - return wrapped + def wrapped(*args, **kwargs): + return _jit_callback_fn(None, obj_name, "__main__", debug, sample_size, + pyvars, *args, **kwargs) + + return wrapped def jit(fn=None, debug=None, sample_size=5, pyvars=None): if not pyvars: pyvars = [] + if not isinstance(pyvars, list): raise ArgumentError("pyvars must be a list") @@ -289,12 +289,15 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None): def _decorate(f): obj_name = _jit_register_fn(f, pyvars, debug) + @functools.wraps(f) def wrapped(*args, **kwargs): - return _jit_callback_fn(obj_name, f.__module__, debug, sample_size, pyvars, *args, **kwargs) - return wrapped - return _decorate(fn) if fn else _decorate + return _jit_callback_fn(f, obj_name, f.__module__, debug, + sample_size, pyvars, *args, **kwargs) + return wrapped + + return _decorate(fn) if fn else _decorate def execute(code, debug=False): try: diff --git a/test/python/cython_jit.py b/test/python/cython_jit.py index 7cac6b15..83682737 100644 --- a/test/python/cython_jit.py +++ b/test/python/cython_jit.py @@ -181,3 +181,16 @@ def test_ndarray(): assert np.datetime_data(y.dtype) == ('s', 2) test_ndarray() + +@codon.jit +def e(x=2, y=99): + return 2*x + y + +def test_arg_order(): + assert e(1, 2) == 4 + assert e(1) == 101 + assert e(y=10, x=1) == 12 + assert e(x=1) == 101 + assert e() == 103 + +test_arg_order()