# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

from .ndarray import ndarray
import util

MAXDIMS             : Static[int] = 32
MAY_SHARE_BOUNDS    : Static[int] = 0
MAY_SHARE_EXACT     : Static[int] = -1
_MEM_OVERLAP_NO      : Static[int] = 0   # no solution exists
_MEM_OVERLAP_YES     : Static[int] = 1   # solution found
_MEM_OVERLAP_TOO_HARD: Static[int] = -1  # max_work exceeded
_MEM_OVERLAP_OVERFLOW: Static[int] = -2  # algorithm failed due to integer overflow
_MEM_OVERLAP_ERROR   : Static[int] = -3  # invalid input

i128 = Int[128]

def _euclid(a1: int, a2: int):
    gamma1 = 1
    gamma2 = 0
    epsilon1 = 0
    epsilon2 = 1

    a_gcd = 0
    gamma = 0
    epsilon = 0

    while True:
        if a2 > 0:
            r = a1 // a2
            a1 -= r * a2
            gamma1 -= r * gamma2
            epsilon1 -= r * epsilon2
        else:
            a_gcd = a1
            gamma = gamma1
            epsilon = epsilon1
            break

        if a1 > 0:
            r = a2 // a1
            a2 -= r * a1
            gamma2 -= r * gamma1
            epsilon2 -= r * epsilon1
        else:
            a_gcd = a2
            gamma = gamma2
            epsilon = epsilon2
            break

    return a_gcd, gamma, epsilon

@tuple
class DiophantineTerm:
    a: int
    ub: int

    def __new__(a: int, ub: int) -> DiophantineTerm:
        return (a, ub)

    def __lt__(self, other: DiophantineTerm):
        return self.a < other.a

    def with_a(self, a: int):
        return DiophantineTerm(a, self.ub)

    def with_ub(self, ub: int):
        return DiophantineTerm(self.a, ub)

def _sort(p: Ptr[DiophantineTerm], n: int):
    arr = ndarray((n,), p)
    arr.sort()

@pure
@llvm
def _safe_add_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
    declare {i64, i1} @llvm.sadd.with.overflow.i64(i64, i64)
    %c = call {i64, i1} @llvm.sadd.with.overflow.i64(i64 %a, i64 %b)
    ret {i64, i1} %c

def _safe_add(a: int, b: int):
    c, o = _safe_add_llvm(a, b)
    return (c, bool(o))

@pure
@llvm
def _safe_sub_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
    declare {i64, i1} @llvm.ssub.with.overflow.i64(i64, i64)
    %c = call {i64, i1} @llvm.ssub.with.overflow.i64(i64 %a, i64 %b)
    ret {i64, i1} %c

def _safe_sub(a: int, b: int):
    c, o = _safe_sub_llvm(a, b)
    return (c, bool(o))

@pure
@llvm
def _safe_mul_llvm(a: int, b: int) -> Tuple[int, UInt[1]]:
    declare {i64, i1} @llvm.smul.with.overflow.i64(i64, i64)
    %c = call {i64, i1} @llvm.smul.with.overflow.i64(i64 %a, i64 %b)
    ret {i64, i1} %c

def _safe_mul(a: int, b: int):
    c, o = _safe_mul_llvm(a, b)
    return (c, bool(o))

@pure
@llvm
def _safe_add128_llvm(a: i128, b: i128) -> Tuple[i128, UInt[1]]:
    declare {i128, i1} @llvm.sadd.with.overflow.i128(i128, i128)
    %c = call {i128, i1} @llvm.sadd.with.overflow.i128(i128 %a, i128 %b)
    ret {i128, i1} %c

def _safe_add128(a: i128, b: i128):
    c, o = _safe_add128_llvm(a, b)
    return (c, bool(o))

@pure
@llvm
def _safe_sub128_llvm(a: i128, b: i128) -> Tuple[i128, UInt[1]]:
    declare {i128, i1} @llvm.ssub.with.overflow.i128(i128, i128)
    %c = call {i128, i1} @llvm.ssub.with.overflow.i128(i128 %a, i128 %b)
    ret {i128, i1} %c

def _safe_sub128(a: i128, b: i128):
    c, o = _safe_sub128_llvm(a, b)
    return (c, bool(o))

def _diophantine_precompute(n: int,
                            E: Ptr[DiophantineTerm],
                            Ep: Ptr[DiophantineTerm],
                            Gamma: Ptr[int],
                            Epsilon: Ptr[int]):
    a_gcd, gamma, epsilon = _euclid(E[0].a, E[1].a)
    Ep[0] = Ep[0].with_a(a_gcd)
    Gamma[0] = gamma
    Epsilon[0] = epsilon

    if n > 2:
        c1 = E[0].a // a_gcd
        c2 = E[1].a // a_gcd
        x1, o1 = _safe_mul(E[0].ub, c1)
        x2, o2 = _safe_mul(E[1].ub, c2)
        x3, o3 = _safe_add(x1, x2)
        Ep[0] = Ep[0].with_ub(x3)
        if o1 or o2 or o3:
            return True

    for j in range(2, n):
        a_gcd, gamma, epsilon = _euclid(Ep[j - 2].a, E[j].a)
        Ep[j - 1] = Ep[j - 1].with_a(a_gcd)
        Gamma[j - 1] = gamma
        Epsilon[j - 1] = epsilon

        if j < n - 1:
            c1 = Ep[j - 2].a // a_gcd
            c2 = E[j].a // a_gcd
            x1, o1 = _safe_mul(c1, Ep[j - 2].ub)
            x2, o2 = _safe_mul(c2, E[j].ub)
            x3, o3 = _safe_add(x1, x2)
            Ep[j - 1] = Ep[j - 1].with_ub(x3)
            if o1 or o2 or o3:
                return True

    return False

def _floordiv(a: i128, b: int):
    b = i128(b)
    result, remainder = a // b, a % b
    if a < i128(0) and remainder != i128(0):
        result -= i128(1)
    return result

def _ceildiv(a: i128, b: int):
    b = i128(b)
    result, remainder = a // b, a % b
    if a > i128(0) and remainder != i128(0):
        result += i128(1)
    return result

def _to_64(x: i128):
    if x > i128(9223372036854775807) or x < i128(-9223372036854775808):
        return 0, True
    else:
        return int(x), False

def _diophantine_dfs(n: int,
                     v: int,
                     E: Ptr[DiophantineTerm],
                     Ep: Ptr[DiophantineTerm],
                     Gamma: Ptr[int],
                     Epsilon: Ptr[int],
                     b: int,
                     max_work: int,
                     require_ub_nontrivial: bool,
                     x: Ptr[int],
                     count: Ptr[int]):

    if max_work >= 0 and count[0] >= max_work:
        return _MEM_OVERLAP_TOO_HARD

    if v == 1:
        a1 = E[0].a
        u1 = E[0].ub
    else:
        a1 = Ep[v - 2].a
        u1 = Ep[v - 2].ub

    a2 = E[v].a
    u2 = E[v].ub

    a_gcd = Ep[v - 1].a
    gamma = Gamma[v - 1]
    epsilon = Epsilon[v - 1]

    c = b // a_gcd
    r = b % a_gcd
    if r != 0:
        count[0] += 1
        return _MEM_OVERLAP_NO

    c1 = a2 // a_gcd
    c2 = a1 // a_gcd

    x10 = i128(gamma) * i128(c)
    x20 = i128(epsilon) * i128(c)

    t_l1 = _ceildiv(-x10, c1)
    v1, o1 = _safe_sub128(x20, i128(u2))
    t_l2 = _ceildiv(v1, c2)

    v2, o2 = _safe_sub128(i128(u1), x10)
    t_u1 = _floordiv(v2, c1)
    t_u2 = _floordiv(x20, c2)

    if o1 or o2:
        return _MEM_OVERLAP_OVERFLOW

    if t_l2 > t_l1:
        tl1 = t_l2

    if t_u1 > t_u2:
        t_u1 = t_u2

    if t_l1 > t_u1:
        count[0] += 1
        return _MEM_OVERLAP_NO

    t_l, o1 = _to_64(t_l1)
    t_u, o2 = _to_64(t_u1)

    x10, o3 = _safe_add128(x10, i128(c1) * i128(t_l))
    x20, o4 = _safe_sub128(x20, i128(c2) * i128(t_l))

    t_u, o5 = _safe_sub(t_u, t_l)
    t_l = 0
    x1, o6 = _to_64(x10)
    x2, o7 = _to_64(x20)

    if o1 or o2 or o3 or o4 or o5 or o6 or o7:
        return _MEM_OVERLAP_OVERFLOW

    if v == 1:
        if t_u >= t_l:
            x[0] = x1 + (c1 * t_l)
            x[1] = x2 - (c2 * t_l)
            if require_ub_nontrivial:
                is_ub_nontrivial = True
                for j in range(n):
                    if x[j] != E[j].ub // 2:
                        is_ub_nontrivial = False
                        break

                if is_ub_nontrivial:
                    count[0] += 1
                    return _MEM_OVERLAP_NO

            return _MEM_OVERLAP_YES

        count[0] += 1
        return _MEM_OVERLAP_NO
    else:
        t = t_l
        while t <= t_u:
            x[v] = x2 - (c2 * t)
            v1, o1 = _safe_mul(a2, x[v])
            b2, o2 = _safe_sub(b, v1)

            if o1 or o2:
                return _MEM_OVERLAP_OVERFLOW

            res = _diophantine_dfs(n, v - 1, E, Ep, Gamma, Epsilon,
                                   b2, max_work, require_ub_nontrivial,
                                   x, count)
            if res != _MEM_OVERLAP_NO:
                return res
            t += 1

        count[0] += 1
        return _MEM_OVERLAP_NO

def _solve_diophantine(n: int,
                       E: Ptr[DiophantineTerm],
                       b: int,
                       max_work: int,
                       require_ub_nontrivial: bool,
                       x: Ptr[int]):
    for j in range(n):
        if E[j].a <= 0:
            return _MEM_OVERLAP_ERROR
        elif E[j].ub < 0:
            return _MEM_OVERLAP_NO

    if require_ub_nontrivial:
        ub_sum = 0
        o = False
        for j in range(n):
            if E[j].ub % 2 != 0:
                return _MEM_OVERLAP_ERROR

            v1, o1 = _safe_mul(E[j].a, E[j].ub // 2)
            v2, o2 = _safe_add(ub_sum, v1)
            ub_sum = v2
            o = o or o1 or o2

        if o:
            return _MEM_OVERLAP_ERROR
        b = ub_sum

    if b < 0:
        return _MEM_OVERLAP_NO

    if n == 0:
        if require_ub_nontrivial:
            return _MEM_OVERLAP_NO

        if b == 0:
            return _MEM_OVERLAP_YES

        return _MEM_OVERLAP_NO
    elif n == 1:
        if require_ub_nontrivial:
            return _MEM_OVERLAP_NO

        if b % E[0].a == 0:
            x[0] = b // E[0].a
            if x[0] >= 0 and x[0] <= E[0].ub:
                return _MEM_OVERLAP_YES

        return _MEM_OVERLAP_NO
    else:
        count = 0
        Ep = Ptr[DiophantineTerm](n)
        Epsilon = Ptr[int](n)
        Gamma = Ptr[int](n)

        if _diophantine_precompute(n, E, Ep, Gamma, Epsilon):
            res = _MEM_OVERLAP_OVERFLOW
        else:
            res = _diophantine_dfs(n, n - 1, E, Ep, Gamma, Epsilon, b, max_work,
                                   require_ub_nontrivial, x, __ptr__(count))

        util.free(Ep)
        util.free(Gamma)
        util.free(Epsilon)
        return res

def _diophantine_simplify(n: Ptr[int], E: Ptr[DiophantineTerm], b: int):
    for j in range(n[0]):
        if E[j].ub < 0:
            return 0

    if b < 0:
        return 0

    _sort(E, n[0])

    o = False
    m = n[0]
    i = 0
    for j in range(1, m):
        if E[i].a == E[j].a:
            v1, o1 = _safe_add(E[i].ub, E[j].ub)
            o = o or o1
            E[i] = E[i].with_ub(v1)
            n[0] -= 1
        else:
            i += 1
            if i != j:
                E[i] = E[j]

    m = n[0]
    i = 0
    for j in range(m):
        E[j] = E[j].with_ub(min(E[j].ub, b // E[j].a))
        if E[j].ub == 0:
            n[0] -= 1
        else:
            if i != j:
                E[i] = E[j]
            i += 1

    if o:
        return -1
    else:
        return 0

def _offset_bounds_from_strides(arr: ndarray):
    lower = 0
    upper = 0
    nd = arr.ndim
    shape = arr.shape
    strides = arr.strides

    for i in range(nd):
        if shape[i] == 0:
            return (0, 0)

        max_axis_offset = strides[i] * (shape[i] - 1)
        if max_axis_offset > 0:
            upper += max_axis_offset
        else:
            lower += max_axis_offset

    upper += arr.itemsize
    return (lower, upper)

def _get_array_memory_extents(arr: ndarray):
    low, upper = _offset_bounds_from_strides(arr)
    out_start = arr.data.as_byte().__int__() + low
    out_end = arr.data.as_byte().__int__() + upper

    num_bytes = arr.itemsize
    for j in range(arr.ndim):
        num_bytes *= arr.shape[j]

    return u64(out_start), u64(out_end), u64(num_bytes)

def _strides_to_terms(arr: ndarray,
                      terms: Ptr[DiophantineTerm],
                      nterms: Ptr[int],
                      skip_empty: bool):
    for i in range(arr.ndim):
        if skip_empty:
            if arr.shape[i] <= 1 or arr.strides[i] == 0:
                continue

        terms[nterms[0]] = terms[nterms[0]].with_a(arr.strides[i])

        if terms[nterms[0]].a < 0:
            terms[nterms[0]] = terms[nterms[0]].with_a(-terms[nterms[0]].a)

        if terms[nterms[0]].a < 0:
            return True

        terms[nterms[0]] = terms[nterms[0]].with_ub(arr.shape[i] - 1)
        nterms[0] += 1

    return False

def _solve_may_share_memory(a: ndarray, b: ndarray, max_work: int):
    if a.ndim > MAXDIMS or b.ndim > MAXDIMS:
        compile_error("maximum array dimension exceeded")

    terms_tuple = (DiophantineTerm(0, 0),) * (2*MAXDIMS + 2)
    x_tuple = (0,) * (2*MAXDIMS + 2)
    terms = Ptr[DiophantineTerm](__ptr__(terms_tuple).as_byte())
    x = Ptr[int](__ptr__(x_tuple).as_byte())

    start1, end1, size1 = _get_array_memory_extents(a)
    start2, end2, size2 = _get_array_memory_extents(b)

    if not (start1 < end2 and start2 < end1 and start1 < end1 and start2 < end2):
        return _MEM_OVERLAP_NO

    if max_work == 0:
        return _MEM_OVERLAP_TOO_HARD

    uintp_rhs = min(end2 - u64(1) - start1, end1 - u64(1) - start2)
    if uintp_rhs > u64(9223372036854775807):
        return _MEM_OVERLAP_OVERFLOW
    rhs = int(uintp_rhs)

    nterms = 0
    if _strides_to_terms(a, terms, __ptr__(nterms), True):
        return _MEM_OVERLAP_OVERFLOW
    if _strides_to_terms(b, terms, __ptr__(nterms), True):
        return _MEM_OVERLAP_OVERFLOW
    if a.itemsize > 1:
        terms[nterms] = DiophantineTerm(a=1, ub=a.itemsize-1)
        nterms += 1
    if b.itemsize > 1:
        terms[nterms] = DiophantineTerm(a=1, ub=b.itemsize-1)
        nterms += 1

    if _diophantine_simplify(__ptr__(nterms), terms, rhs):
        return _MEM_OVERLAP_OVERFLOW

    return _solve_diophantine(nterms, terms, rhs, max_work, False, x)

def _solve_may_have_internal_overlap(a: ndarray, max_work: int):
    if a.ndim > MAXDIMS:
        compile_error("maximum array dimension exceeded")

    terms_tuple = (DiophantineTerm(0, 0),) * (MAXDIMS + 1)
    x_tuple = (0,) * (MAXDIMS + 1)
    terms = Ptr[DiophantineTerm](__ptr__(terms_tuple).as_byte())
    x = Ptr[int](__ptr__(x_tuple).as_byte())

    cc, fc = a._contig
    if cc or fc:
        return _MEM_OVERLAP_NO


    nterms = 0
    if _strides_to_terms(a, terms, __ptr__(nterms), False):
        return _MEM_OVERLAP_OVERFLOW
    if a.itemsize > 1:
        terms[nterms] = DiophantineTerm(a=1, ub=a.itemsize-1)
        nterms += 1

    i = 0
    for j in range(nterms):
        if terms[j].ub == 0:
            continue
        elif terms[j].ub < 0:
            return _MEM_OVERLAP_NO
        elif terms[j].a == 0:
            return _MEM_OVERLAP_YES

        if i != j:
            terms[i] = terms[j]
        i += 1
    nterms = i

    for j in range(nterms):
        terms[j] = terms[j].with_ub(terms[j].ub * 2)

    _sort(terms, nterms)

    return _solve_diophantine(nterms, terms, -1, max_work, True, x)

def _array_shaes_memory_impl(a: ndarray, b: ndarray, max_work: int, raise_exception: bool):
    if max_work < -2:
        raise ValueError("Invalid value for max_work")

    result = _solve_may_share_memory(a, b, max_work)

    if result == _MEM_OVERLAP_NO:
        return False
    elif result == _MEM_OVERLAP_YES:
        return True
    elif result == _MEM_OVERLAP_OVERFLOW:
        if raise_exception:
            raise OverflowError("Integer overflow in computing overlap")
        else:
            return True
    elif result == _MEM_OVERLAP_TOO_HARD:
        if raise_exception:
            raise util.TooHardError("Exceeded max_work")
        else:
            return True
    else:
        raise RuntimeError("Error in computing overlap")

def shares_memory(a: ndarray, b: ndarray, max_work: Optional[int] = None):
    mw = 0
    if max_work is None:
        mw = MAY_SHARE_EXACT
    else:
        mw = max_work
    return _array_shaes_memory_impl(a, b, mw, True)

def may_share_memory(a: ndarray, b: ndarray, max_work: Optional[int] = None):
    mw = 0
    if max_work is None:
        mw = MAY_SHARE_BOUNDS
    else:
        mw = max_work
    return _array_shaes_memory_impl(a, b, mw, False)

def setbufsize(size: int):
    pass  # Codon-NumPy does not use ufunc buffers

def getbufsize():
    return 0  # Codon-NumPy does not use ufunc buffers

def byte_bounds(a: ndarray):
    a_data = a.data
    astrides = a.strides
    ashape = a.shape
    a_low = a_data.as_byte().__int__()
    a_high = a_data.as_byte().__int__()
    bytes_a = a.itemsize
    cc, fc = a._contig

    if cc or fc:
        a_high += a.size * bytes_a
    else:
        for i in staticrange(staticlen(ashape)):
            shape = ashape[i]
            stride = astrides[i]
            if stride < 0:
                a_low += (shape - 1) * stride
            else:
                a_high += (shape - 1) * stride
        a_high += bytes_a
    return (a_low, a_high)