# Copyright (C) 2022-2025 Exaloop Inc. import util import zmath from .ndarray import ndarray from .routines import asarray, broadcast_to, empty_like # Utility def _cast(x, dtype: type): return util.cast(x, dtype) def _coerce(dtype1: type, dtype2: type): return util.coerce(dtype1, dtype2) def _count(shape): return util.count(shape) def _contig_match(arrs): def keep_arrays(arrs): if staticlen(arrs) == 0: return () else: if hasattr(arrs[0], "_contig"): return (arrs[0],) + keep_arrays(arrs[1:]) else: return keep_arrays(arrs[1:]) arrs = keep_arrays(arrs) if staticlen(arrs) == 0: return True all_cc = True all_fc = True for i in staticrange(staticlen(arrs)): a = arrs[i] cc, fc = a._contig if a.ndim != arrs[0].ndim: return False else: if a.shape != arrs[0].shape: return False all_cc = all_cc and cc all_fc = all_fc and fc return all_cc or all_fc def _ptrset(p: Ptr[T], x: T, T: type): p[0] = x def _loop_alloc(arrays, func, extra, dtype: type): return ndarray._loop(arrays, func, alloc=Tuple[dtype], extra=extra)[0] def _loop_basic(arrays, func, extra): ndarray._loop(arrays, func, extra=extra) def _broadcast(sh1, sh2): def bc_one(sh1, sh2, i: Static[int]): a = sh1[i] b = sh2[i] if a == 1 or b == 1 or a == b: return max(a, b) else: raise ValueError(f"operands could not be broadcast together with shapes {sh1} {sh2}") def bc_same(sh1, sh2): return tuple(bc_one(sh1, sh2, i) for i in staticrange(staticlen(sh1))) N1: Static[int] = staticlen(sh1) N2: Static[int] = staticlen(sh2) if N1 == 0: return sh2 elif N2 == 0: return sh1 elif N1 > N2: return sh1[:-N2] + bc_same(sh1[-N2:], sh2) elif N1 < N2: return sh2[:-N1] + bc_same(sh1, sh2[-N1:]) else: return bc_same(sh1, sh2) def _matmul_shape(x1, x2): x1d: Static[int] = staticlen(x1) x2d: Static[int] = staticlen(x2) if x1d == 0: return x2 if x2d == 0: return x1 if x1d == 1: y1 = (1,) + x1 else: y1 = x1 if x2d == 1: y2 = x2 + (1,) else: y2 = x2 y1d: Static[int] = staticlen(y1) y2d: Static[int] = staticlen(y2) base1s = y1[:-2] base2s = y2[:-2] mat1s = y1[-2:] mat2s = y2[-2:] m = mat1s[0] k = mat1s[1] n = mat2s[1] if k != mat2s[0]: raise ValueError("matmul: last dimension of first argument does not " "match second-to-last dimension of second argument") ans_base = _broadcast(base1s, base2s) if x1d == 1 and x2d == 1: return ans_base elif x1d == 1: return ans_base + (mat2s[1],) elif x2d == 1: return ans_base + (mat1s[0],) else: return ans_base + (mat1s[0], mat2s[1]) def _create(like, shape, dtype: type): return empty_like(like, shape=shape, dtype=dtype) def _shape(x): if hasattr(x, "shape"): return x.shape else: return () def _free(x): util.free(x.data) def _apply_vectorized_loop_unary(arr, out, func: Static[str]): if arr.ndim == 0 or out.ndim == 0 or arr.ndim > out.ndim: compile_error("[internal error] bad array dims for vectorized loop") if out.ndim == 1: util.call_vectorized_loop(arr.data, arr.strides[0], Ptr[arr.dtype](), 0, out.data, out.strides[0], out.size, func) return shape = arr.shape arr = broadcast_to(arr, shape) if arr._contig_match(out): s = util.sizeof(out.dtype) util.call_vectorized_loop(arr.data, s, Ptr[arr.dtype](), 0, out.data, s, out.size, func) else: # Find smallest stride to use in vectorized loop arr_strides = arr.strides out_strides = out.strides n = 0 si = 0 so = 0 loop_axis = -1 for i in staticrange(arr.ndim): if shape[i] > 1 and (loop_axis == -1 or abs(arr_strides[i]) < abs(si)): n = shape[i] si = arr_strides[i] so = out_strides[i] loop_axis = i if loop_axis == -1: n = shape[0] si = arr_strides[0] so = out_strides[0] loop_axis = 0 for idx in util.multirange(util.tuple_delete(shape, loop_axis)): idx1 = util.tuple_insert(idx, loop_axis, 0) p = arr._ptr(idx1) q = out._ptr(idx1) util.call_vectorized_loop(p, si, Ptr[arr.dtype](), 0, q, so, n, func) def _apply_vectorized_loop_binary(arr1, arr2, out, func: Static[str]): if (arr1.ndim == 0 and arr2.ndim == 0) or out.ndim == 0 or arr1.ndim > out.ndim or arr2.ndim > out.ndim: compile_error("[internal error] bad array dims for vectorized loop") if arr1.ndim == 0: st1 = 0 else: st1 = arr1.strides[0] if arr2.ndim == 0: st2 = 0 else: st2 = arr2.strides[0] if out.ndim == 1: util.call_vectorized_loop(arr1.data, st1, arr2.data, st2, out.data, out.strides[0], out.size, func) return shape = out.shape arr1 = broadcast_to(arr1, shape) arr2 = broadcast_to(arr2, shape) if arr1._contig_match(out) and arr2._contig_match(out): s = util.sizeof(out.dtype) util.call_vectorized_loop(arr1.data, s, arr2.data, s, out.data, s, out.size, func) else: # Find smallest stride to use in vectorized loop arr1_strides = arr1.strides arr2_strides = arr2.strides out_strides = out.strides n = 0 si1 = 0 si2 = 0 so = 0 loop_axis = -1 for i in staticrange(arr1.ndim): if shape[i] > 1 and (loop_axis == -1 or abs(arr1_strides[i]) < abs(si1)): n = shape[i] si1 = arr1_strides[i] si2 = arr2_strides[i] so = out_strides[i] loop_axis = i if loop_axis == -1: n = shape[0] si1 = arr1_strides[0] si2 = arr2_strides[0] so = out_strides[0] loop_axis = 0 for idx in util.multirange(util.tuple_delete(shape, loop_axis)): idx1 = util.tuple_insert(idx, loop_axis, 0) p1 = arr1._ptr(idx1) p2 = arr2._ptr(idx1) q = out._ptr(idx1) util.call_vectorized_loop(p1, si1, p2, si2, q, so, n, func) # Operations @inline def _pos(x): return +x @inline def _neg(x): return -x @inline def _invert(x): return ~x @inline def _abs(x): return abs(x) @inline def _transpose(x): return x.T @inline def _add(x, y): return x + y @inline def _sub(x, y): return x - y @inline def _mul(x, y): return x * y @inline def _matmul(x, y): return x @ y @inline def _true_div(x, y): return x / y @inline def _floor_div(x, y): X = type(x) Y = type(y) if isinstance(X, Int) and isinstance(Y, Int): return util.pydiv(x, y) else: return x // y @inline def _mod(x, y): X = type(x) Y = type(y) if isinstance(X, Int) and isinstance(Y, Int): return util.pymod(x, y) elif ((X is float and Y is float) or (X is float32 and Y is float32) or (X is float16 and Y is float16)): return util.pyfmod(x, y) else: return x % y @inline def _fmod(x, y): X = type(x) Y = type(y) if isinstance(X, Int) and isinstance(Y, Int): return util.cmod_int(x, y) elif ((X is float and Y is float) or (X is float32 and Y is float32) or (X is float16 and Y is float16)): return util.cmod(x, y) else: return x % y @inline def _pow(x, y): return x ** y @inline def _lshift(x, y): return x << y @inline def _rshift(x, y): return x >> y @inline def _and(x, y): return x & y @inline def _or(x, y): return x | y @inline def _xor(x, y): return x ^ y @inline def _eq(x, y): return x == y @inline def _ne(x, y): return x != y @inline def _lt(x, y): return x < y @inline def _le(x, y): return x <= y @inline def _gt(x, y): return x > y @inline def _ge(x, y): return x >= y def _apply(x, f, f_complex = None): if f_complex is not None and (isinstance(x, complex) or isinstance(x, complex64)): return f_complex(x) elif isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16): return f(x) else: return f(util.to_float(x)) def _apply2(x, y, f, f_complex = None): if type(x) is not type(y): compile_error("type mismatch in util") if f_complex is not None and (isinstance(x, complex) or isinstance(x, complex64)): return f_complex(x, y) elif isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16): return f(x, y) else: return f(util.to_float(x), util.to_float(y)) def _fabs(x): return _apply(x, util.fabs) def _rint(x): def rint_complex(x): C = type(x) return C(util.rint(x.real), util.rint(x.imag)) return _apply(x, util.rint, rint_complex) def _exp(x): return _apply(x, util.exp, zmath.exp) def _exp2(x): return _apply(x, util.exp2, zmath.exp2) def _expm1(x): return _apply(x, util.expm1, zmath.expm1) def _log(x): return _apply(x, util.log, zmath.log) def _log2(x): return _apply(x, util.log2, zmath.log2) def _log10(x): return _apply(x, util.log10, zmath.log10) def _log1p(x): return _apply(x, util.log1p, zmath.log1p) def _sqrt(x): return _apply(x, util.sqrt, zmath.sqrt) def _cbrt(x): return _apply(x, util.cbrt) def _square(x): return x * x def _sin(x): return _apply(x, util.sin, zmath.sin) def _cos(x): return _apply(x, util.cos, zmath.cos) def _tan(x): return _apply(x, util.tan, zmath.tan) def _arcsin(x): return _apply(x, util.asin, zmath.asin) def _arccos(x): return _apply(x, util.acos, zmath.acos) def _arctan(x): return _apply(x, util.atan, zmath.atan) def _sinh(x): return _apply(x, util.sinh, zmath.sinh) def _cosh(x): return _apply(x, util.cosh, zmath.cosh) def _tanh(x): return _apply(x, util.tanh, zmath.tanh) def _arcsinh(x): return _apply(x, util.asinh, zmath.asinh) def _arccosh(x): return _apply(x, util.acosh, zmath.acosh) def _arctanh(x): return _apply(x, util.atanh, zmath.atanh) def _rad2deg(x): r2d = 180.0 / util.PI x = util.to_float(x) F = type(x) return x * F(r2d) def _deg2rad(x): d2r = util.PI / 180.0 x = util.to_float(x) F = type(x) return x * F(d2r) def _arctan2(x, y): return _apply2(x, y, util.atan2) def _hypot(x, y): return _apply2(x, y, util.hypot) def _logaddexp(x, y): return _apply2(x, y, util.logaddexp) def _logaddexp2(x, y): return _apply2(x, y, util.logaddexp2) def _isnan(x): if isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16): return util.isnan(x) elif isinstance(x, complex) or isinstance(x, complex64): return util.isnan(x.real) or util.isnan(x.imag) else: return False def _isinf(x): if isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16): return util.isinf(x) elif isinstance(x, complex) or isinstance(x, complex64): return util.isinf(x.real) or util.isinf(x.imag) else: return False def _isfinite(x): if isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16): return util.isfinite(x) elif isinstance(x, complex) or isinstance(x, complex64): return util.isfinite(x.real) and util.isfinite(x.imag) else: return True def _signbit(x): if isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16): return util.signbit(x) else: T = type(x) return x < T() def _copysign(x, y): return _apply2(x, y, util.copysign) def _nextafter(x, y): return _apply2(x, y, util.nextafter) def _floor(x): return _apply(x, util.floor) def _ceil(x): return _apply(x, util.ceil) def _trunc(x): return _apply(x, util.trunc) def _sign(x): def sign1(x): T = type(x) if x < T(0): return T(-1) elif x > T(0): return T(1) else: return x if isinstance(x, complex): if _isnan(x): return complex(util.nan64(), 0.0) return complex(sign1(x.real), 0) if x.real else complex(sign1(x.imag), 0) elif isinstance(x, complex64): if _isnan(x): return complex64(util.nan64(), 0.0) return complex64(sign1(x.real), 0) if x.real else complex64(sign1(x.imag), 0) else: return sign1(x) def _heaviside(x, y): def heaviside(x, y): if isinstance(x, float16) and isinstance(y, float16): if x < float16(0): return float16(0) elif x > float16(0): return float16(1) elif x == float16(0): return y else: return x elif isinstance(x, float32) and isinstance(y, float32): if x < float32(0): return float32(0) elif x > float32(0): return float32(1) elif x == float32(0): return y else: return x elif isinstance(x, float) and isinstance(y, float): if x < 0: return 0.0 elif x > 0: return 1.0 elif x == 0.0: return y else: return x return _apply2(x, y, heaviside) def _conj(x): if isinstance(x, complex) or isinstance(x, complex64): return x.conjugate() else: return x def _gcd(x, y): ''' # fails with optionals if not ( isinstance(x, int) or isinstance(x, Int) or isinstance(x, UInt) or isinstance(x, byte) ): compile_error("gcd/lcm can only be used on integral types") ''' while x: z = x x = y % x y = z return y def _lcm(x, y): gcd = _gcd(x, y) return x // gcd * y if gcd else 0 def _reciprocal(x: T, T: type): if ( isinstance(x, int) or isinstance(x, Int) or isinstance(x, UInt) or isinstance(x, byte) ): return T(1) // x else: return T(1) / x def _logical_and(x, y): return bool(x) and bool(y) def _logical_or(x, y): return bool(x) or bool(y) def _logical_xor(x, y): return bool(x) ^ bool(y) def _logical_not(x): return not bool(x) def _coerce_types_for_minmax(x, y): if isinstance(x, complex): if isinstance(y, complex64): return x, complex(y) elif not isinstance(y, complex): return x, complex(util.cast(y, float)) elif isinstance(x, complex64): if isinstance(y, complex): return complex(x), y elif not isinstance(y, complex64): return complex(x), complex(util.cast(y, float)) if isinstance(y, complex): if isinstance(x, complex64): return complex(x), y elif not isinstance(x, complex): return complex(util.cast(x, float)), y elif isinstance(y, complex64): if isinstance(x, complex): return x, complex(y) elif not isinstance(x, complex64): return complex(util.cast(x, float)), complex(y) T = type(util.coerce(type(x), type(y))) return util.cast(x, T), util.cast(y, T) def _compare_le(x, y): if isinstance(x, complex) or isinstance(x, complex64): return (x.real, x.imag) <= (y.real, y.imag) else: return x <= y def _compare_ge(x, y): if isinstance(x, complex) or isinstance(x, complex64): return (x.real, x.imag) >= (y.real, y.imag) else: return x >= y def _maximum(x, y): x, y = _coerce_types_for_minmax(x, y) if _isnan(x): return x if _isnan(y): return y return x if _compare_ge(x, y) else y def _minimum(x, y): x, y = _coerce_types_for_minmax(x, y) if _isnan(x): return x if _isnan(y): return y return x if _compare_le(x, y) else y def _fmax(x, y): x, y = _coerce_types_for_minmax(x, y) if _isnan(y): return x if _isnan(x): return y return x if _compare_ge(x, y) else y def _fmin(x, y): x, y = _coerce_types_for_minmax(x, y) if _isnan(y): return x if _isnan(x): return y return x if _compare_le(x, y) else y def _divmod_float(x, y): F = type(x) mod = util.cmod(x, y) if not y: return util.cdiv(x, y), mod div = util.cdiv(x - mod, y) if mod: if (y < F(0)) != (mod < F(0)): mod += y div -= F(1) else: mod = util.copysign(F(0), y) floordiv = F() if div: floordiv = util.floor(div) if div - floordiv > F(0.5): floordiv += F(1) else: floordiv = util.copysign(F(0), util.cdiv(x, y)) return floordiv, mod def _divmod(x, y): if isinstance(x, float16) and isinstance(y, float16): return _divmod_float(x, y) if isinstance(x, float32) and isinstance(y, float32): return _divmod_float(x, y) if isinstance(x, float) or isinstance(y, float): return _divmod_float(util.cast(x, float), util.cast(y, float)) return (x // y, x % y) def _modf(x): return _apply(x, util.modf) def _frexp(x): def frexp(x): a, b = util.frexp(x) return a, i32(b) return _apply(x, frexp) def _spacing16(h: float16): h_u16 = util.bitcast(h, u16) h_exp = h_u16 & u16(0x7c00) h_sig = h_u16 & u16(0x03ff) if h_exp == u16(0x7c00): return util.nan16() elif h_u16 == u16(0x7bff): return util.inf16() elif (h_u16 & u16(0x8000)) and not h_sig: if h_exp > u16(0x2c00): return util.bitcast(h_exp - u16(0x2c00), float16) elif h_exp > u16(0x0400): return util.bitcast(u16(1) << ((h_exp >> u16(10)) - u16(2)), float16) else: return util.bitcast(u16(0x0001), float16) elif h_exp > u16(0x2800): return util.bitcast(h_exp - u16(0x2800), float16) elif h_exp > u16(0x0400): return util.bitcast(u16(1) << ((h_exp >> u16(10)) - u16(1)), float16) else: return util.bitcast(u16(0x0001), float16) def _spacing(x): if isinstance(x, float16): return _spacing16(x) elif isinstance(x, float32): if util.isinf32(x): return util.nan32() p = util.inf32() if x >= float32(0) else -util.inf32() return util.nextafter32(x, util.inf32()) - x elif isinstance(x, float): x = util.cast(x, float) if util.isinf64(x): return util.nan64() p = util.inf64() if x >= 0 else -util.inf64() return util.nextafter64(x, p) - x else: return _spacing(util.to_float(x))