mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
127 lines
3.9 KiB
Python
127 lines
3.9 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
from .ndarray import ndarray
|
|
from .routines import asarray, empty, expand_dims
|
|
from .ufunc import UnaryUFunc, UnaryUFunc2, BinaryUFunc
|
|
import util
|
|
|
|
def apply_along_axis(func1d, axis: int, arr, *args, **kwargs):
|
|
a = asarray(arr)
|
|
axis = util.normalize_axis_index(axis, a.ndim)
|
|
n = a.shape[axis]
|
|
s = a.strides[axis]
|
|
|
|
first_slice = ndarray((n,), (s,), a.data)
|
|
f0 = asarray(func1d(first_slice, *args, **kwargs))
|
|
|
|
a_iter_space = util.tuple_delete(a.shape, axis)
|
|
new_shape = util.tuple_insert_tuple(a_iter_space, axis, f0.shape)
|
|
ans = empty(new_shape, f0.dtype)
|
|
first = True
|
|
|
|
for idx0 in util.multirange(a_iter_space):
|
|
if first:
|
|
fx = f0
|
|
first = False
|
|
else:
|
|
idx = util.tuple_insert(idx0, axis, 0)
|
|
next_slice = ndarray((n,), (s,), a._ptr(idx))
|
|
fx = asarray(func1d(next_slice, *args, **kwargs))
|
|
|
|
if fx.shape != f0.shape:
|
|
raise ValueError("function is returning arrays of different shape")
|
|
|
|
if fx.ndim > 0:
|
|
for idx1 in util.multirange(fx.shape):
|
|
idx2 = util.tuple_insert_tuple(idx0, axis, idx1)
|
|
p = fx._ptr(idx1)
|
|
q = ans._ptr(idx2)
|
|
q[0] = p[0]
|
|
else:
|
|
ans._ptr(idx0)[0] = fx.item()
|
|
|
|
return ans
|
|
|
|
def apply_over_axes(func, a, axes):
|
|
a = asarray(a)
|
|
dtype = func(a, 0).dtype
|
|
|
|
if staticlen(dtype) >= 0:
|
|
val = a.astype(dtype)
|
|
if asarray(axes).ndim == 0:
|
|
ax = (axes,)
|
|
else:
|
|
ax = axes
|
|
|
|
for axis in ax:
|
|
if axis < 0:
|
|
axis += val.ndim
|
|
b = func(val, axis)
|
|
|
|
if isinstance(b, ndarray):
|
|
if b.ndim == val.ndim:
|
|
val = b
|
|
elif b.ndim == val.ndim - 1:
|
|
val = expand_dims(b, axis)
|
|
else:
|
|
compile_error("function is not returning an array of the correct shape")
|
|
else:
|
|
compile_error("function is not returning ndarray")
|
|
|
|
return val
|
|
|
|
def frompyfunc(func, nin: Static[int], nout: Static[int], identity):
|
|
if nin == 1 and nout == 1:
|
|
return UnaryUFunc(func, func.__name__ + ' (vectorized)')
|
|
elif nin == 1 and nout == 2:
|
|
return UnaryUFunc2(func, func.__name__ + ' (vectorized)')
|
|
elif nin == 2 and nout == 1:
|
|
return BinaryUFunc(func, func.__name__ + ' (vectorized)', identity)
|
|
else:
|
|
compile_error("given combination of 'nin' and 'nout' is not supported")
|
|
|
|
def vectorize(pyfunc):
|
|
import internal.static as S
|
|
nout: Static[int] = 1
|
|
nin: Static[int] = staticlen(S.fn_args(pyfunc))
|
|
return frompyfunc(pyfunc, nin=nin, nout=nout, identity=None)
|
|
|
|
def piecewise(x, condlist: List, funclist: List, *args, **kw):
|
|
x = asarray(x)
|
|
ans = empty(x.shape, x.dtype)
|
|
|
|
n_condlist = len(condlist)
|
|
n_funclist = len(funclist)
|
|
|
|
if n_condlist != n_funclist and n_condlist + 1 != n_funclist:
|
|
raise ValueError(f"with {n_condlist} conditions, either {n_condlist} or {n_condlist + 1} functions are expected")
|
|
|
|
for cond in condlist:
|
|
acond = asarray(cond)
|
|
if acond.ndim != x.ndim:
|
|
compile_error("condition dimension does not match input domain dimension")
|
|
if acond.shape != x.shape:
|
|
raise ValueError("condition shape does not match input domain shape")
|
|
|
|
for idx in util.multirange(x.shape):
|
|
e = x._ptr(idx)[0]
|
|
i = 0
|
|
for cond in condlist:
|
|
if asarray(cond)._ptr(idx)[0]:
|
|
break
|
|
i += 1
|
|
|
|
if i == n_condlist and n_condlist + 1 == n_funclist:
|
|
f = funclist[-1]
|
|
else:
|
|
f = funclist[i]
|
|
|
|
if hasattr(f, "__call__"):
|
|
e = util.cast(f(e, *args, **kw), x.dtype)
|
|
else:
|
|
e = util.cast(f, x.dtype)
|
|
|
|
ans._ptr(idx)[0] = e
|
|
|
|
return ans
|