# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

from .routines import empty as empty_array, asarray, atleast_1d, empty, ascontiguousarray, _check_out, concatenate
from .sorting import argsort
import util

_LIKELY_IN_CACHE_SIZE: Static[int] = 8

def _linear_search(key, arr, lenxp):
    for i in range(lenxp):
        if key < arr[i]:
            return i - 1
    return lenxp - 1

def _binary_search_with_guess(key, arr, lenxp, guess):
    imin = 0
    imax = lenxp
    # Handle keys outside of the arr range first
    if key > arr[lenxp - 1]:
        return lenxp
    elif key < arr[0]:
        return -1

    # If len <= 4 use linear search
    if lenxp <= 4:
        return _linear_search(key, arr, lenxp)

    if guess > lenxp - 3:
        guess = lenxp - 3
    if guess < 1:
        guess = 1

    # Check most likely values: guess - 1, guess, guess + 1
    if key < arr[guess]:
        if key < arr[guess - 1]:
            imax = guess - 1
            # Last attempt to restrict search to items in cache
            if guess > _LIKELY_IN_CACHE_SIZE and key >= arr[
                    guess - _LIKELY_IN_CACHE_SIZE]:
                imin = guess - _LIKELY_IN_CACHE_SIZE
        else:
            return guess - 1
    else:
        if key < arr[guess + 1]:
            return guess
        elif key < arr[guess + 2]:
            return guess + 1
        else:
            imin = guess + 2
            # Last attempt to restrict search to items in cache
            if guess < lenxp - _LIKELY_IN_CACHE_SIZE - 1 and key < arr[
                    guess + _LIKELY_IN_CACHE_SIZE]:
                imax = guess + _LIKELY_IN_CACHE_SIZE

    # Finally, find index by bisection
    while imin < imax:
        imid = imin + (imax - imin) // 2
        if key >= arr[imid]:
            imin = imid + 1
        else:
            imax = imid

    return imin - 1

def _compiled_interp_complex(x: Ptr[X],
                             lenx: int,
                             X: type,
                             xp: Ptr[T],
                             lenxp: int,
                             T: type,
                             fp: Ptr[D],
                             lenafp: int,
                             D: type,
                             left=None,
                             right=None):
    dx = xp
    dz = x
    dy = fp

    af = empty_array(lenx, dtype=complex)

    # Get left and right fill values
    lval = dy[0] if left is None else complex(left)
    rval = dy[lenafp - 1] if right is None else complex(right)

    # binary_search_with_guess needs at least a 3 item long array
    if lenxp == 1:
        xp_val, fp_val = dx[0], dy[0]
        for i in range(lenx):
            x_val = dz[i]
            af[i] = lval if x_val < xp_val else (
                rval if x_val > xp_val else fp_val)
    else:
        j = 0

        # only pre-calculate slopes if there are relatively few of them
        slopes = empty_array(lenxp -
                             1, dtype=complex) if lenxp <= lenx else None

        if slopes is not None:
            for i in range(lenxp - 1):
                inv_dx = 1.0 / (dx[i + 1] - dx[i])
                slopes[i] = ((dy[i + 1] - dy[i]) * inv_dx)

        for i in range(lenx):
            x_val = dz[i]

            if util.isnan(x_val):
                af[i] = x_val
                continue

            j = _binary_search_with_guess(x_val, dx, lenxp, j)

            if j == -1:
                af[i] = lval
            elif j == lenxp:
                af[i] = rval
            elif j == lenxp - 1:
                af[i] = dy[j]
            elif dx[j] == x_val:
                af[i] = dy[j]
            else:
                slope = slopes[j] if slopes is not None else (
                    (dy[j + 1] - dy[j]) / (dx[j + 1] - dx[j]))

                af_real = slope.real * (x_val - dx[j]) + dy[j].real
                af_imag = slope.imag * (x_val - dx[j]) + dy[j].imag

                if util.isnan(af_real):
                    af_real = slope.real * (x_val - dx[j + 1]) + dy[j + 1].real
                    if util.isnan(af_real) and dy[j].real == dy[j + 1].real:
                        af_real = dy[j].real

                if util.isnan(af_imag):
                    af_imag = slope.imag * (x_val - dx[j]) + dy[j].imag
                    if util.isnan(af_imag) and dy[j].imag == dy[j + 1].imag:
                        af_imag = dy[j].imag

                af[i] = complex(af_real, af_imag)

    return af

def _compiled_interp(x: Ptr[X],
                     lenx: int,
                     X: type,
                     xp: Ptr[T],
                     lenxp: int,
                     T: type,
                     fp: Ptr[D],
                     lenafp: int,
                     D: type,
                     left=None,
                     right=None):
    afp = fp
    axp = xp
    ax = x
    if lenxp == 0:
        raise ValueError("Array of sample points is empty")

    if lenafp != lenxp:
        raise ValueError("fp and xp are not of the same length.")

    af = empty_array(lenx, dtype=float)
    dy = afp
    dx = axp
    dz = ax
    dres = af

    # Get left and right fill values
    lval = dy[0] if left is None else float(left)
    rval = dy[lenafp - 1] if right is None else float(right)

    # binary_search_with_guess needs at least a 3 item long array
    if lenxp == 1:
        xp_val = dx[0]
        fp_val = dy[0]

        for i in range(lenx):
            x_val = dz[i]
            if x_val < xp_val:
                dres[i] = lval
            elif x_val > xp_val:
                dres[i] = rval
            else:
                dres[i] = fp_val
    else:
        slopes = None

        if lenxp <= lenx:
            slopes = [0.0] * (lenxp - 1)
            for i in range(lenxp - 1):
                slopes[i] = (dy[i + 1] - dy[i]) / (dx[i + 1] - dx[i])

        j = 0

        for i in range(lenx):
            x_val = dz[i]

            if util.isnan(x_val):
                dres[i] = x_val
                continue

            j = _binary_search_with_guess(x_val, dx, lenxp, j)

            if j == -1:
                dres[i] = lval
            elif j == lenxp:
                dres[i] = rval
            elif j == lenxp - 1 or dx[j] == x_val:
                dres[i] = dy[j]
            else:
                slope = slopes[j] if slopes is not None else (
                    dy[j + 1] - dy[j]) / (dx[j + 1] - dx[j])
                dres[i] = slope * (x_val - dx[j]) + dy[j]

                if util.isnan(dres[i]):
                    dres[i] = slope * (x_val - dx[j + 1]) + dy[j + 1]

                    if util.isnan(dres[i]) and dy[j] == dy[j + 1]:
                        dres[i] = dy[j]

    return af

def _interp_type(dtype: type):
    if dtype is complex64:
        return complex()
    else:
        return dtype()

def interp(x, xp, fp, left=None, right=None, period=None):
    x = ascontiguousarray(atleast_1d(asarray(x)))
    xp = ascontiguousarray(atleast_1d(asarray(xp)))
    fp = asarray(fp)

    fp_dtype = type(_interp_type(fp.dtype))
    fp = ascontiguousarray(atleast_1d(asarray(fp, fp_dtype)))

    if period is not None:
        if period == 0:
            raise ValueError("period must be a non-zero value")
        period = abs(period)
        left = None
        right = None

        if xp.ndim != 1 or fp.ndim != 1:
            compile_error("Data points must be 1-D sequences")
        if xp.shape[0] != fp.shape[0]:
            raise ValueError("fp and xp are not of the same length")

        x = (x % period + period) % period
        xp = (xp % period + period) % period
        asort_xp = argsort(xp)
        xp = xp[asort_xp]
        fp = fp[asort_xp]
        xp = concatenate((xp[-1:] - period, xp, xp[0:1] + period))
        fp = concatenate((fp[-1:], fp, fp[0:1]))

    if fp.dtype is complex:
        af = _compiled_interp_complex(x.data,
                                      x.size,
                                      x.dtype,
                                      xp.data,
                                      xp.size,
                                      xp.dtype,
                                      fp.data,
                                      fp.size,
                                      fp.dtype,
                                      left=left,
                                      right=right)
        return af
    else:
        af = _compiled_interp(x.data,
                              x.size,
                              x.dtype,
                              xp.data,
                              xp.size,
                              xp.dtype,
                              fp.data,
                              fp.size,
                              fp.dtype,
                              left=left,
                              right=right)
        return af