1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/misc.codon
A. R. Shajii b8c1eeed36
2025 updates (#619)
* 2025 updates

* Update ci.yml
2025-01-29 15:41:43 -05:00

608 lines
16 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from .ndarray import ndarray
import util
MAXDIMS : Static[int] = 32
MAY_SHARE_BOUNDS : Static[int] = 0
MAY_SHARE_EXACT : Static[int] = -1
_MEM_OVERLAP_NO : Static[int] = 0 # no solution exists
_MEM_OVERLAP_YES : Static[int] = 1 # solution found
_MEM_OVERLAP_TOO_HARD: Static[int] = -1 # max_work exceeded
_MEM_OVERLAP_OVERFLOW: Static[int] = -2 # algorithm failed due to integer overflow
_MEM_OVERLAP_ERROR : Static[int] = -3 # invalid input
i128 = Int[128]
def _euclid(a1: int, a2: int):
gamma1 = 1
gamma2 = 0
epsilon1 = 0
epsilon2 = 1
a_gcd = 0
gamma = 0
epsilon = 0
while True:
if a2 > 0:
r = a1 // a2
a1 -= r * a2
gamma1 -= r * gamma2
epsilon1 -= r * epsilon2
else:
a_gcd = a1
gamma = gamma1
epsilon = epsilon1
break
if a1 > 0:
r = a2 // a1
a2 -= r * a1
gamma2 -= r * gamma1
epsilon2 -= r * epsilon1
else:
a_gcd = a2
gamma = gamma2
epsilon = epsilon2
break
return a_gcd, gamma, epsilon
@tuple
class DiophantineTerm:
a: int
ub: int
def __new__(a: int, ub: int) -> DiophantineTerm:
return (a, ub)
def __lt__(self, other: DiophantineTerm):
return self.a < other.a
def with_a(self, a: int):
return DiophantineTerm(a, self.ub)
def with_ub(self, ub: int):
return DiophantineTerm(self.a, ub)
def _sort(p: Ptr[DiophantineTerm], n: int):
arr = ndarray((n,), p)
arr.sort()
@pure
@llvm
def _safe_add_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
declare {i64, i1} @llvm.sadd.with.overflow.i64(i64, i64)
%c = call {i64, i1} @llvm.sadd.with.overflow.i64(i64 %a, i64 %b)
ret {i64, i1} %c
def _safe_add(a: int, b: int):
c, o = _safe_add_llvm(a, b)
return (c, bool(o))
@pure
@llvm
def _safe_sub_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
declare {i64, i1} @llvm.ssub.with.overflow.i64(i64, i64)
%c = call {i64, i1} @llvm.ssub.with.overflow.i64(i64 %a, i64 %b)
ret {i64, i1} %c
def _safe_sub(a: int, b: int):
c, o = _safe_sub_llvm(a, b)
return (c, bool(o))
@pure
@llvm
def _safe_mul_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
declare {i64, i1} @llvm.smul.with.overflow.i64(i64, i64)
%c = call {i64, i1} @llvm.smul.with.overflow.i64(i64 %a, i64 %b)
ret {i64, i1} %c
def _safe_mul(a: int, b: int):
c, o = _safe_mul_llvm(a, b)
return (c, bool(o))
@pure
@llvm
def _safe_add128_llvm(a: i128, b: i128) -> Tuple[i128, UInt[1]]:
declare {i128, i1} @llvm.sadd.with.overflow.i128(i128, i128)
%c = call {i128, i1} @llvm.sadd.with.overflow.i128(i128 %a, i128 %b)
ret {i128, i1} %c
def _safe_add128(a: i128, b: i128):
c, o = _safe_add128_llvm(a, b)
return (c, bool(o))
@pure
@llvm
def _safe_sub128_llvm(a: i128, b: i128) -> Tuple[i128, UInt[1]]:
declare {i128, i1} @llvm.ssub.with.overflow.i128(i128, i128)
%c = call {i128, i1} @llvm.ssub.with.overflow.i128(i128 %a, i128 %b)
ret {i128, i1} %c
def _safe_sub128(a: i128, b: i128):
c, o = _safe_sub128_llvm(a, b)
return (c, bool(o))
def _diophantine_precompute(n: int,
E: Ptr[DiophantineTerm],
Ep: Ptr[DiophantineTerm],
Gamma: Ptr[int],
Epsilon: Ptr[int]):
a_gcd, gamma, epsilon = _euclid(E[0].a, E[1].a)
Ep[0] = Ep[0].with_a(a_gcd)
Gamma[0] = gamma
Epsilon[0] = epsilon
if n > 2:
c1 = E[0].a // a_gcd
c2 = E[1].a // a_gcd
x1, o1 = _safe_mul(E[0].ub, c1)
x2, o2 = _safe_mul(E[1].ub, c2)
x3, o3 = _safe_add(x1, x2)
Ep[0] = Ep[0].with_ub(x3)
if o1 or o2 or o3:
return True
for j in range(2, n):
a_gcd, gamma, epsilon = _euclid(Ep[j - 2].a, E[j].a)
Ep[j - 1] = Ep[j - 1].with_a(a_gcd)
Gamma[j - 1] = gamma
Epsilon[j - 1] = epsilon
if j < n - 1:
c1 = Ep[j - 2].a // a_gcd
c2 = E[j].a // a_gcd
x1, o1 = _safe_mul(c1, Ep[j - 2].ub)
x2, o2 = _safe_mul(c2, E[j].ub)
x3, o3 = _safe_add(x1, x2)
Ep[j - 1] = Ep[j - 1].with_ub(x3)
if o1 or o2 or o3:
return True
return False
def _floordiv(a: i128, b: int):
b = i128(b)
result, remainder = a // b, a % b
if a < i128(0) and remainder != i128(0):
result -= i128(1)
return result
def _ceildiv(a: i128, b: int):
b = i128(b)
result, remainder = a // b, a % b
if a > i128(0) and remainder != i128(0):
result += i128(1)
return result
def _to_64(x: i128):
if x > i128(9223372036854775807) or x < i128(-9223372036854775808):
return 0, True
else:
return int(x), False
def _diophantine_dfs(n: int,
v: int,
E: Ptr[DiophantineTerm],
Ep: Ptr[DiophantineTerm],
Gamma: Ptr[int],
Epsilon: Ptr[int],
b: int,
max_work: int,
require_ub_nontrivial: bool,
x: Ptr[int],
count: Ptr[int]):
if max_work >= 0 and count[0] >= max_work:
return _MEM_OVERLAP_TOO_HARD
if v == 1:
a1 = E[0].a
u1 = E[0].ub
else:
a1 = Ep[v - 2].a
u1 = Ep[v - 2].ub
a2 = E[v].a
u2 = E[v].ub
a_gcd = Ep[v - 1].a
gamma = Gamma[v - 1]
epsilon = Epsilon[v - 1]
c = b // a_gcd
r = b % a_gcd
if r != 0:
count[0] += 1
return _MEM_OVERLAP_NO
c1 = a2 // a_gcd
c2 = a1 // a_gcd
x10 = i128(gamma) * i128(c)
x20 = i128(epsilon) * i128(c)
t_l1 = _ceildiv(-x10, c1)
v1, o1 = _safe_sub128(x20, i128(u2))
t_l2 = _ceildiv(v1, c2)
v2, o2 = _safe_sub128(i128(u1), x10)
t_u1 = _floordiv(v2, c1)
t_u2 = _floordiv(x20, c2)
if o1 or o2:
return _MEM_OVERLAP_OVERFLOW
if t_l2 > t_l1:
tl1 = t_l2
if t_u1 > t_u2:
t_u1 = t_u2
if t_l1 > t_u1:
count[0] += 1
return _MEM_OVERLAP_NO
t_l, o1 = _to_64(t_l1)
t_u, o2 = _to_64(t_u1)
x10, o3 = _safe_add128(x10, i128(c1) * i128(t_l))
x20, o4 = _safe_sub128(x20, i128(c2) * i128(t_l))
t_u, o5 = _safe_sub(t_u, t_l)
t_l = 0
x1, o6 = _to_64(x10)
x2, o7 = _to_64(x20)
if o1 or o2 or o3 or o4 or o5 or o6 or o7:
return _MEM_OVERLAP_OVERFLOW
if v == 1:
if t_u >= t_l:
x[0] = x1 + (c1 * t_l)
x[1] = x2 - (c2 * t_l)
if require_ub_nontrivial:
is_ub_nontrivial = True
for j in range(n):
if x[j] != E[j].ub // 2:
is_ub_nontrivial = False
break
if is_ub_nontrivial:
count[0] += 1
return _MEM_OVERLAP_NO
return _MEM_OVERLAP_YES
count[0] += 1
return _MEM_OVERLAP_NO
else:
t = t_l
while t <= t_u:
x[v] = x2 - (c2 * t)
v1, o1 = _safe_mul(a2, x[v])
b2, o2 = _safe_sub(b, v1)
if o1 or o2:
return _MEM_OVERLAP_OVERFLOW
res = _diophantine_dfs(n, v - 1, E, Ep, Gamma, Epsilon,
b2, max_work, require_ub_nontrivial,
x, count)
if res != _MEM_OVERLAP_NO:
return res
t += 1
count[0] += 1
return _MEM_OVERLAP_NO
def _solve_diophantine(n: int,
E: Ptr[DiophantineTerm],
b: int,
max_work: int,
require_ub_nontrivial: bool,
x: Ptr[int]):
for j in range(n):
if E[j].a <= 0:
return _MEM_OVERLAP_ERROR
elif E[j].ub < 0:
return _MEM_OVERLAP_NO
if require_ub_nontrivial:
ub_sum = 0
o = False
for j in range(n):
if E[j].ub % 2 != 0:
return _MEM_OVERLAP_ERROR
v1, o1 = _safe_mul(E[j].a, E[j].ub // 2)
v2, o2 = _safe_add(ub_sum, v1)
ub_sum = v2
o = o or o1 or o2
if o:
return _MEM_OVERLAP_ERROR
b = ub_sum
if b < 0:
return _MEM_OVERLAP_NO
if n == 0:
if require_ub_nontrivial:
return _MEM_OVERLAP_NO
if b == 0:
return _MEM_OVERLAP_YES
return _MEM_OVERLAP_NO
elif n == 1:
if require_ub_nontrivial:
return _MEM_OVERLAP_NO
if b % E[0].a == 0:
x[0] = b // E[0].a
if x[0] >= 0 and x[0] <= E[0].ub:
return _MEM_OVERLAP_YES
return _MEM_OVERLAP_NO
else:
count = 0
Ep = Ptr[DiophantineTerm](n)
Epsilon = Ptr[int](n)
Gamma = Ptr[int](n)
if _diophantine_precompute(n, E, Ep, Gamma, Epsilon):
res = _MEM_OVERLAP_OVERFLOW
else:
res = _diophantine_dfs(n, n - 1, E, Ep, Gamma, Epsilon, b, max_work,
require_ub_nontrivial, x, __ptr__(count))
util.free(Ep)
util.free(Gamma)
util.free(Epsilon)
return res
def _diophantine_simplify(n: Ptr[int], E: Ptr[DiophantineTerm], b: int):
for j in range(n[0]):
if E[j].ub < 0:
return 0
if b < 0:
return 0
_sort(E, n[0])
o = False
m = n[0]
i = 0
for j in range(1, m):
if E[i].a == E[j].a:
v1, o1 = _safe_add(E[i].ub, E[j].ub)
o = o or o1
E[i] = E[i].with_ub(v1)
n[0] -= 1
else:
i += 1
if i != j:
E[i] = E[j]
m = n[0]
i = 0
for j in range(m):
E[j] = E[j].with_ub(min(E[j].ub, b // E[j].a))
if E[j].ub == 0:
n[0] -= 1
else:
if i != j:
E[i] = E[j]
i += 1
if o:
return -1
else:
return 0
def _offset_bounds_from_strides(arr: ndarray):
lower = 0
upper = 0
nd = arr.ndim
shape = arr.shape
strides = arr.strides
for i in range(nd):
if shape[i] == 0:
return (0, 0)
max_axis_offset = strides[i] * (shape[i] - 1)
if max_axis_offset > 0:
upper += max_axis_offset
else:
lower += max_axis_offset
upper += arr.itemsize
return (lower, upper)
def _get_array_memory_extents(arr: ndarray):
low, upper = _offset_bounds_from_strides(arr)
out_start = arr.data.as_byte().__int__() + low
out_end = arr.data.as_byte().__int__() + upper
num_bytes = arr.itemsize
for j in range(arr.ndim):
num_bytes *= arr.shape[j]
return u64(out_start), u64(out_end), u64(num_bytes)
def _strides_to_terms(arr: ndarray,
terms: Ptr[DiophantineTerm],
nterms: Ptr[int],
skip_empty: bool):
for i in range(arr.ndim):
if skip_empty:
if arr.shape[i] <= 1 or arr.strides[i] == 0:
continue
terms[nterms[0]] = terms[nterms[0]].with_a(arr.strides[i])
if terms[nterms[0]].a < 0:
terms[nterms[0]] = terms[nterms[0]].with_a(-terms[nterms[0]].a)
if terms[nterms[0]].a < 0:
return True
terms[nterms[0]] = terms[nterms[0]].with_ub(arr.shape[i] - 1)
nterms[0] += 1
return False
def _solve_may_share_memory(a: ndarray, b: ndarray, max_work: int):
if a.ndim > MAXDIMS or b.ndim > MAXDIMS:
compile_error("maximum array dimension exceeded")
terms_tuple = (DiophantineTerm(0, 0),) * (2*MAXDIMS + 2)
x_tuple = (0,) * (2*MAXDIMS + 2)
terms = Ptr[DiophantineTerm](__ptr__(terms_tuple).as_byte())
x = Ptr[int](__ptr__(x_tuple).as_byte())
start1, end1, size1 = _get_array_memory_extents(a)
start2, end2, size2 = _get_array_memory_extents(b)
if not (start1 < end2 and start2 < end1 and start1 < end1 and start2 < end2):
return _MEM_OVERLAP_NO
if max_work == 0:
return _MEM_OVERLAP_TOO_HARD
uintp_rhs = min(end2 - u64(1) - start1, end1 - u64(1) - start2)
if uintp_rhs > u64(9223372036854775807):
return _MEM_OVERLAP_OVERFLOW
rhs = int(uintp_rhs)
nterms = 0
if _strides_to_terms(a, terms, __ptr__(nterms), True):
return _MEM_OVERLAP_OVERFLOW
if _strides_to_terms(b, terms, __ptr__(nterms), True):
return _MEM_OVERLAP_OVERFLOW
if a.itemsize > 1:
terms[nterms] = DiophantineTerm(a=1, ub=a.itemsize-1)
nterms += 1
if b.itemsize > 1:
terms[nterms] = DiophantineTerm(a=1, ub=b.itemsize-1)
nterms += 1
if _diophantine_simplify(__ptr__(nterms), terms, rhs):
return _MEM_OVERLAP_OVERFLOW
return _solve_diophantine(nterms, terms, rhs, max_work, False, x)
def _solve_may_have_internal_overlap(a: ndarray, max_work: int):
if a.ndim > MAXDIMS:
compile_error("maximum array dimension exceeded")
terms_tuple = (DiophantineTerm(0, 0),) * (MAXDIMS + 1)
x_tuple = (0,) * (MAXDIMS + 1)
terms = Ptr[DiophantineTerm](__ptr__(terms_tuple).as_byte())
x = Ptr[int](__ptr__(x_tuple).as_byte())
cc, fc = a._contig
if cc or fc:
return _MEM_OVERLAP_NO
nterms = 0
if _strides_to_terms(a, terms, __ptr__(nterms), False):
return _MEM_OVERLAP_OVERFLOW
if a.itemsize > 1:
terms[nterms] = DiophantineTerm(a=1, ub=a.itemsize-1)
nterms += 1
i = 0
for j in range(nterms):
if terms[j].ub == 0:
continue
elif terms[j].ub < 0:
return _MEM_OVERLAP_NO
elif terms[j].a == 0:
return _MEM_OVERLAP_YES
if i != j:
terms[i] = terms[j]
i += 1
nterms = i
for j in range(nterms):
terms[j] = terms[j].with_ub(terms[j].ub * 2)
_sort(terms, nterms)
return _solve_diophantine(nterms, terms, -1, max_work, True, x)
def _array_shaes_memory_impl(a: ndarray, b: ndarray, max_work: int, raise_exception: bool):
if max_work < -2:
raise ValueError("Invalid value for max_work")
result = _solve_may_share_memory(a, b, max_work)
if result == _MEM_OVERLAP_NO:
return False
elif result == _MEM_OVERLAP_YES:
return True
elif result == _MEM_OVERLAP_OVERFLOW:
if raise_exception:
raise OverflowError("Integer overflow in computing overlap")
else:
return True
elif result == _MEM_OVERLAP_TOO_HARD:
if raise_exception:
raise util.TooHardError("Exceeded max_work")
else:
return True
else:
raise RuntimeError("Error in computing overlap")
def shares_memory(a: ndarray, b: ndarray, max_work: Optional[int] = None):
mw = 0
if max_work is None:
mw = MAY_SHARE_EXACT
else:
mw = max_work
return _array_shaes_memory_impl(a, b, mw, True)
def may_share_memory(a: ndarray, b: ndarray, max_work: Optional[int] = None):
mw = 0
if max_work is None:
mw = MAY_SHARE_BOUNDS
else:
mw = max_work
return _array_shaes_memory_impl(a, b, mw, False)
def setbufsize(size: int):
pass # Codon-NumPy does not use ufunc buffers
def getbufsize():
return 0 # Codon-NumPy does not use ufunc buffers
def byte_bounds(a: ndarray):
a_data = a.data
astrides = a.strides
ashape = a.shape
a_low = a_data.as_byte().__int__()
a_high = a_data.as_byte().__int__()
bytes_a = a.itemsize
cc, fc = a._contig
if cc or fc:
a_high += a.size * bytes_a
else:
for i in staticrange(staticlen(ashape)):
shape = ashape[i]
stride = astrides[i]
if stride < 0:
a_low += (shape - 1) * stride
else:
a_high += (shape - 1) * stride
a_high += bytes_a
return (a_low, a_high)