# Copyright (C) 2022-2025 Exaloop Inc. from .ndarray import ndarray import util MAXDIMS : Static[int] = 32 MAY_SHARE_BOUNDS : Static[int] = 0 MAY_SHARE_EXACT : Static[int] = -1 _MEM_OVERLAP_NO : Static[int] = 0 # no solution exists _MEM_OVERLAP_YES : Static[int] = 1 # solution found _MEM_OVERLAP_TOO_HARD: Static[int] = -1 # max_work exceeded _MEM_OVERLAP_OVERFLOW: Static[int] = -2 # algorithm failed due to integer overflow _MEM_OVERLAP_ERROR : Static[int] = -3 # invalid input i128 = Int[128] def _euclid(a1: int, a2: int): gamma1 = 1 gamma2 = 0 epsilon1 = 0 epsilon2 = 1 a_gcd = 0 gamma = 0 epsilon = 0 while True: if a2 > 0: r = a1 // a2 a1 -= r * a2 gamma1 -= r * gamma2 epsilon1 -= r * epsilon2 else: a_gcd = a1 gamma = gamma1 epsilon = epsilon1 break if a1 > 0: r = a2 // a1 a2 -= r * a1 gamma2 -= r * gamma1 epsilon2 -= r * epsilon1 else: a_gcd = a2 gamma = gamma2 epsilon = epsilon2 break return a_gcd, gamma, epsilon @tuple class DiophantineTerm: a: int ub: int def __new__(a: int, ub: int) -> DiophantineTerm: return (a, ub) def __lt__(self, other: DiophantineTerm): return self.a < other.a def with_a(self, a: int): return DiophantineTerm(a, self.ub) def with_ub(self, ub: int): return DiophantineTerm(self.a, ub) def _sort(p: Ptr[DiophantineTerm], n: int): arr = ndarray((n,), p) arr.sort() @pure @llvm def _safe_add_llvm(a: int, b: int) -> Tuple[int, UInt[1]]: declare {i64, i1} @llvm.sadd.with.overflow.i64(i64, i64) %c = call {i64, i1} @llvm.sadd.with.overflow.i64(i64 %a, i64 %b) ret {i64, i1} %c def _safe_add(a: int, b: int): c, o = _safe_add_llvm(a, b) return (c, bool(o)) @pure @llvm def _safe_sub_llvm(a: int, b: int) -> Tuple[int, UInt[1]]: declare {i64, i1} @llvm.ssub.with.overflow.i64(i64, i64) %c = call {i64, i1} @llvm.ssub.with.overflow.i64(i64 %a, i64 %b) ret {i64, i1} %c def _safe_sub(a: int, b: int): c, o = _safe_sub_llvm(a, b) return (c, bool(o)) @pure @llvm def _safe_mul_llvm(a: int, b: int) -> Tuple[int, UInt[1]]: declare {i64, i1} @llvm.smul.with.overflow.i64(i64, i64) %c = call {i64, i1} @llvm.smul.with.overflow.i64(i64 %a, i64 %b) ret {i64, i1} %c def _safe_mul(a: int, b: int): c, o = _safe_mul_llvm(a, b) return (c, bool(o)) @pure @llvm def _safe_add128_llvm(a: i128, b: i128) -> Tuple[i128, UInt[1]]: declare {i128, i1} @llvm.sadd.with.overflow.i128(i128, i128) %c = call {i128, i1} @llvm.sadd.with.overflow.i128(i128 %a, i128 %b) ret {i128, i1} %c def _safe_add128(a: i128, b: i128): c, o = _safe_add128_llvm(a, b) return (c, bool(o)) @pure @llvm def _safe_sub128_llvm(a: i128, b: i128) -> Tuple[i128, UInt[1]]: declare {i128, i1} @llvm.ssub.with.overflow.i128(i128, i128) %c = call {i128, i1} @llvm.ssub.with.overflow.i128(i128 %a, i128 %b) ret {i128, i1} %c def _safe_sub128(a: i128, b: i128): c, o = _safe_sub128_llvm(a, b) return (c, bool(o)) def _diophantine_precompute(n: int, E: Ptr[DiophantineTerm], Ep: Ptr[DiophantineTerm], Gamma: Ptr[int], Epsilon: Ptr[int]): a_gcd, gamma, epsilon = _euclid(E[0].a, E[1].a) Ep[0] = Ep[0].with_a(a_gcd) Gamma[0] = gamma Epsilon[0] = epsilon if n > 2: c1 = E[0].a // a_gcd c2 = E[1].a // a_gcd x1, o1 = _safe_mul(E[0].ub, c1) x2, o2 = _safe_mul(E[1].ub, c2) x3, o3 = _safe_add(x1, x2) Ep[0] = Ep[0].with_ub(x3) if o1 or o2 or o3: return True for j in range(2, n): a_gcd, gamma, epsilon = _euclid(Ep[j - 2].a, E[j].a) Ep[j - 1] = Ep[j - 1].with_a(a_gcd) Gamma[j - 1] = gamma Epsilon[j - 1] = epsilon if j < n - 1: c1 = Ep[j - 2].a // a_gcd c2 = E[j].a // a_gcd x1, o1 = _safe_mul(c1, Ep[j - 2].ub) x2, o2 = _safe_mul(c2, E[j].ub) x3, o3 = _safe_add(x1, x2) Ep[j - 1] = Ep[j - 1].with_ub(x3) if o1 or o2 or o3: return True return False def _floordiv(a: i128, b: int): b = i128(b) result, remainder = a // b, a % b if a < i128(0) and remainder != i128(0): result -= i128(1) return result def _ceildiv(a: i128, b: int): b = i128(b) result, remainder = a // b, a % b if a > i128(0) and remainder != i128(0): result += i128(1) return result def _to_64(x: i128): if x > i128(9223372036854775807) or x < i128(-9223372036854775808): return 0, True else: return int(x), False def _diophantine_dfs(n: int, v: int, E: Ptr[DiophantineTerm], Ep: Ptr[DiophantineTerm], Gamma: Ptr[int], Epsilon: Ptr[int], b: int, max_work: int, require_ub_nontrivial: bool, x: Ptr[int], count: Ptr[int]): if max_work >= 0 and count[0] >= max_work: return _MEM_OVERLAP_TOO_HARD if v == 1: a1 = E[0].a u1 = E[0].ub else: a1 = Ep[v - 2].a u1 = Ep[v - 2].ub a2 = E[v].a u2 = E[v].ub a_gcd = Ep[v - 1].a gamma = Gamma[v - 1] epsilon = Epsilon[v - 1] c = b // a_gcd r = b % a_gcd if r != 0: count[0] += 1 return _MEM_OVERLAP_NO c1 = a2 // a_gcd c2 = a1 // a_gcd x10 = i128(gamma) * i128(c) x20 = i128(epsilon) * i128(c) t_l1 = _ceildiv(-x10, c1) v1, o1 = _safe_sub128(x20, i128(u2)) t_l2 = _ceildiv(v1, c2) v2, o2 = _safe_sub128(i128(u1), x10) t_u1 = _floordiv(v2, c1) t_u2 = _floordiv(x20, c2) if o1 or o2: return _MEM_OVERLAP_OVERFLOW if t_l2 > t_l1: tl1 = t_l2 if t_u1 > t_u2: t_u1 = t_u2 if t_l1 > t_u1: count[0] += 1 return _MEM_OVERLAP_NO t_l, o1 = _to_64(t_l1) t_u, o2 = _to_64(t_u1) x10, o3 = _safe_add128(x10, i128(c1) * i128(t_l)) x20, o4 = _safe_sub128(x20, i128(c2) * i128(t_l)) t_u, o5 = _safe_sub(t_u, t_l) t_l = 0 x1, o6 = _to_64(x10) x2, o7 = _to_64(x20) if o1 or o2 or o3 or o4 or o5 or o6 or o7: return _MEM_OVERLAP_OVERFLOW if v == 1: if t_u >= t_l: x[0] = x1 + (c1 * t_l) x[1] = x2 - (c2 * t_l) if require_ub_nontrivial: is_ub_nontrivial = True for j in range(n): if x[j] != E[j].ub // 2: is_ub_nontrivial = False break if is_ub_nontrivial: count[0] += 1 return _MEM_OVERLAP_NO return _MEM_OVERLAP_YES count[0] += 1 return _MEM_OVERLAP_NO else: t = t_l while t <= t_u: x[v] = x2 - (c2 * t) v1, o1 = _safe_mul(a2, x[v]) b2, o2 = _safe_sub(b, v1) if o1 or o2: return _MEM_OVERLAP_OVERFLOW res = _diophantine_dfs(n, v - 1, E, Ep, Gamma, Epsilon, b2, max_work, require_ub_nontrivial, x, count) if res != _MEM_OVERLAP_NO: return res t += 1 count[0] += 1 return _MEM_OVERLAP_NO def _solve_diophantine(n: int, E: Ptr[DiophantineTerm], b: int, max_work: int, require_ub_nontrivial: bool, x: Ptr[int]): for j in range(n): if E[j].a <= 0: return _MEM_OVERLAP_ERROR elif E[j].ub < 0: return _MEM_OVERLAP_NO if require_ub_nontrivial: ub_sum = 0 o = False for j in range(n): if E[j].ub % 2 != 0: return _MEM_OVERLAP_ERROR v1, o1 = _safe_mul(E[j].a, E[j].ub // 2) v2, o2 = _safe_add(ub_sum, v1) ub_sum = v2 o = o or o1 or o2 if o: return _MEM_OVERLAP_ERROR b = ub_sum if b < 0: return _MEM_OVERLAP_NO if n == 0: if require_ub_nontrivial: return _MEM_OVERLAP_NO if b == 0: return _MEM_OVERLAP_YES return _MEM_OVERLAP_NO elif n == 1: if require_ub_nontrivial: return _MEM_OVERLAP_NO if b % E[0].a == 0: x[0] = b // E[0].a if x[0] >= 0 and x[0] <= E[0].ub: return _MEM_OVERLAP_YES return _MEM_OVERLAP_NO else: count = 0 Ep = Ptr[DiophantineTerm](n) Epsilon = Ptr[int](n) Gamma = Ptr[int](n) if _diophantine_precompute(n, E, Ep, Gamma, Epsilon): res = _MEM_OVERLAP_OVERFLOW else: res = _diophantine_dfs(n, n - 1, E, Ep, Gamma, Epsilon, b, max_work, require_ub_nontrivial, x, __ptr__(count)) util.free(Ep) util.free(Gamma) util.free(Epsilon) return res def _diophantine_simplify(n: Ptr[int], E: Ptr[DiophantineTerm], b: int): for j in range(n[0]): if E[j].ub < 0: return 0 if b < 0: return 0 _sort(E, n[0]) o = False m = n[0] i = 0 for j in range(1, m): if E[i].a == E[j].a: v1, o1 = _safe_add(E[i].ub, E[j].ub) o = o or o1 E[i] = E[i].with_ub(v1) n[0] -= 1 else: i += 1 if i != j: E[i] = E[j] m = n[0] i = 0 for j in range(m): E[j] = E[j].with_ub(min(E[j].ub, b // E[j].a)) if E[j].ub == 0: n[0] -= 1 else: if i != j: E[i] = E[j] i += 1 if o: return -1 else: return 0 def _offset_bounds_from_strides(arr: ndarray): lower = 0 upper = 0 nd = arr.ndim shape = arr.shape strides = arr.strides for i in range(nd): if shape[i] == 0: return (0, 0) max_axis_offset = strides[i] * (shape[i] - 1) if max_axis_offset > 0: upper += max_axis_offset else: lower += max_axis_offset upper += arr.itemsize return (lower, upper) def _get_array_memory_extents(arr: ndarray): low, upper = _offset_bounds_from_strides(arr) out_start = arr.data.as_byte().__int__() + low out_end = arr.data.as_byte().__int__() + upper num_bytes = arr.itemsize for j in range(arr.ndim): num_bytes *= arr.shape[j] return u64(out_start), u64(out_end), u64(num_bytes) def _strides_to_terms(arr: ndarray, terms: Ptr[DiophantineTerm], nterms: Ptr[int], skip_empty: bool): for i in range(arr.ndim): if skip_empty: if arr.shape[i] <= 1 or arr.strides[i] == 0: continue terms[nterms[0]] = terms[nterms[0]].with_a(arr.strides[i]) if terms[nterms[0]].a < 0: terms[nterms[0]] = terms[nterms[0]].with_a(-terms[nterms[0]].a) if terms[nterms[0]].a < 0: return True terms[nterms[0]] = terms[nterms[0]].with_ub(arr.shape[i] - 1) nterms[0] += 1 return False def _solve_may_share_memory(a: ndarray, b: ndarray, max_work: int): if a.ndim > MAXDIMS or b.ndim > MAXDIMS: compile_error("maximum array dimension exceeded") terms_tuple = (DiophantineTerm(0, 0),) * (2*MAXDIMS + 2) x_tuple = (0,) * (2*MAXDIMS + 2) terms = Ptr[DiophantineTerm](__ptr__(terms_tuple).as_byte()) x = Ptr[int](__ptr__(x_tuple).as_byte()) start1, end1, size1 = _get_array_memory_extents(a) start2, end2, size2 = _get_array_memory_extents(b) if not (start1 < end2 and start2 < end1 and start1 < end1 and start2 < end2): return _MEM_OVERLAP_NO if max_work == 0: return _MEM_OVERLAP_TOO_HARD uintp_rhs = min(end2 - u64(1) - start1, end1 - u64(1) - start2) if uintp_rhs > u64(9223372036854775807): return _MEM_OVERLAP_OVERFLOW rhs = int(uintp_rhs) nterms = 0 if _strides_to_terms(a, terms, __ptr__(nterms), True): return _MEM_OVERLAP_OVERFLOW if _strides_to_terms(b, terms, __ptr__(nterms), True): return _MEM_OVERLAP_OVERFLOW if a.itemsize > 1: terms[nterms] = DiophantineTerm(a=1, ub=a.itemsize-1) nterms += 1 if b.itemsize > 1: terms[nterms] = DiophantineTerm(a=1, ub=b.itemsize-1) nterms += 1 if _diophantine_simplify(__ptr__(nterms), terms, rhs): return _MEM_OVERLAP_OVERFLOW return _solve_diophantine(nterms, terms, rhs, max_work, False, x) def _solve_may_have_internal_overlap(a: ndarray, max_work: int): if a.ndim > MAXDIMS: compile_error("maximum array dimension exceeded") terms_tuple = (DiophantineTerm(0, 0),) * (MAXDIMS + 1) x_tuple = (0,) * (MAXDIMS + 1) terms = Ptr[DiophantineTerm](__ptr__(terms_tuple).as_byte()) x = Ptr[int](__ptr__(x_tuple).as_byte()) cc, fc = a._contig if cc or fc: return _MEM_OVERLAP_NO nterms = 0 if _strides_to_terms(a, terms, __ptr__(nterms), False): return _MEM_OVERLAP_OVERFLOW if a.itemsize > 1: terms[nterms] = DiophantineTerm(a=1, ub=a.itemsize-1) nterms += 1 i = 0 for j in range(nterms): if terms[j].ub == 0: continue elif terms[j].ub < 0: return _MEM_OVERLAP_NO elif terms[j].a == 0: return _MEM_OVERLAP_YES if i != j: terms[i] = terms[j] i += 1 nterms = i for j in range(nterms): terms[j] = terms[j].with_ub(terms[j].ub * 2) _sort(terms, nterms) return _solve_diophantine(nterms, terms, -1, max_work, True, x) def _array_shaes_memory_impl(a: ndarray, b: ndarray, max_work: int, raise_exception: bool): if max_work < -2: raise ValueError("Invalid value for max_work") result = _solve_may_share_memory(a, b, max_work) if result == _MEM_OVERLAP_NO: return False elif result == _MEM_OVERLAP_YES: return True elif result == _MEM_OVERLAP_OVERFLOW: if raise_exception: raise OverflowError("Integer overflow in computing overlap") else: return True elif result == _MEM_OVERLAP_TOO_HARD: if raise_exception: raise util.TooHardError("Exceeded max_work") else: return True else: raise RuntimeError("Error in computing overlap") def shares_memory(a: ndarray, b: ndarray, max_work: Optional[int] = None): mw = 0 if max_work is None: mw = MAY_SHARE_EXACT else: mw = max_work return _array_shaes_memory_impl(a, b, mw, True) def may_share_memory(a: ndarray, b: ndarray, max_work: Optional[int] = None): mw = 0 if max_work is None: mw = MAY_SHARE_BOUNDS else: mw = max_work return _array_shaes_memory_impl(a, b, mw, False) def setbufsize(size: int): pass # Codon-NumPy does not use ufunc buffers def getbufsize(): return 0 # Codon-NumPy does not use ufunc buffers def byte_bounds(a: ndarray): a_data = a.data astrides = a.strides ashape = a.shape a_low = a_data.as_byte().__int__() a_high = a_data.as_byte().__int__() bytes_a = a.itemsize cc, fc = a._contig if cc or fc: a_high += a.size * bytes_a else: for i in staticrange(staticlen(ashape)): shape = ashape[i] stride = astrides[i] if stride < 0: a_low += (shape - 1) * stride else: a_high += (shape - 1) * stride a_high += bytes_a return (a_low, a_high)