mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
JIT argument order fix (#639)
* Fix argument ordering in JIT * Format * Update JIT tests * Fix JIT test
This commit is contained in:
parent
b3f6c12d57
commit
93fb3d53e3
@ -1,5 +1,7 @@
|
||||
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
__all__ = ["jit", "convert", "JITError"]
|
||||
__all__ = [
|
||||
"jit", "convert", "JITError", "JITWrapper", "_jit_register_fn", "_jit"
|
||||
]
|
||||
|
||||
from .decorator import jit, convert, execute, JITError
|
||||
from .decorator import jit, convert, execute, JITError, JITWrapper, _jit_register_fn, _jit_callback_fn, _jit
|
||||
|
@ -23,16 +23,14 @@ if "CODON_PATH" not in os.environ:
|
||||
if codon_lib_path:
|
||||
codon_path.append(Path(codon_lib_path).parent / "stdlib")
|
||||
codon_path.append(
|
||||
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib"
|
||||
)
|
||||
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib")
|
||||
for path in codon_path:
|
||||
if path.exists():
|
||||
os.environ["CODON_PATH"] = str(path.resolve())
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot locate Codon. Please install Codon or set CODON_PATH."
|
||||
)
|
||||
"Cannot locate Codon. Please install Codon or set CODON_PATH.")
|
||||
|
||||
pod_conversions = {
|
||||
type(None): "pyobj",
|
||||
@ -61,7 +59,6 @@ pod_conversions = {
|
||||
custom_conversions = {}
|
||||
_error_msgs = set()
|
||||
|
||||
|
||||
def _common_type(t, debug, sample_size):
|
||||
sub, is_optional = None, False
|
||||
for i in itertools.islice(t, sample_size):
|
||||
@ -76,7 +73,6 @@ def _common_type(t, debug, sample_size):
|
||||
sub = "Optional[{}]".format(sub)
|
||||
return sub if sub else "pyobj"
|
||||
|
||||
|
||||
def _codon_type(arg, **kwargs):
|
||||
t = type(arg)
|
||||
|
||||
@ -88,11 +84,11 @@ def _codon_type(arg, **kwargs):
|
||||
if issubclass(t, set):
|
||||
return "Set[{}]".format(_common_type(arg, **kwargs))
|
||||
if issubclass(t, dict):
|
||||
return "Dict[{},{}]".format(
|
||||
_common_type(arg.keys(), **kwargs), _common_type(arg.values(), **kwargs)
|
||||
)
|
||||
return "Dict[{},{}]".format(_common_type(arg.keys(), **kwargs),
|
||||
_common_type(arg.values(), **kwargs))
|
||||
if issubclass(t, tuple):
|
||||
return "Tuple[{}]".format(",".join(_codon_type(a, **kwargs) for a in arg))
|
||||
return "Tuple[{}]".format(",".join(
|
||||
_codon_type(a, **kwargs) for a in arg))
|
||||
if issubclass(t, np.ndarray):
|
||||
if arg.dtype == np.bool_:
|
||||
dtype = "bool"
|
||||
@ -134,7 +130,8 @@ def _codon_type(arg, **kwargs):
|
||||
|
||||
s = custom_conversions.get(t, "")
|
||||
if s:
|
||||
j = ",".join(_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
|
||||
j = ",".join(
|
||||
_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
|
||||
return "{}[{}]".format(s, j)
|
||||
|
||||
debug = kwargs.get("debug", None)
|
||||
@ -145,28 +142,22 @@ def _codon_type(arg, **kwargs):
|
||||
_error_msgs.add(msg)
|
||||
return "pyobj"
|
||||
|
||||
|
||||
def _codon_types(args, **kwargs):
|
||||
return tuple(_codon_type(arg, **kwargs) for arg in args)
|
||||
|
||||
|
||||
def _reset_jit():
|
||||
global _jit
|
||||
_jit = JITWrapper()
|
||||
init_code = (
|
||||
"from internal.python import "
|
||||
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
|
||||
"setup_decorator()\n"
|
||||
"import numpy as np\n"
|
||||
"import numpy.pybridge\n"
|
||||
)
|
||||
init_code = ("from internal.python import "
|
||||
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
|
||||
"setup_decorator()\n"
|
||||
"import numpy as np\n"
|
||||
"import numpy.pybridge\n")
|
||||
_jit.execute(init_code, "", 0, False)
|
||||
return _jit
|
||||
|
||||
|
||||
_jit = _reset_jit()
|
||||
|
||||
|
||||
class RewriteFunctionArgs(ast.NodeTransformer):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
@ -176,7 +167,6 @@ class RewriteFunctionArgs(ast.NodeTransformer):
|
||||
node.args.args.append(ast.arg(arg=a, annotation=None))
|
||||
return node
|
||||
|
||||
|
||||
def _obj_to_str(obj, **kwargs) -> str:
|
||||
if inspect.isclass(obj):
|
||||
lines = inspect.getsourcelines(obj)[0]
|
||||
@ -185,8 +175,10 @@ def _obj_to_str(obj, **kwargs) -> str:
|
||||
obj_name = obj.__name__
|
||||
elif callable(obj) or isinstance(obj, str):
|
||||
is_str = isinstance(obj, str)
|
||||
lines = [i + '\n' for i in obj.split('\n')] if is_str else inspect.getsourcelines(obj)[0]
|
||||
if not is_str: lines = lines[1:]
|
||||
lines = [i + '\n' for i in obj.split('\n')
|
||||
] if is_str else inspect.getsourcelines(obj)[0]
|
||||
if not is_str:
|
||||
lines = lines[1:]
|
||||
obj_str = textwrap.dedent(''.join(lines))
|
||||
|
||||
pyvars = kwargs.get("pyvars", None)
|
||||
@ -195,8 +187,7 @@ def _obj_to_str(obj, **kwargs) -> str:
|
||||
if not isinstance(i, str):
|
||||
raise ValueError("pyvars only takes string literals")
|
||||
node = ast.fix_missing_locations(
|
||||
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str))
|
||||
)
|
||||
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str)))
|
||||
obj_str = astunparse.unparse(node)
|
||||
if is_str:
|
||||
try:
|
||||
@ -206,28 +197,23 @@ def _obj_to_str(obj, **kwargs) -> str:
|
||||
else:
|
||||
obj_name = obj.__name__
|
||||
else:
|
||||
raise TypeError("Function or class expected, got " + type(obj).__name__)
|
||||
raise TypeError("Function or class expected, got " +
|
||||
type(obj).__name__)
|
||||
return obj_name, obj_str.replace("_@par", "@par")
|
||||
|
||||
|
||||
def _parse_decorated(obj, **kwargs):
|
||||
return _obj_to_str(obj, **kwargs)
|
||||
|
||||
return _obj_to_str(obj, **kwargs)
|
||||
|
||||
def convert(t):
|
||||
if not hasattr(t, "__slots__"):
|
||||
raise JITError("class '{}' does not have '__slots__' attribute".format(str(t)))
|
||||
raise JITError("class '{}' does not have '__slots__' attribute".format(
|
||||
str(t)))
|
||||
|
||||
name = t.__name__
|
||||
slots = t.__slots__
|
||||
code = (
|
||||
"@tuple\n"
|
||||
"class "
|
||||
+ name
|
||||
+ "["
|
||||
+ ",".join("T{}".format(i) for i in range(len(slots)))
|
||||
+ "]:\n"
|
||||
)
|
||||
code = ("@tuple\n"
|
||||
"class " + name + "[" +
|
||||
",".join("T{}".format(i) for i in range(len(slots))) + "]:\n")
|
||||
for i, slot in enumerate(slots):
|
||||
code += " {}: T{}\n".format(slot, i)
|
||||
|
||||
@ -235,17 +221,14 @@ def convert(t):
|
||||
code += " def __from_py__(p: cobj):\n"
|
||||
for i, slot in enumerate(slots):
|
||||
code += " a{} = T{}.__from_py__(PyObject_GetAttrString(p, '{}'.ptr))\n".format(
|
||||
i, i, slot
|
||||
)
|
||||
i, i, slot)
|
||||
code += " return {}({})\n".format(
|
||||
name, ", ".join("a{}".format(i) for i in range(len(slots)))
|
||||
)
|
||||
name, ", ".join("a{}".format(i) for i in range(len(slots))))
|
||||
|
||||
_jit.execute(code, "", 0, False)
|
||||
custom_conversions[t] = name
|
||||
return t
|
||||
|
||||
|
||||
def _jit_register_fn(f, pyvars, debug):
|
||||
try:
|
||||
obj_name, obj_str = _parse_decorated(f, pyvars=pyvars)
|
||||
@ -258,29 +241,46 @@ def _jit_register_fn(f, pyvars, debug):
|
||||
_reset_jit()
|
||||
raise
|
||||
|
||||
def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs):
|
||||
try:
|
||||
def _jit_callback_fn(fn,
|
||||
obj_name,
|
||||
module,
|
||||
debug=None,
|
||||
sample_size=5,
|
||||
pyvars=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
if fn is not None:
|
||||
sig = inspect.signature(fn)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
args = tuple(bound_args.arguments[param] for param in sig.parameters)
|
||||
else:
|
||||
args = (*args, *kwargs.values())
|
||||
|
||||
try:
|
||||
types = _codon_types(args, debug=debug, sample_size=sample_size)
|
||||
if debug:
|
||||
print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr)
|
||||
return _jit.run_wrapper(
|
||||
obj_name, list(types), module, list(pyvars), args, 1 if debug else 0
|
||||
)
|
||||
print("[python] {}({})".format(obj_name, list(types)),
|
||||
file=sys.stderr)
|
||||
return _jit.run_wrapper(obj_name, list(types), module, list(pyvars),
|
||||
args, 1 if debug else 0)
|
||||
except JITError:
|
||||
_reset_jit()
|
||||
raise
|
||||
|
||||
def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None):
|
||||
obj_name = _jit_register_fn(fstr, pyvars, debug)
|
||||
def wrapped(*args, **kwargs):
|
||||
return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return _jit_callback_fn(None, obj_name, "__main__", debug, sample_size,
|
||||
pyvars, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
def jit(fn=None, debug=None, sample_size=5, pyvars=None):
|
||||
if not pyvars:
|
||||
pyvars = []
|
||||
|
||||
if not isinstance(pyvars, list):
|
||||
raise ArgumentError("pyvars must be a list")
|
||||
|
||||
@ -289,12 +289,15 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
|
||||
|
||||
def _decorate(f):
|
||||
obj_name = _jit_register_fn(f, pyvars, debug)
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
return _jit_callback_fn(obj_name, f.__module__, debug, sample_size, pyvars, *args, **kwargs)
|
||||
return wrapped
|
||||
return _decorate(fn) if fn else _decorate
|
||||
return _jit_callback_fn(f, obj_name, f.__module__, debug,
|
||||
sample_size, pyvars, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return _decorate(fn) if fn else _decorate
|
||||
|
||||
def execute(code, debug=False):
|
||||
try:
|
||||
|
@ -181,3 +181,16 @@ def test_ndarray():
|
||||
assert np.datetime_data(y.dtype) == ('s', 2)
|
||||
|
||||
test_ndarray()
|
||||
|
||||
@codon.jit
|
||||
def e(x=2, y=99):
|
||||
return 2*x + y
|
||||
|
||||
def test_arg_order():
|
||||
assert e(1, 2) == 4
|
||||
assert e(1) == 101
|
||||
assert e(y=10, x=1) == 12
|
||||
assert e(x=1) == 101
|
||||
assert e() == 103
|
||||
|
||||
test_arg_order()
|
||||
|
Loading…
x
Reference in New Issue
Block a user