# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

from pocketfft import _swap_direction, fft as raw_fft, ifft as raw_ifft, rfft as raw_rfft, irfft as raw_irfft
from ..routines import _check_out, asarray, empty, arange, roll
from ..ndarray import ndarray
from ..util import zero, cast, free, multirange, normalize_axis_index, tuple_range, tuple_set, tuple_insert, tuple_delete, sizeof

def _complex_dtype(dtype: type):
    if (dtype is complex64 or
        dtype is float32 or
        dtype is float16):
        return complex64()
    else:
        return complex()

def _fft(a, inv: bool, n: Optional[int] = None,
         axis: int = -1, norm: Optional[str] = None,
         out = None):
    a = asarray(a)
    T = type(_complex_dtype(a.dtype))
    T0 = type(T().real)
    axis = normalize_axis_index(axis, a.ndim)
    istride = a.strides[axis]

    m = a.shape[axis]
    N = m
    if n is not None:
        N = n
    M = N

    out_shape = tuple_set(a.shape, axis, M)

    if out is None:
        ans = empty(out_shape, T)
    else:
        _check_out(out, out_shape)
        if out.dtype is not T:
            compile_error("'out' dtype must be complex")
        ans = out

    ostride = ans.strides[axis]
    need_buf = (a.dtype is not T or
                ostride != sizeof(T) or
                N != m)
    buf = Ptr[T](N) if need_buf else Ptr[T]()

    for idx in multirange(tuple_delete(a.shape, axis)):
        idx1 = tuple_insert(idx, axis, 0)
        p = a._ptr(idx1)
        q = ans._ptr(idx1)
        pb = p.as_byte()
        data = buf if need_buf else Ptr[T](q.as_byte())

        for i in range(min(m, N)):
            data[i] = cast(Ptr[a.dtype](pb)[0], T)
            pb += istride

        for i in range(m, N):
            data[i] = T(0.0, 0.0)

        if inv:
            raw_ifft(data, N, norm)
        else:
            raw_fft(data, N, norm)

        if need_buf:
            qb = q.as_byte()
            for i in range(M):
                Ptr[ans.dtype](qb)[0] = data[i]
                qb += ostride

    if need_buf:
        free(buf)

    return ans

def fft(a, n: Optional[int] = None,
        axis: int = -1, norm: Optional[str] = None,
        out = None):
    return _fft(a, inv=False, n=n, axis=axis, norm=norm, out=out)

def ifft(a, n: Optional[int] = None,
        axis: int = -1, norm: Optional[str] = None,
        out = None):
    return _fft(a, inv=True, n=n, axis=axis, norm=norm, out=out)

def rfft(a, n: Optional[int] = None,
         axis: int = -1, norm: Optional[str] = None,
         out = None):
    a = asarray(a)
    T = type(_complex_dtype(a.dtype))
    T0 = type(T().real)
    axis = normalize_axis_index(axis, a.ndim)
    istride = a.strides[axis]

    m = a.shape[axis]
    N = m
    if n is not None:
        N = n
    M = (N >> 1) + 1

    out_shape = tuple_set(a.shape, axis, M)

    need_ibuf = (a.dtype is not T0 or
                 istride != sizeof(T0) or
                 N != m)
    need_ibuf = True

    if out is None:
        ans = empty(out_shape, T)
    else:
        _check_out(out, out_shape)
        if out.dtype is not T:
            compile_error("'out' dtype must be complex")
        ans = out

    ostride = ans.strides[axis]
    need_obuf = (ans.strides[axis] != sizeof(ans.dtype))
    ibuf = Ptr[T0](N) if need_ibuf else Ptr[T0]()
    obuf = Ptr[T](M) if need_obuf else Ptr[T]()

    for idx in multirange(tuple_delete(a.shape, axis)):
        idx1 = tuple_insert(idx, axis, 0)
        p = a._ptr(idx1)
        if need_ibuf:
            pb = p.as_byte()
            for i in range(min(m, N)):
                ibuf[i] = cast(Ptr[a.dtype](pb)[0], T0)
                pb += istride
            for i in range(m, N):
                ibuf[i] = T0(0.0)
            in_data = ibuf
        else:
            in_data = Ptr[T0](p.as_byte())

        q = ans._ptr(idx1)
        if need_obuf:
            out_data = obuf
        else:
            out_data = q

        raw_rfft(in_data, N, True, norm, out_data)

        if need_obuf:
            qb = q.as_byte()
            for i in range(M):
                Ptr[ans.dtype](qb)[0] = out_data[i]
                qb += ostride

    if need_ibuf:
        free(ibuf)
    if need_obuf:
        free(obuf)

    return ans

def irfft(a, n: Optional[int] = None,
          axis: int = -1, norm: Optional[str] = None,
          out = None):
    a = asarray(a)
    T = type(_complex_dtype(a.dtype))
    T0 = type(T().real)
    axis = normalize_axis_index(axis, a.ndim)
    istride = a.strides[axis]

    m = a.shape[axis]
    N = m
    M = (m - 1) << 1
    if n is not None:
        N = (n >> 1) + 1
        M = n

    out_shape = tuple_set(a.shape, axis, M)

    need_ibuf = (a.dtype is not T or
                 istride != sizeof(T) or
                 N != m)

    if out is None:
        ans = empty(out_shape, T0)
    else:
        _check_out(out, out_shape)
        if out.dtype is not T0:
            compile_error("'out' dtype must be float")
        ans = out

    ostride = ans.strides[axis]
    need_obuf = (ans.strides[axis] != sizeof(ans.dtype))
    ibuf = Ptr[T](N) if need_ibuf else Ptr[T]()
    obuf = Ptr[T0](M) if need_obuf else Ptr[T0]()

    for idx in multirange(tuple_delete(a.shape, axis)):
        idx1 = tuple_insert(idx, axis, 0)
        p = a._ptr(idx1)
        if need_ibuf:
            pb = p.as_byte()
            for i in range(min(m, N)):
                ibuf[i] = cast(Ptr[a.dtype](pb)[0], T)
                pb += istride
            for i in range(m, N):
                ibuf[i] = T(0.0, 0.0)
            in_data = ibuf
        else:
            in_data = Ptr[T](p.as_byte())

        q = ans._ptr(idx1)
        if need_obuf:
            out_data = obuf
        else:
            out_data = q

        raw_irfft(in_data, M, norm, out_data)

        if need_obuf:
            qb = q.as_byte()
            for i in range(M):
                Ptr[ans.dtype](qb)[0] = out_data[i]
                qb += ostride

    if need_ibuf:
        free(ibuf)
    if need_obuf:
        free(obuf)

    return ans

def hfft(a, n: Optional[int] = None,
         axis: int = -1, norm: Optional[str] = None,
         out = None):
    a = asarray(a)
    n1 = 0
    if n is None:
        n1 = (a.shape[axis] - 1) * 2
    else:
        n1 = n
    new_norm = _swap_direction(norm)
    output = irfft(a.conj(), n1, axis, norm=new_norm, out=None)
    return output

def ihfft(a, n: Optional[int] = None,
          axis: int = -1, norm: Optional[str] = None,
          out = None):
    a = asarray(a)
    n1 = 0
    if n is None:
        n1 = a.shape[axis]
    else:
        n1 = n
    new_norm = _swap_direction(norm)
    out = rfft(a, n1, axis, norm=new_norm, out=out)
    out.map(lambda x: x.conjugate(), inplace=True)
    return out

def _cook_nd_args(a, s = None, axes = None, invreal: Static[int] = False):
    shapeless: Static[int] = (s is None)

    if s is None:
        if axes is None:
            s1 = a.shape
        else:
            s1 = tuple(a.shape[i] for i in axes)
    else:
        s1 = s

    if axes is None:
        if not shapeless:
            pass  # warning
        r = tuple_range(staticlen(s1))[::-1]
        axes1 = tuple(-i - 1 for i in r)
    else:
        axes1 = axes

    if staticlen(s1) != staticlen(axes1):
        compile_error("Shape and axes have different lengths.")

    if invreal and shapeless:
        s1 = tuple_set(s1, staticlen(s1) - 1, (a.shape[axes1[-1]] - 1) * 2)

    s2 = tuple(a.shape[axes1[i]] if s1[i] == -1 else s1[i] for i in staticrange(staticlen(s1)))
    return s2, axes1

def _raw_fftnd(a, s = None, axes = None, function = fft, norm: Optional[str] = None, out = None):
    a = asarray(a)
    s, axes = _cook_nd_args(a, s, axes)
    if staticlen(axes) == 0:
        return a
    i = len(axes) - 1
    a = function(a, n=s[i], axis=axes[i], norm=norm, out=out)
    for i in range(len(axes) - 2, -1, -1):
        a = function(a, n=s[i], axis=axes[i], norm=norm, out=out)
    return a

def fftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
    return _raw_fftnd(a, s, axes, fft, norm, out=out)

def ifftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
    return _raw_fftnd(a, s, axes, ifft, norm, out=out)

def fft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
    return _raw_fftnd(a, s, axes, fft, norm, out=out)

def ifft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
    return _raw_fftnd(a, s, axes, ifft, norm, out=out)

def rfftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
    a = asarray(a)
    s, axes = _cook_nd_args(a, s, axes)
    a = rfft(a, s[-1], axes[-1], norm, out=out)
    for i in range(len(axes) - 1):
        a = fft(a, s[i], axes[i], norm, out=out)
    return a

def rfft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
    return rfftn(a, s, axes, norm, out=out)

def irfftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
    a = asarray(a)
    s, axes = _cook_nd_args(a, s, axes, invreal=True)
    if staticlen(axes) <= 1:
        return irfft(a, s[-1], axes[-1], norm, out=out)
    b = ifft(a, s[0], axes[0], norm)
    for i in range(1, len(axes) - 1):
        b = ifft(b, s[i], axes[i], norm)
    return irfft(b, s[-1], axes[-1], norm, out=out)

def irfft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
    return irfftn(a, s, axes, norm, out=None)


# Helpers

def fftshift(x, axes=None):
    x = asarray(x)
    if axes is None:
        axes1 = tuple_range(x.ndim)
        shift = tuple(dim // 2 for dim in x.shape)
    elif isinstance(axes, int):
        axes1 = axes
        shift = x.shape[axes] // 2
    else:
        axes1 = axes
        shift = tuple(x.shape[ax] // 2 for ax in axes)

    return roll(x, shift, axes1)

def ifftshift(x, axes=None):
    x = asarray(x)
    if axes is None:
        axes1 = tuple_range(x.ndim)
        shift = tuple(-(dim // 2) for dim in x.shape)
    elif isinstance(axes, int):
        axes1 = axes
        shift = -(x.shape[axes] // 2)
    else:
        axes1 = axes
        shift = tuple(-(x.shape[ax] // 2) for ax in axes)

    return roll(x, shift, axes1)

def fftfreq(n: int, d: float = 1.0):
    val = 1.0 / (n * d)
    results = empty(n, int)
    N = (n-1)//2 + 1
    p1 = arange(0, N, dtype=int)
    results[:N] = p1
    p2 = arange(-(n//2), 0, dtype=int)
    results[N:] = p2
    return results * val

def rfftfreq(n: int, d: float = 1.0):
    val = 1.0/(n*d)
    N = n//2 + 1
    results = arange(0, N, dtype=int)
    return results * val