# Copyright (C) 2022-2025 Exaloop Inc. from pocketfft import _swap_direction, fft as raw_fft, ifft as raw_ifft, rfft as raw_rfft, irfft as raw_irfft from ..routines import _check_out, asarray, empty, arange, roll from ..ndarray import ndarray from ..util import zero, cast, free, multirange, normalize_axis_index, tuple_range, tuple_set, tuple_insert, tuple_delete, sizeof def _complex_dtype(dtype: type): if (dtype is complex64 or dtype is float32 or dtype is float16): return complex64() else: return complex() def _fft(a, inv: bool, n: Optional[int] = None, axis: int = -1, norm: Optional[str] = None, out = None): a = asarray(a) T = type(_complex_dtype(a.dtype)) T0 = type(T().real) axis = normalize_axis_index(axis, a.ndim) istride = a.strides[axis] m = a.shape[axis] N = m if n is not None: N = n M = N out_shape = tuple_set(a.shape, axis, M) if out is None: ans = empty(out_shape, T) else: _check_out(out, out_shape) if out.dtype is not T: compile_error("'out' dtype must be complex") ans = out ostride = ans.strides[axis] need_buf = (a.dtype is not T or ostride != sizeof(T) or N != m) buf = Ptr[T](N) if need_buf else Ptr[T]() for idx in multirange(tuple_delete(a.shape, axis)): idx1 = tuple_insert(idx, axis, 0) p = a._ptr(idx1) q = ans._ptr(idx1) pb = p.as_byte() data = buf if need_buf else Ptr[T](q.as_byte()) for i in range(min(m, N)): data[i] = cast(Ptr[a.dtype](pb)[0], T) pb += istride for i in range(m, N): data[i] = T(0.0, 0.0) if inv: raw_ifft(data, N, norm) else: raw_fft(data, N, norm) if need_buf: qb = q.as_byte() for i in range(M): Ptr[ans.dtype](qb)[0] = data[i] qb += ostride if need_buf: free(buf) return ans def fft(a, n: Optional[int] = None, axis: int = -1, norm: Optional[str] = None, out = None): return _fft(a, inv=False, n=n, axis=axis, norm=norm, out=out) def ifft(a, n: Optional[int] = None, axis: int = -1, norm: Optional[str] = None, out = None): return _fft(a, inv=True, n=n, axis=axis, norm=norm, out=out) def rfft(a, n: Optional[int] = None, axis: int = -1, norm: Optional[str] = None, out = None): a = asarray(a) T = type(_complex_dtype(a.dtype)) T0 = type(T().real) axis = normalize_axis_index(axis, a.ndim) istride = a.strides[axis] m = a.shape[axis] N = m if n is not None: N = n M = (N >> 1) + 1 out_shape = tuple_set(a.shape, axis, M) need_ibuf = (a.dtype is not T0 or istride != sizeof(T0) or N != m) need_ibuf = True if out is None: ans = empty(out_shape, T) else: _check_out(out, out_shape) if out.dtype is not T: compile_error("'out' dtype must be complex") ans = out ostride = ans.strides[axis] need_obuf = (ans.strides[axis] != sizeof(ans.dtype)) ibuf = Ptr[T0](N) if need_ibuf else Ptr[T0]() obuf = Ptr[T](M) if need_obuf else Ptr[T]() for idx in multirange(tuple_delete(a.shape, axis)): idx1 = tuple_insert(idx, axis, 0) p = a._ptr(idx1) if need_ibuf: pb = p.as_byte() for i in range(min(m, N)): ibuf[i] = cast(Ptr[a.dtype](pb)[0], T0) pb += istride for i in range(m, N): ibuf[i] = T0(0.0) in_data = ibuf else: in_data = Ptr[T0](p.as_byte()) q = ans._ptr(idx1) if need_obuf: out_data = obuf else: out_data = q raw_rfft(in_data, N, True, norm, out_data) if need_obuf: qb = q.as_byte() for i in range(M): Ptr[ans.dtype](qb)[0] = out_data[i] qb += ostride if need_ibuf: free(ibuf) if need_obuf: free(obuf) return ans def irfft(a, n: Optional[int] = None, axis: int = -1, norm: Optional[str] = None, out = None): a = asarray(a) T = type(_complex_dtype(a.dtype)) T0 = type(T().real) axis = normalize_axis_index(axis, a.ndim) istride = a.strides[axis] m = a.shape[axis] N = m M = (m - 1) << 1 if n is not None: N = (n >> 1) + 1 M = n out_shape = tuple_set(a.shape, axis, M) need_ibuf = (a.dtype is not T or istride != sizeof(T) or N != m) if out is None: ans = empty(out_shape, T0) else: _check_out(out, out_shape) if out.dtype is not T0: compile_error("'out' dtype must be float") ans = out ostride = ans.strides[axis] need_obuf = (ans.strides[axis] != sizeof(ans.dtype)) ibuf = Ptr[T](N) if need_ibuf else Ptr[T]() obuf = Ptr[T0](M) if need_obuf else Ptr[T0]() for idx in multirange(tuple_delete(a.shape, axis)): idx1 = tuple_insert(idx, axis, 0) p = a._ptr(idx1) if need_ibuf: pb = p.as_byte() for i in range(min(m, N)): ibuf[i] = cast(Ptr[a.dtype](pb)[0], T) pb += istride for i in range(m, N): ibuf[i] = T(0.0, 0.0) in_data = ibuf else: in_data = Ptr[T](p.as_byte()) q = ans._ptr(idx1) if need_obuf: out_data = obuf else: out_data = q raw_irfft(in_data, M, norm, out_data) if need_obuf: qb = q.as_byte() for i in range(M): Ptr[ans.dtype](qb)[0] = out_data[i] qb += ostride if need_ibuf: free(ibuf) if need_obuf: free(obuf) return ans def hfft(a, n: Optional[int] = None, axis: int = -1, norm: Optional[str] = None, out = None): a = asarray(a) n1 = 0 if n is None: n1 = (a.shape[axis] - 1) * 2 else: n1 = n new_norm = _swap_direction(norm) output = irfft(a.conj(), n1, axis, norm=new_norm, out=None) return output def ihfft(a, n: Optional[int] = None, axis: int = -1, norm: Optional[str] = None, out = None): a = asarray(a) n1 = 0 if n is None: n1 = a.shape[axis] else: n1 = n new_norm = _swap_direction(norm) out = rfft(a, n1, axis, norm=new_norm, out=out) out.map(lambda x: x.conjugate(), inplace=True) return out def _cook_nd_args(a, s = None, axes = None, invreal: Static[int] = False): shapeless: Static[int] = (s is None) if s is None: if axes is None: s1 = a.shape else: s1 = tuple(a.shape[i] for i in axes) else: s1 = s if axes is None: if not shapeless: pass # warning r = tuple_range(staticlen(s1))[::-1] axes1 = tuple(-i - 1 for i in r) else: axes1 = axes if staticlen(s1) != staticlen(axes1): compile_error("Shape and axes have different lengths.") if invreal and shapeless: s1 = tuple_set(s1, staticlen(s1) - 1, (a.shape[axes1[-1]] - 1) * 2) s2 = tuple(a.shape[axes1[i]] if s1[i] == -1 else s1[i] for i in staticrange(staticlen(s1))) return s2, axes1 def _raw_fftnd(a, s = None, axes = None, function = fft, norm: Optional[str] = None, out = None): a = asarray(a) s, axes = _cook_nd_args(a, s, axes) if staticlen(axes) == 0: return a i = len(axes) - 1 a = function(a, n=s[i], axis=axes[i], norm=norm, out=out) for i in range(len(axes) - 2, -1, -1): a = function(a, n=s[i], axis=axes[i], norm=norm, out=out) return a def fftn(a, s=None, axes=None, norm: Optional[str] = None, out = None): return _raw_fftnd(a, s, axes, fft, norm, out=out) def ifftn(a, s=None, axes=None, norm: Optional[str] = None, out = None): return _raw_fftnd(a, s, axes, ifft, norm, out=out) def fft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None): return _raw_fftnd(a, s, axes, fft, norm, out=out) def ifft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None): return _raw_fftnd(a, s, axes, ifft, norm, out=out) def rfftn(a, s=None, axes=None, norm: Optional[str] = None, out = None): a = asarray(a) s, axes = _cook_nd_args(a, s, axes) a = rfft(a, s[-1], axes[-1], norm, out=out) for i in range(len(axes) - 1): a = fft(a, s[i], axes[i], norm, out=out) return a def rfft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None): return rfftn(a, s, axes, norm, out=out) def irfftn(a, s=None, axes=None, norm: Optional[str] = None, out = None): a = asarray(a) s, axes = _cook_nd_args(a, s, axes, invreal=True) if staticlen(axes) <= 1: return irfft(a, s[-1], axes[-1], norm, out=out) b = ifft(a, s[0], axes[0], norm) for i in range(1, len(axes) - 1): b = ifft(b, s[i], axes[i], norm) return irfft(b, s[-1], axes[-1], norm, out=out) def irfft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None): return irfftn(a, s, axes, norm, out=None) # Helpers def fftshift(x, axes=None): x = asarray(x) if axes is None: axes1 = tuple_range(x.ndim) shift = tuple(dim // 2 for dim in x.shape) elif isinstance(axes, int): axes1 = axes shift = x.shape[axes] // 2 else: axes1 = axes shift = tuple(x.shape[ax] // 2 for ax in axes) return roll(x, shift, axes1) def ifftshift(x, axes=None): x = asarray(x) if axes is None: axes1 = tuple_range(x.ndim) shift = tuple(-(dim // 2) for dim in x.shape) elif isinstance(axes, int): axes1 = axes shift = -(x.shape[axes] // 2) else: axes1 = axes shift = tuple(-(x.shape[ax] // 2) for ax in axes) return roll(x, shift, axes1) def fftfreq(n: int, d: float = 1.0): val = 1.0 / (n * d) results = empty(n, int) N = (n-1)//2 + 1 p1 = arange(0, N, dtype=int) results[:N] = p1 p2 = arange(-(n//2), 0, dtype=int) results[N:] = p2 return results * val def rfftfreq(n: int, d: float = 1.0): val = 1.0/(n*d) N = n//2 + 1 results = arange(0, N, dtype=int) return results * val