mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
608 lines
16 KiB
Python
608 lines
16 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
from .ndarray import ndarray
|
|
import util
|
|
|
|
MAXDIMS : Static[int] = 32
|
|
MAY_SHARE_BOUNDS : Static[int] = 0
|
|
MAY_SHARE_EXACT : Static[int] = -1
|
|
_MEM_OVERLAP_NO : Static[int] = 0 # no solution exists
|
|
_MEM_OVERLAP_YES : Static[int] = 1 # solution found
|
|
_MEM_OVERLAP_TOO_HARD: Static[int] = -1 # max_work exceeded
|
|
_MEM_OVERLAP_OVERFLOW: Static[int] = -2 # algorithm failed due to integer overflow
|
|
_MEM_OVERLAP_ERROR : Static[int] = -3 # invalid input
|
|
|
|
i128 = Int[128]
|
|
|
|
def _euclid(a1: int, a2: int):
|
|
gamma1 = 1
|
|
gamma2 = 0
|
|
epsilon1 = 0
|
|
epsilon2 = 1
|
|
|
|
a_gcd = 0
|
|
gamma = 0
|
|
epsilon = 0
|
|
|
|
while True:
|
|
if a2 > 0:
|
|
r = a1 // a2
|
|
a1 -= r * a2
|
|
gamma1 -= r * gamma2
|
|
epsilon1 -= r * epsilon2
|
|
else:
|
|
a_gcd = a1
|
|
gamma = gamma1
|
|
epsilon = epsilon1
|
|
break
|
|
|
|
if a1 > 0:
|
|
r = a2 // a1
|
|
a2 -= r * a1
|
|
gamma2 -= r * gamma1
|
|
epsilon2 -= r * epsilon1
|
|
else:
|
|
a_gcd = a2
|
|
gamma = gamma2
|
|
epsilon = epsilon2
|
|
break
|
|
|
|
return a_gcd, gamma, epsilon
|
|
|
|
@tuple
|
|
class DiophantineTerm:
|
|
a: int
|
|
ub: int
|
|
|
|
def __new__(a: int, ub: int) -> DiophantineTerm:
|
|
return (a, ub)
|
|
|
|
def __lt__(self, other: DiophantineTerm):
|
|
return self.a < other.a
|
|
|
|
def with_a(self, a: int):
|
|
return DiophantineTerm(a, self.ub)
|
|
|
|
def with_ub(self, ub: int):
|
|
return DiophantineTerm(self.a, ub)
|
|
|
|
def _sort(p: Ptr[DiophantineTerm], n: int):
|
|
arr = ndarray((n,), p)
|
|
arr.sort()
|
|
|
|
@pure
|
|
@llvm
|
|
def _safe_add_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
|
|
declare {i64, i1} @llvm.sadd.with.overflow.i64(i64, i64)
|
|
%c = call {i64, i1} @llvm.sadd.with.overflow.i64(i64 %a, i64 %b)
|
|
ret {i64, i1} %c
|
|
|
|
def _safe_add(a: int, b: int):
|
|
c, o = _safe_add_llvm(a, b)
|
|
return (c, bool(o))
|
|
|
|
@pure
|
|
@llvm
|
|
def _safe_sub_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
|
|
declare {i64, i1} @llvm.ssub.with.overflow.i64(i64, i64)
|
|
%c = call {i64, i1} @llvm.ssub.with.overflow.i64(i64 %a, i64 %b)
|
|
ret {i64, i1} %c
|
|
|
|
def _safe_sub(a: int, b: int):
|
|
c, o = _safe_sub_llvm(a, b)
|
|
return (c, bool(o))
|
|
|
|
@pure
|
|
@llvm
|
|
def _safe_mul_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
|
|
declare {i64, i1} @llvm.smul.with.overflow.i64(i64, i64)
|
|
%c = call {i64, i1} @llvm.smul.with.overflow.i64(i64 %a, i64 %b)
|
|
ret {i64, i1} %c
|
|
|
|
def _safe_mul(a: int, b: int):
|
|
c, o = _safe_mul_llvm(a, b)
|
|
return (c, bool(o))
|
|
|
|
@pure
|
|
@llvm
|
|
def _safe_add128_llvm(a: i128, b: i128) -> Tuple[i128, UInt[1]]:
|
|
declare {i128, i1} @llvm.sadd.with.overflow.i128(i128, i128)
|
|
%c = call {i128, i1} @llvm.sadd.with.overflow.i128(i128 %a, i128 %b)
|
|
ret {i128, i1} %c
|
|
|
|
def _safe_add128(a: i128, b: i128):
|
|
c, o = _safe_add128_llvm(a, b)
|
|
return (c, bool(o))
|
|
|
|
@pure
|
|
@llvm
|
|
def _safe_sub128_llvm(a: i128, b: i128) -> Tuple[i128, UInt[1]]:
|
|
declare {i128, i1} @llvm.ssub.with.overflow.i128(i128, i128)
|
|
%c = call {i128, i1} @llvm.ssub.with.overflow.i128(i128 %a, i128 %b)
|
|
ret {i128, i1} %c
|
|
|
|
def _safe_sub128(a: i128, b: i128):
|
|
c, o = _safe_sub128_llvm(a, b)
|
|
return (c, bool(o))
|
|
|
|
def _diophantine_precompute(n: int,
|
|
E: Ptr[DiophantineTerm],
|
|
Ep: Ptr[DiophantineTerm],
|
|
Gamma: Ptr[int],
|
|
Epsilon: Ptr[int]):
|
|
a_gcd, gamma, epsilon = _euclid(E[0].a, E[1].a)
|
|
Ep[0] = Ep[0].with_a(a_gcd)
|
|
Gamma[0] = gamma
|
|
Epsilon[0] = epsilon
|
|
|
|
if n > 2:
|
|
c1 = E[0].a // a_gcd
|
|
c2 = E[1].a // a_gcd
|
|
x1, o1 = _safe_mul(E[0].ub, c1)
|
|
x2, o2 = _safe_mul(E[1].ub, c2)
|
|
x3, o3 = _safe_add(x1, x2)
|
|
Ep[0] = Ep[0].with_ub(x3)
|
|
if o1 or o2 or o3:
|
|
return True
|
|
|
|
for j in range(2, n):
|
|
a_gcd, gamma, epsilon = _euclid(Ep[j - 2].a, E[j].a)
|
|
Ep[j - 1] = Ep[j - 1].with_a(a_gcd)
|
|
Gamma[j - 1] = gamma
|
|
Epsilon[j - 1] = epsilon
|
|
|
|
if j < n - 1:
|
|
c1 = Ep[j - 2].a // a_gcd
|
|
c2 = E[j].a // a_gcd
|
|
x1, o1 = _safe_mul(c1, Ep[j - 2].ub)
|
|
x2, o2 = _safe_mul(c2, E[j].ub)
|
|
x3, o3 = _safe_add(x1, x2)
|
|
Ep[j - 1] = Ep[j - 1].with_ub(x3)
|
|
if o1 or o2 or o3:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _floordiv(a: i128, b: int):
|
|
b = i128(b)
|
|
result, remainder = a // b, a % b
|
|
if a < i128(0) and remainder != i128(0):
|
|
result -= i128(1)
|
|
return result
|
|
|
|
def _ceildiv(a: i128, b: int):
|
|
b = i128(b)
|
|
result, remainder = a // b, a % b
|
|
if a > i128(0) and remainder != i128(0):
|
|
result += i128(1)
|
|
return result
|
|
|
|
def _to_64(x: i128):
|
|
if x > i128(9223372036854775807) or x < i128(-9223372036854775808):
|
|
return 0, True
|
|
else:
|
|
return int(x), False
|
|
|
|
def _diophantine_dfs(n: int,
|
|
v: int,
|
|
E: Ptr[DiophantineTerm],
|
|
Ep: Ptr[DiophantineTerm],
|
|
Gamma: Ptr[int],
|
|
Epsilon: Ptr[int],
|
|
b: int,
|
|
max_work: int,
|
|
require_ub_nontrivial: bool,
|
|
x: Ptr[int],
|
|
count: Ptr[int]):
|
|
|
|
if max_work >= 0 and count[0] >= max_work:
|
|
return _MEM_OVERLAP_TOO_HARD
|
|
|
|
if v == 1:
|
|
a1 = E[0].a
|
|
u1 = E[0].ub
|
|
else:
|
|
a1 = Ep[v - 2].a
|
|
u1 = Ep[v - 2].ub
|
|
|
|
a2 = E[v].a
|
|
u2 = E[v].ub
|
|
|
|
a_gcd = Ep[v - 1].a
|
|
gamma = Gamma[v - 1]
|
|
epsilon = Epsilon[v - 1]
|
|
|
|
c = b // a_gcd
|
|
r = b % a_gcd
|
|
if r != 0:
|
|
count[0] += 1
|
|
return _MEM_OVERLAP_NO
|
|
|
|
c1 = a2 // a_gcd
|
|
c2 = a1 // a_gcd
|
|
|
|
x10 = i128(gamma) * i128(c)
|
|
x20 = i128(epsilon) * i128(c)
|
|
|
|
t_l1 = _ceildiv(-x10, c1)
|
|
v1, o1 = _safe_sub128(x20, i128(u2))
|
|
t_l2 = _ceildiv(v1, c2)
|
|
|
|
v2, o2 = _safe_sub128(i128(u1), x10)
|
|
t_u1 = _floordiv(v2, c1)
|
|
t_u2 = _floordiv(x20, c2)
|
|
|
|
if o1 or o2:
|
|
return _MEM_OVERLAP_OVERFLOW
|
|
|
|
if t_l2 > t_l1:
|
|
tl1 = t_l2
|
|
|
|
if t_u1 > t_u2:
|
|
t_u1 = t_u2
|
|
|
|
if t_l1 > t_u1:
|
|
count[0] += 1
|
|
return _MEM_OVERLAP_NO
|
|
|
|
t_l, o1 = _to_64(t_l1)
|
|
t_u, o2 = _to_64(t_u1)
|
|
|
|
x10, o3 = _safe_add128(x10, i128(c1) * i128(t_l))
|
|
x20, o4 = _safe_sub128(x20, i128(c2) * i128(t_l))
|
|
|
|
t_u, o5 = _safe_sub(t_u, t_l)
|
|
t_l = 0
|
|
x1, o6 = _to_64(x10)
|
|
x2, o7 = _to_64(x20)
|
|
|
|
if o1 or o2 or o3 or o4 or o5 or o6 or o7:
|
|
return _MEM_OVERLAP_OVERFLOW
|
|
|
|
if v == 1:
|
|
if t_u >= t_l:
|
|
x[0] = x1 + (c1 * t_l)
|
|
x[1] = x2 - (c2 * t_l)
|
|
if require_ub_nontrivial:
|
|
is_ub_nontrivial = True
|
|
for j in range(n):
|
|
if x[j] != E[j].ub // 2:
|
|
is_ub_nontrivial = False
|
|
break
|
|
|
|
if is_ub_nontrivial:
|
|
count[0] += 1
|
|
return _MEM_OVERLAP_NO
|
|
|
|
return _MEM_OVERLAP_YES
|
|
|
|
count[0] += 1
|
|
return _MEM_OVERLAP_NO
|
|
else:
|
|
t = t_l
|
|
while t <= t_u:
|
|
x[v] = x2 - (c2 * t)
|
|
v1, o1 = _safe_mul(a2, x[v])
|
|
b2, o2 = _safe_sub(b, v1)
|
|
|
|
if o1 or o2:
|
|
return _MEM_OVERLAP_OVERFLOW
|
|
|
|
res = _diophantine_dfs(n, v - 1, E, Ep, Gamma, Epsilon,
|
|
b2, max_work, require_ub_nontrivial,
|
|
x, count)
|
|
if res != _MEM_OVERLAP_NO:
|
|
return res
|
|
t += 1
|
|
|
|
count[0] += 1
|
|
return _MEM_OVERLAP_NO
|
|
|
|
def _solve_diophantine(n: int,
|
|
E: Ptr[DiophantineTerm],
|
|
b: int,
|
|
max_work: int,
|
|
require_ub_nontrivial: bool,
|
|
x: Ptr[int]):
|
|
for j in range(n):
|
|
if E[j].a <= 0:
|
|
return _MEM_OVERLAP_ERROR
|
|
elif E[j].ub < 0:
|
|
return _MEM_OVERLAP_NO
|
|
|
|
if require_ub_nontrivial:
|
|
ub_sum = 0
|
|
o = False
|
|
for j in range(n):
|
|
if E[j].ub % 2 != 0:
|
|
return _MEM_OVERLAP_ERROR
|
|
|
|
v1, o1 = _safe_mul(E[j].a, E[j].ub // 2)
|
|
v2, o2 = _safe_add(ub_sum, v1)
|
|
ub_sum = v2
|
|
o = o or o1 or o2
|
|
|
|
if o:
|
|
return _MEM_OVERLAP_ERROR
|
|
b = ub_sum
|
|
|
|
if b < 0:
|
|
return _MEM_OVERLAP_NO
|
|
|
|
if n == 0:
|
|
if require_ub_nontrivial:
|
|
return _MEM_OVERLAP_NO
|
|
|
|
if b == 0:
|
|
return _MEM_OVERLAP_YES
|
|
|
|
return _MEM_OVERLAP_NO
|
|
elif n == 1:
|
|
if require_ub_nontrivial:
|
|
return _MEM_OVERLAP_NO
|
|
|
|
if b % E[0].a == 0:
|
|
x[0] = b // E[0].a
|
|
if x[0] >= 0 and x[0] <= E[0].ub:
|
|
return _MEM_OVERLAP_YES
|
|
|
|
return _MEM_OVERLAP_NO
|
|
else:
|
|
count = 0
|
|
Ep = Ptr[DiophantineTerm](n)
|
|
Epsilon = Ptr[int](n)
|
|
Gamma = Ptr[int](n)
|
|
|
|
if _diophantine_precompute(n, E, Ep, Gamma, Epsilon):
|
|
res = _MEM_OVERLAP_OVERFLOW
|
|
else:
|
|
res = _diophantine_dfs(n, n - 1, E, Ep, Gamma, Epsilon, b, max_work,
|
|
require_ub_nontrivial, x, __ptr__(count))
|
|
|
|
util.free(Ep)
|
|
util.free(Gamma)
|
|
util.free(Epsilon)
|
|
return res
|
|
|
|
def _diophantine_simplify(n: Ptr[int], E: Ptr[DiophantineTerm], b: int):
|
|
for j in range(n[0]):
|
|
if E[j].ub < 0:
|
|
return 0
|
|
|
|
if b < 0:
|
|
return 0
|
|
|
|
_sort(E, n[0])
|
|
|
|
o = False
|
|
m = n[0]
|
|
i = 0
|
|
for j in range(1, m):
|
|
if E[i].a == E[j].a:
|
|
v1, o1 = _safe_add(E[i].ub, E[j].ub)
|
|
o = o or o1
|
|
E[i] = E[i].with_ub(v1)
|
|
n[0] -= 1
|
|
else:
|
|
i += 1
|
|
if i != j:
|
|
E[i] = E[j]
|
|
|
|
m = n[0]
|
|
i = 0
|
|
for j in range(m):
|
|
E[j] = E[j].with_ub(min(E[j].ub, b // E[j].a))
|
|
if E[j].ub == 0:
|
|
n[0] -= 1
|
|
else:
|
|
if i != j:
|
|
E[i] = E[j]
|
|
i += 1
|
|
|
|
if o:
|
|
return -1
|
|
else:
|
|
return 0
|
|
|
|
def _offset_bounds_from_strides(arr: ndarray):
|
|
lower = 0
|
|
upper = 0
|
|
nd = arr.ndim
|
|
shape = arr.shape
|
|
strides = arr.strides
|
|
|
|
for i in range(nd):
|
|
if shape[i] == 0:
|
|
return (0, 0)
|
|
|
|
max_axis_offset = strides[i] * (shape[i] - 1)
|
|
if max_axis_offset > 0:
|
|
upper += max_axis_offset
|
|
else:
|
|
lower += max_axis_offset
|
|
|
|
upper += arr.itemsize
|
|
return (lower, upper)
|
|
|
|
def _get_array_memory_extents(arr: ndarray):
|
|
low, upper = _offset_bounds_from_strides(arr)
|
|
out_start = arr.data.as_byte().__int__() + low
|
|
out_end = arr.data.as_byte().__int__() + upper
|
|
|
|
num_bytes = arr.itemsize
|
|
for j in range(arr.ndim):
|
|
num_bytes *= arr.shape[j]
|
|
|
|
return u64(out_start), u64(out_end), u64(num_bytes)
|
|
|
|
def _strides_to_terms(arr: ndarray,
|
|
terms: Ptr[DiophantineTerm],
|
|
nterms: Ptr[int],
|
|
skip_empty: bool):
|
|
for i in range(arr.ndim):
|
|
if skip_empty:
|
|
if arr.shape[i] <= 1 or arr.strides[i] == 0:
|
|
continue
|
|
|
|
terms[nterms[0]] = terms[nterms[0]].with_a(arr.strides[i])
|
|
|
|
if terms[nterms[0]].a < 0:
|
|
terms[nterms[0]] = terms[nterms[0]].with_a(-terms[nterms[0]].a)
|
|
|
|
if terms[nterms[0]].a < 0:
|
|
return True
|
|
|
|
terms[nterms[0]] = terms[nterms[0]].with_ub(arr.shape[i] - 1)
|
|
nterms[0] += 1
|
|
|
|
return False
|
|
|
|
def _solve_may_share_memory(a: ndarray, b: ndarray, max_work: int):
|
|
if a.ndim > MAXDIMS or b.ndim > MAXDIMS:
|
|
compile_error("maximum array dimension exceeded")
|
|
|
|
terms_tuple = (DiophantineTerm(0, 0),) * (2*MAXDIMS + 2)
|
|
x_tuple = (0,) * (2*MAXDIMS + 2)
|
|
terms = Ptr[DiophantineTerm](__ptr__(terms_tuple).as_byte())
|
|
x = Ptr[int](__ptr__(x_tuple).as_byte())
|
|
|
|
start1, end1, size1 = _get_array_memory_extents(a)
|
|
start2, end2, size2 = _get_array_memory_extents(b)
|
|
|
|
if not (start1 < end2 and start2 < end1 and start1 < end1 and start2 < end2):
|
|
return _MEM_OVERLAP_NO
|
|
|
|
if max_work == 0:
|
|
return _MEM_OVERLAP_TOO_HARD
|
|
|
|
uintp_rhs = min(end2 - u64(1) - start1, end1 - u64(1) - start2)
|
|
if uintp_rhs > u64(9223372036854775807):
|
|
return _MEM_OVERLAP_OVERFLOW
|
|
rhs = int(uintp_rhs)
|
|
|
|
nterms = 0
|
|
if _strides_to_terms(a, terms, __ptr__(nterms), True):
|
|
return _MEM_OVERLAP_OVERFLOW
|
|
if _strides_to_terms(b, terms, __ptr__(nterms), True):
|
|
return _MEM_OVERLAP_OVERFLOW
|
|
if a.itemsize > 1:
|
|
terms[nterms] = DiophantineTerm(a=1, ub=a.itemsize-1)
|
|
nterms += 1
|
|
if b.itemsize > 1:
|
|
terms[nterms] = DiophantineTerm(a=1, ub=b.itemsize-1)
|
|
nterms += 1
|
|
|
|
if _diophantine_simplify(__ptr__(nterms), terms, rhs):
|
|
return _MEM_OVERLAP_OVERFLOW
|
|
|
|
return _solve_diophantine(nterms, terms, rhs, max_work, False, x)
|
|
|
|
def _solve_may_have_internal_overlap(a: ndarray, max_work: int):
|
|
if a.ndim > MAXDIMS:
|
|
compile_error("maximum array dimension exceeded")
|
|
|
|
terms_tuple = (DiophantineTerm(0, 0),) * (MAXDIMS + 1)
|
|
x_tuple = (0,) * (MAXDIMS + 1)
|
|
terms = Ptr[DiophantineTerm](__ptr__(terms_tuple).as_byte())
|
|
x = Ptr[int](__ptr__(x_tuple).as_byte())
|
|
|
|
cc, fc = a._contig
|
|
if cc or fc:
|
|
return _MEM_OVERLAP_NO
|
|
|
|
|
|
nterms = 0
|
|
if _strides_to_terms(a, terms, __ptr__(nterms), False):
|
|
return _MEM_OVERLAP_OVERFLOW
|
|
if a.itemsize > 1:
|
|
terms[nterms] = DiophantineTerm(a=1, ub=a.itemsize-1)
|
|
nterms += 1
|
|
|
|
i = 0
|
|
for j in range(nterms):
|
|
if terms[j].ub == 0:
|
|
continue
|
|
elif terms[j].ub < 0:
|
|
return _MEM_OVERLAP_NO
|
|
elif terms[j].a == 0:
|
|
return _MEM_OVERLAP_YES
|
|
|
|
if i != j:
|
|
terms[i] = terms[j]
|
|
i += 1
|
|
nterms = i
|
|
|
|
for j in range(nterms):
|
|
terms[j] = terms[j].with_ub(terms[j].ub * 2)
|
|
|
|
_sort(terms, nterms)
|
|
|
|
return _solve_diophantine(nterms, terms, -1, max_work, True, x)
|
|
|
|
def _array_shaes_memory_impl(a: ndarray, b: ndarray, max_work: int, raise_exception: bool):
|
|
if max_work < -2:
|
|
raise ValueError("Invalid value for max_work")
|
|
|
|
result = _solve_may_share_memory(a, b, max_work)
|
|
|
|
if result == _MEM_OVERLAP_NO:
|
|
return False
|
|
elif result == _MEM_OVERLAP_YES:
|
|
return True
|
|
elif result == _MEM_OVERLAP_OVERFLOW:
|
|
if raise_exception:
|
|
raise OverflowError("Integer overflow in computing overlap")
|
|
else:
|
|
return True
|
|
elif result == _MEM_OVERLAP_TOO_HARD:
|
|
if raise_exception:
|
|
raise util.TooHardError("Exceeded max_work")
|
|
else:
|
|
return True
|
|
else:
|
|
raise RuntimeError("Error in computing overlap")
|
|
|
|
def shares_memory(a: ndarray, b: ndarray, max_work: Optional[int] = None):
|
|
mw = 0
|
|
if max_work is None:
|
|
mw = MAY_SHARE_EXACT
|
|
else:
|
|
mw = max_work
|
|
return _array_shaes_memory_impl(a, b, mw, True)
|
|
|
|
def may_share_memory(a: ndarray, b: ndarray, max_work: Optional[int] = None):
|
|
mw = 0
|
|
if max_work is None:
|
|
mw = MAY_SHARE_BOUNDS
|
|
else:
|
|
mw = max_work
|
|
return _array_shaes_memory_impl(a, b, mw, False)
|
|
|
|
def setbufsize(size: int):
|
|
pass # Codon-NumPy does not use ufunc buffers
|
|
|
|
def getbufsize():
|
|
return 0 # Codon-NumPy does not use ufunc buffers
|
|
|
|
def byte_bounds(a: ndarray):
|
|
a_data = a.data
|
|
astrides = a.strides
|
|
ashape = a.shape
|
|
a_low = a_data.as_byte().__int__()
|
|
a_high = a_data.as_byte().__int__()
|
|
bytes_a = a.itemsize
|
|
cc, fc = a._contig
|
|
|
|
if cc or fc:
|
|
a_high += a.size * bytes_a
|
|
else:
|
|
for i in staticrange(staticlen(ashape)):
|
|
shape = ashape[i]
|
|
stride = astrides[i]
|
|
if stride < 0:
|
|
a_low += (shape - 1) * stride
|
|
else:
|
|
a_high += (shape - 1) * stride
|
|
a_high += bytes_a
|
|
return (a_low, a_high)
|