mirror of https://github.com/exaloop/codon.git
811 lines
19 KiB
Python
811 lines
19 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
import util
|
|
import zmath
|
|
from .ndarray import ndarray
|
|
from .routines import asarray, broadcast_to, empty_like
|
|
|
|
# Utility
|
|
|
|
def _cast(x, dtype: type):
|
|
return util.cast(x, dtype)
|
|
|
|
def _coerce(dtype1: type, dtype2: type):
|
|
return util.coerce(dtype1, dtype2)
|
|
|
|
def _count(shape):
|
|
return util.count(shape)
|
|
|
|
def _contig_match(arrs):
|
|
|
|
def keep_arrays(arrs):
|
|
if staticlen(arrs) == 0:
|
|
return ()
|
|
else:
|
|
if hasattr(arrs[0], "_contig"):
|
|
return (arrs[0],) + keep_arrays(arrs[1:])
|
|
else:
|
|
return keep_arrays(arrs[1:])
|
|
|
|
arrs = keep_arrays(arrs)
|
|
if staticlen(arrs) == 0:
|
|
return True
|
|
|
|
all_cc = True
|
|
all_fc = True
|
|
|
|
for i in staticrange(staticlen(arrs)):
|
|
a = arrs[i]
|
|
cc, fc = a._contig
|
|
|
|
if a.ndim != arrs[0].ndim:
|
|
return False
|
|
else:
|
|
if a.shape != arrs[0].shape:
|
|
return False
|
|
|
|
all_cc = all_cc and cc
|
|
all_fc = all_fc and fc
|
|
|
|
return all_cc or all_fc
|
|
|
|
def _ptrset(p: Ptr[T], x: T, T: type):
|
|
p[0] = x
|
|
|
|
def _loop_alloc(arrays, func, extra, dtype: type):
|
|
return ndarray._loop(arrays, func, alloc=Tuple[dtype], extra=extra)[0]
|
|
|
|
def _loop_basic(arrays, func, extra):
|
|
ndarray._loop(arrays, func, extra=extra)
|
|
|
|
def _broadcast(sh1, sh2):
|
|
|
|
def bc_one(sh1, sh2, i: Static[int]):
|
|
a = sh1[i]
|
|
b = sh2[i]
|
|
if a == 1 or b == 1 or a == b:
|
|
return max(a, b)
|
|
else:
|
|
raise ValueError(f"operands could not be broadcast together with shapes {sh1} {sh2}")
|
|
|
|
def bc_same(sh1, sh2):
|
|
return tuple(bc_one(sh1, sh2, i) for i in staticrange(staticlen(sh1)))
|
|
|
|
N1: Static[int] = staticlen(sh1)
|
|
N2: Static[int] = staticlen(sh2)
|
|
|
|
if N1 == 0:
|
|
return sh2
|
|
elif N2 == 0:
|
|
return sh1
|
|
elif N1 > N2:
|
|
return sh1[:-N2] + bc_same(sh1[-N2:], sh2)
|
|
elif N1 < N2:
|
|
return sh2[:-N1] + bc_same(sh1, sh2[-N1:])
|
|
else:
|
|
return bc_same(sh1, sh2)
|
|
|
|
def _matmul_shape(x1, x2):
|
|
x1d: Static[int] = staticlen(x1)
|
|
x2d: Static[int] = staticlen(x2)
|
|
|
|
if x1d == 0:
|
|
return x2
|
|
|
|
if x2d == 0:
|
|
return x1
|
|
|
|
if x1d == 1:
|
|
y1 = (1,) + x1
|
|
else:
|
|
y1 = x1
|
|
|
|
if x2d == 1:
|
|
y2 = x2 + (1,)
|
|
else:
|
|
y2 = x2
|
|
|
|
y1d: Static[int] = staticlen(y1)
|
|
y2d: Static[int] = staticlen(y2)
|
|
|
|
base1s = y1[:-2]
|
|
base2s = y2[:-2]
|
|
mat1s = y1[-2:]
|
|
mat2s = y2[-2:]
|
|
|
|
m = mat1s[0]
|
|
k = mat1s[1]
|
|
n = mat2s[1]
|
|
|
|
if k != mat2s[0]:
|
|
raise ValueError("matmul: last dimension of first argument does not "
|
|
"match second-to-last dimension of second argument")
|
|
|
|
ans_base = _broadcast(base1s, base2s)
|
|
if x1d == 1 and x2d == 1:
|
|
return ans_base
|
|
elif x1d == 1:
|
|
return ans_base + (mat2s[1],)
|
|
elif x2d == 1:
|
|
return ans_base + (mat1s[0],)
|
|
else:
|
|
return ans_base + (mat1s[0], mat2s[1])
|
|
|
|
def _create(like, shape, dtype: type):
|
|
return empty_like(like, shape=shape, dtype=dtype)
|
|
|
|
def _shape(x):
|
|
if hasattr(x, "shape"):
|
|
return x.shape
|
|
else:
|
|
return ()
|
|
|
|
def _free(x):
|
|
util.free(x.data)
|
|
|
|
def _apply_vectorized_loop_unary(arr, out, func: Static[str]):
|
|
if arr.ndim == 0 or out.ndim == 0 or arr.ndim > out.ndim:
|
|
compile_error("[internal error] bad array dims for vectorized loop")
|
|
|
|
if out.ndim == 1:
|
|
util.call_vectorized_loop(arr.data, arr.strides[0], Ptr[arr.dtype](),
|
|
0, out.data, out.strides[0], out.size, func)
|
|
return
|
|
|
|
shape = arr.shape
|
|
arr = broadcast_to(arr, shape)
|
|
|
|
if arr._contig_match(out):
|
|
s = util.sizeof(out.dtype)
|
|
util.call_vectorized_loop(arr.data, s, Ptr[arr.dtype](), 0, out.data,
|
|
s, out.size, func)
|
|
else:
|
|
# Find smallest stride to use in vectorized loop
|
|
arr_strides = arr.strides
|
|
out_strides = out.strides
|
|
n = 0
|
|
si = 0
|
|
so = 0
|
|
loop_axis = -1
|
|
|
|
for i in staticrange(arr.ndim):
|
|
if shape[i] > 1 and (loop_axis == -1 or abs(arr_strides[i]) < abs(si)):
|
|
n = shape[i]
|
|
si = arr_strides[i]
|
|
so = out_strides[i]
|
|
loop_axis = i
|
|
|
|
if loop_axis == -1:
|
|
n = shape[0]
|
|
si = arr_strides[0]
|
|
so = out_strides[0]
|
|
loop_axis = 0
|
|
|
|
for idx in util.multirange(util.tuple_delete(shape, loop_axis)):
|
|
idx1 = util.tuple_insert(idx, loop_axis, 0)
|
|
p = arr._ptr(idx1)
|
|
q = out._ptr(idx1)
|
|
util.call_vectorized_loop(p, si, Ptr[arr.dtype](), 0, q, so, n,
|
|
func)
|
|
|
|
def _apply_vectorized_loop_binary(arr1, arr2, out, func: Static[str]):
|
|
if (arr1.ndim == 0 and arr2.ndim == 0) or out.ndim == 0 or arr1.ndim > out.ndim or arr2.ndim > out.ndim:
|
|
compile_error("[internal error] bad array dims for vectorized loop")
|
|
|
|
if arr1.ndim == 0:
|
|
st1 = 0
|
|
else:
|
|
st1 = arr1.strides[0]
|
|
|
|
if arr2.ndim == 0:
|
|
st2 = 0
|
|
else:
|
|
st2 = arr2.strides[0]
|
|
|
|
if out.ndim == 1:
|
|
util.call_vectorized_loop(arr1.data, st1, arr2.data,
|
|
st2, out.data, out.strides[0],
|
|
out.size, func)
|
|
return
|
|
|
|
shape = out.shape
|
|
arr1 = broadcast_to(arr1, shape)
|
|
arr2 = broadcast_to(arr2, shape)
|
|
|
|
if arr1._contig_match(out) and arr2._contig_match(out):
|
|
s = util.sizeof(out.dtype)
|
|
util.call_vectorized_loop(arr1.data, s, arr2.data, s, out.data, s, out.size, func)
|
|
else:
|
|
# Find smallest stride to use in vectorized loop
|
|
arr1_strides = arr1.strides
|
|
arr2_strides = arr2.strides
|
|
out_strides = out.strides
|
|
n = 0
|
|
si1 = 0
|
|
si2 = 0
|
|
so = 0
|
|
loop_axis = -1
|
|
|
|
for i in staticrange(arr1.ndim):
|
|
if shape[i] > 1 and (loop_axis == -1 or abs(arr1_strides[i]) < abs(si1)):
|
|
n = shape[i]
|
|
si1 = arr1_strides[i]
|
|
si2 = arr2_strides[i]
|
|
so = out_strides[i]
|
|
loop_axis = i
|
|
|
|
if loop_axis == -1:
|
|
n = shape[0]
|
|
si1 = arr1_strides[0]
|
|
si2 = arr2_strides[0]
|
|
so = out_strides[0]
|
|
loop_axis = 0
|
|
|
|
for idx in util.multirange(util.tuple_delete(shape, loop_axis)):
|
|
idx1 = util.tuple_insert(idx, loop_axis, 0)
|
|
p1 = arr1._ptr(idx1)
|
|
p2 = arr2._ptr(idx1)
|
|
q = out._ptr(idx1)
|
|
util.call_vectorized_loop(p1, si1, p2, si2, q, so, n, func)
|
|
|
|
# Operations
|
|
|
|
@inline
|
|
def _pos(x):
|
|
return +x
|
|
|
|
@inline
|
|
def _neg(x):
|
|
return -x
|
|
|
|
@inline
|
|
def _invert(x):
|
|
return ~x
|
|
|
|
@inline
|
|
def _abs(x):
|
|
return abs(x)
|
|
|
|
@inline
|
|
def _transpose(x):
|
|
return x.T
|
|
|
|
@inline
|
|
def _add(x, y):
|
|
return x + y
|
|
|
|
@inline
|
|
def _sub(x, y):
|
|
return x - y
|
|
|
|
@inline
|
|
def _mul(x, y):
|
|
return x * y
|
|
|
|
@inline
|
|
def _matmul(x, y):
|
|
return x @ y
|
|
|
|
@inline
|
|
def _true_div(x, y):
|
|
return x / y
|
|
|
|
@inline
|
|
def _floor_div(x, y):
|
|
X = type(x)
|
|
Y = type(y)
|
|
if isinstance(X, Int) and isinstance(Y, Int):
|
|
return util.pydiv(x, y)
|
|
else:
|
|
return x // y
|
|
|
|
@inline
|
|
def _mod(x, y):
|
|
X = type(x)
|
|
Y = type(y)
|
|
if isinstance(X, Int) and isinstance(Y, Int):
|
|
return util.pymod(x, y)
|
|
elif ((X is float and Y is float) or
|
|
(X is float32 and Y is float32) or
|
|
(X is float16 and Y is float16)):
|
|
return util.pyfmod(x, y)
|
|
else:
|
|
return x % y
|
|
|
|
@inline
|
|
def _fmod(x, y):
|
|
X = type(x)
|
|
Y = type(y)
|
|
if isinstance(X, Int) and isinstance(Y, Int):
|
|
return util.cmod_int(x, y)
|
|
elif ((X is float and Y is float) or
|
|
(X is float32 and Y is float32) or
|
|
(X is float16 and Y is float16)):
|
|
return util.cmod(x, y)
|
|
else:
|
|
return x % y
|
|
|
|
@inline
|
|
def _pow(x, y):
|
|
return x ** y
|
|
|
|
@inline
|
|
def _lshift(x, y):
|
|
return x << y
|
|
|
|
@inline
|
|
def _rshift(x, y):
|
|
return x >> y
|
|
|
|
@inline
|
|
def _and(x, y):
|
|
return x & y
|
|
|
|
@inline
|
|
def _or(x, y):
|
|
return x | y
|
|
|
|
@inline
|
|
def _xor(x, y):
|
|
return x ^ y
|
|
|
|
@inline
|
|
def _eq(x, y):
|
|
return x == y
|
|
|
|
@inline
|
|
def _ne(x, y):
|
|
return x != y
|
|
|
|
@inline
|
|
def _lt(x, y):
|
|
return x < y
|
|
|
|
@inline
|
|
def _le(x, y):
|
|
return x <= y
|
|
|
|
@inline
|
|
def _gt(x, y):
|
|
return x > y
|
|
|
|
@inline
|
|
def _ge(x, y):
|
|
return x >= y
|
|
|
|
def _apply(x, f, f_complex = None):
|
|
if f_complex is not None and (isinstance(x, complex) or isinstance(x, complex64)):
|
|
return f_complex(x)
|
|
elif isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16):
|
|
return f(x)
|
|
else:
|
|
return f(util.to_float(x))
|
|
|
|
def _apply2(x, y, f, f_complex = None):
|
|
if type(x) is not type(y):
|
|
compile_error("type mismatch in util")
|
|
|
|
if f_complex is not None and (isinstance(x, complex) or isinstance(x, complex64)):
|
|
return f_complex(x, y)
|
|
elif isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16):
|
|
return f(x, y)
|
|
else:
|
|
return f(util.to_float(x), util.to_float(y))
|
|
|
|
def _fabs(x):
|
|
return _apply(x, util.fabs)
|
|
|
|
def _rint(x):
|
|
def rint_complex(x):
|
|
C = type(x)
|
|
return C(util.rint(x.real), util.rint(x.imag))
|
|
|
|
return _apply(x, util.rint, rint_complex)
|
|
|
|
def _exp(x):
|
|
return _apply(x, util.exp, zmath.exp)
|
|
|
|
def _exp2(x):
|
|
return _apply(x, util.exp2, zmath.exp2)
|
|
|
|
def _expm1(x):
|
|
return _apply(x, util.expm1, zmath.expm1)
|
|
|
|
def _log(x):
|
|
return _apply(x, util.log, zmath.log)
|
|
|
|
def _log2(x):
|
|
return _apply(x, util.log2, zmath.log2)
|
|
|
|
def _log10(x):
|
|
return _apply(x, util.log10, zmath.log10)
|
|
|
|
def _log1p(x):
|
|
return _apply(x, util.log1p, zmath.log1p)
|
|
|
|
def _sqrt(x):
|
|
return _apply(x, util.sqrt, zmath.sqrt)
|
|
|
|
def _cbrt(x):
|
|
return _apply(x, util.cbrt)
|
|
|
|
def _square(x):
|
|
return x * x
|
|
|
|
def _sin(x):
|
|
return _apply(x, util.sin, zmath.sin)
|
|
|
|
def _cos(x):
|
|
return _apply(x, util.cos, zmath.cos)
|
|
|
|
def _tan(x):
|
|
return _apply(x, util.tan, zmath.tan)
|
|
|
|
def _arcsin(x):
|
|
return _apply(x, util.asin, zmath.asin)
|
|
|
|
def _arccos(x):
|
|
return _apply(x, util.acos, zmath.acos)
|
|
|
|
def _arctan(x):
|
|
return _apply(x, util.atan, zmath.atan)
|
|
|
|
def _sinh(x):
|
|
return _apply(x, util.sinh, zmath.sinh)
|
|
|
|
def _cosh(x):
|
|
return _apply(x, util.cosh, zmath.cosh)
|
|
|
|
def _tanh(x):
|
|
return _apply(x, util.tanh, zmath.tanh)
|
|
|
|
def _arcsinh(x):
|
|
return _apply(x, util.asinh, zmath.asinh)
|
|
|
|
def _arccosh(x):
|
|
return _apply(x, util.acosh, zmath.acosh)
|
|
|
|
def _arctanh(x):
|
|
return _apply(x, util.atanh, zmath.atanh)
|
|
|
|
def _rad2deg(x):
|
|
r2d = 180.0 / util.PI
|
|
x = util.to_float(x)
|
|
F = type(x)
|
|
return x * F(r2d)
|
|
|
|
def _deg2rad(x):
|
|
d2r = util.PI / 180.0
|
|
x = util.to_float(x)
|
|
F = type(x)
|
|
return x * F(d2r)
|
|
|
|
def _arctan2(x, y):
|
|
return _apply2(x, y, util.atan2)
|
|
|
|
def _hypot(x, y):
|
|
return _apply2(x, y, util.hypot)
|
|
|
|
def _logaddexp(x, y):
|
|
return _apply2(x, y, util.logaddexp)
|
|
|
|
def _logaddexp2(x, y):
|
|
return _apply2(x, y, util.logaddexp2)
|
|
|
|
def _isnan(x):
|
|
if isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16):
|
|
return util.isnan(x)
|
|
elif isinstance(x, complex) or isinstance(x, complex64):
|
|
return util.isnan(x.real) or util.isnan(x.imag)
|
|
else:
|
|
return False
|
|
|
|
def _isinf(x):
|
|
if isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16):
|
|
return util.isinf(x)
|
|
elif isinstance(x, complex) or isinstance(x, complex64):
|
|
return util.isinf(x.real) or util.isinf(x.imag)
|
|
else:
|
|
return False
|
|
|
|
def _isfinite(x):
|
|
if isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16):
|
|
return util.isfinite(x)
|
|
elif isinstance(x, complex) or isinstance(x, complex64):
|
|
return util.isfinite(x.real) and util.isfinite(x.imag)
|
|
else:
|
|
return True
|
|
|
|
def _signbit(x):
|
|
if isinstance(x, float) or isinstance(x, float32) or isinstance(x, float16):
|
|
return util.signbit(x)
|
|
else:
|
|
T = type(x)
|
|
return x < T()
|
|
|
|
def _copysign(x, y):
|
|
return _apply2(x, y, util.copysign)
|
|
|
|
def _nextafter(x, y):
|
|
return _apply2(x, y, util.nextafter)
|
|
|
|
def _floor(x):
|
|
return _apply(x, util.floor)
|
|
|
|
def _ceil(x):
|
|
return _apply(x, util.ceil)
|
|
|
|
def _trunc(x):
|
|
return _apply(x, util.trunc)
|
|
|
|
def _sign(x):
|
|
def sign1(x):
|
|
T = type(x)
|
|
if x < T(0):
|
|
return T(-1)
|
|
elif x > T(0):
|
|
return T(1)
|
|
else:
|
|
return x
|
|
|
|
if isinstance(x, complex):
|
|
if _isnan(x):
|
|
return complex(util.nan64(), 0.0)
|
|
return complex(sign1(x.real), 0) if x.real else complex(sign1(x.imag), 0)
|
|
elif isinstance(x, complex64):
|
|
if _isnan(x):
|
|
return complex64(util.nan64(), 0.0)
|
|
return complex64(sign1(x.real), 0) if x.real else complex64(sign1(x.imag), 0)
|
|
else:
|
|
return sign1(x)
|
|
|
|
def _heaviside(x, y):
|
|
def heaviside(x, y):
|
|
if isinstance(x, float16) and isinstance(y, float16):
|
|
if x < float16(0):
|
|
return float16(0)
|
|
elif x > float16(0):
|
|
return float16(1)
|
|
elif x == float16(0):
|
|
return y
|
|
else:
|
|
return x
|
|
elif isinstance(x, float32) and isinstance(y, float32):
|
|
if x < float32(0):
|
|
return float32(0)
|
|
elif x > float32(0):
|
|
return float32(1)
|
|
elif x == float32(0):
|
|
return y
|
|
else:
|
|
return x
|
|
elif isinstance(x, float) and isinstance(y, float):
|
|
if x < 0:
|
|
return 0.0
|
|
elif x > 0:
|
|
return 1.0
|
|
elif x == 0.0:
|
|
return y
|
|
else:
|
|
return x
|
|
|
|
return _apply2(x, y, heaviside)
|
|
|
|
def _conj(x):
|
|
if isinstance(x, complex) or isinstance(x, complex64):
|
|
return x.conjugate()
|
|
else:
|
|
return x
|
|
|
|
def _gcd(x, y):
|
|
''' # fails with optionals
|
|
if not (
|
|
isinstance(x, int) or
|
|
isinstance(x, Int) or
|
|
isinstance(x, UInt) or
|
|
isinstance(x, byte)
|
|
):
|
|
compile_error("gcd/lcm can only be used on integral types")
|
|
'''
|
|
|
|
while x:
|
|
z = x
|
|
x = y % x
|
|
y = z
|
|
return y
|
|
|
|
def _lcm(x, y):
|
|
gcd = _gcd(x, y)
|
|
return x // gcd * y if gcd else 0
|
|
|
|
def _reciprocal(x: T, T: type):
|
|
if (
|
|
isinstance(x, int) or
|
|
isinstance(x, Int) or
|
|
isinstance(x, UInt) or
|
|
isinstance(x, byte)
|
|
):
|
|
return T(1) // x
|
|
else:
|
|
return T(1) / x
|
|
|
|
def _logical_and(x, y):
|
|
return bool(x) and bool(y)
|
|
|
|
def _logical_or(x, y):
|
|
return bool(x) or bool(y)
|
|
|
|
def _logical_xor(x, y):
|
|
return bool(x) ^ bool(y)
|
|
|
|
def _logical_not(x):
|
|
return not bool(x)
|
|
|
|
def _coerce_types_for_minmax(x, y):
|
|
if isinstance(x, complex):
|
|
if isinstance(y, complex64):
|
|
return x, complex(y)
|
|
elif not isinstance(y, complex):
|
|
return x, complex(util.cast(y, float))
|
|
elif isinstance(x, complex64):
|
|
if isinstance(y, complex):
|
|
return complex(x), y
|
|
elif not isinstance(y, complex64):
|
|
return complex(x), complex(util.cast(y, float))
|
|
|
|
if isinstance(y, complex):
|
|
if isinstance(x, complex64):
|
|
return complex(x), y
|
|
elif not isinstance(x, complex):
|
|
return complex(util.cast(x, float)), y
|
|
elif isinstance(y, complex64):
|
|
if isinstance(x, complex):
|
|
return x, complex(y)
|
|
elif not isinstance(x, complex64):
|
|
return complex(util.cast(x, float)), complex(y)
|
|
|
|
T = type(util.coerce(type(x), type(y)))
|
|
return util.cast(x, T), util.cast(y, T)
|
|
|
|
def _compare_le(x, y):
|
|
if isinstance(x, complex) or isinstance(x, complex64):
|
|
return (x.real, x.imag) <= (y.real, y.imag)
|
|
else:
|
|
return x <= y
|
|
|
|
def _compare_ge(x, y):
|
|
if isinstance(x, complex) or isinstance(x, complex64):
|
|
return (x.real, x.imag) >= (y.real, y.imag)
|
|
else:
|
|
return x >= y
|
|
|
|
def _maximum(x, y):
|
|
x, y = _coerce_types_for_minmax(x, y)
|
|
|
|
if _isnan(x):
|
|
return x
|
|
|
|
if _isnan(y):
|
|
return y
|
|
|
|
return x if _compare_ge(x, y) else y
|
|
|
|
def _minimum(x, y):
|
|
x, y = _coerce_types_for_minmax(x, y)
|
|
|
|
if _isnan(x):
|
|
return x
|
|
|
|
if _isnan(y):
|
|
return y
|
|
|
|
return x if _compare_le(x, y) else y
|
|
|
|
def _fmax(x, y):
|
|
x, y = _coerce_types_for_minmax(x, y)
|
|
|
|
if _isnan(y):
|
|
return x
|
|
|
|
if _isnan(x):
|
|
return y
|
|
|
|
return x if _compare_ge(x, y) else y
|
|
|
|
def _fmin(x, y):
|
|
x, y = _coerce_types_for_minmax(x, y)
|
|
|
|
if _isnan(y):
|
|
return x
|
|
|
|
if _isnan(x):
|
|
return y
|
|
|
|
return x if _compare_le(x, y) else y
|
|
|
|
def _divmod_float(x, y):
|
|
F = type(x)
|
|
mod = util.cmod(x, y)
|
|
if not y:
|
|
return util.cdiv(x, y), mod
|
|
|
|
div = util.cdiv(x - mod, y)
|
|
if mod:
|
|
if (y < F(0)) != (mod < F(0)):
|
|
mod += y
|
|
div -= F(1)
|
|
else:
|
|
mod = util.copysign(F(0), y)
|
|
|
|
floordiv = F()
|
|
if div:
|
|
floordiv = util.floor(div)
|
|
if div - floordiv > F(0.5):
|
|
floordiv += F(1)
|
|
else:
|
|
floordiv = util.copysign(F(0), util.cdiv(x, y))
|
|
|
|
return floordiv, mod
|
|
|
|
def _divmod(x, y):
|
|
if isinstance(x, float16) and isinstance(y, float16):
|
|
return _divmod_float(x, y)
|
|
|
|
if isinstance(x, float32) and isinstance(y, float32):
|
|
return _divmod_float(x, y)
|
|
|
|
if isinstance(x, float) or isinstance(y, float):
|
|
return _divmod_float(util.cast(x, float), util.cast(y, float))
|
|
|
|
return (x // y, x % y)
|
|
|
|
def _modf(x):
|
|
return _apply(x, util.modf)
|
|
|
|
def _frexp(x):
|
|
def frexp(x):
|
|
a, b = util.frexp(x)
|
|
return a, i32(b)
|
|
|
|
return _apply(x, frexp)
|
|
|
|
def _spacing16(h: float16):
|
|
h_u16 = util.bitcast(h, u16)
|
|
h_exp = h_u16 & u16(0x7c00)
|
|
h_sig = h_u16 & u16(0x03ff)
|
|
|
|
if h_exp == u16(0x7c00):
|
|
return util.nan16()
|
|
elif h_u16 == u16(0x7bff):
|
|
return util.inf16()
|
|
elif (h_u16 & u16(0x8000)) and not h_sig:
|
|
if h_exp > u16(0x2c00):
|
|
return util.bitcast(h_exp - u16(0x2c00), float16)
|
|
elif h_exp > u16(0x0400):
|
|
return util.bitcast(u16(1) << ((h_exp >> u16(10)) - u16(2)), float16)
|
|
else:
|
|
return util.bitcast(u16(0x0001), float16)
|
|
elif h_exp > u16(0x2800):
|
|
return util.bitcast(h_exp - u16(0x2800), float16)
|
|
elif h_exp > u16(0x0400):
|
|
return util.bitcast(u16(1) << ((h_exp >> u16(10)) - u16(1)), float16)
|
|
else:
|
|
return util.bitcast(u16(0x0001), float16)
|
|
|
|
def _spacing(x):
|
|
if isinstance(x, float16):
|
|
return _spacing16(x)
|
|
elif isinstance(x, float32):
|
|
if util.isinf32(x):
|
|
return util.nan32()
|
|
p = util.inf32() if x >= float32(0) else -util.inf32()
|
|
return util.nextafter32(x, util.inf32()) - x
|
|
elif isinstance(x, float):
|
|
x = util.cast(x, float)
|
|
if util.isinf64(x):
|
|
return util.nan64()
|
|
p = util.inf64() if x >= 0 else -util.inf64()
|
|
return util.nextafter64(x, p) - x
|
|
else:
|
|
return _spacing(util.to_float(x))
|