mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
375 lines
10 KiB
Python
375 lines
10 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
from pocketfft import _swap_direction, fft as raw_fft, ifft as raw_ifft, rfft as raw_rfft, irfft as raw_irfft
|
|
from ..routines import _check_out, asarray, empty, arange, roll
|
|
from ..ndarray import ndarray
|
|
from ..util import zero, cast, free, multirange, normalize_axis_index, tuple_range, tuple_set, tuple_insert, tuple_delete, sizeof
|
|
|
|
def _complex_dtype(dtype: type):
|
|
if (dtype is complex64 or
|
|
dtype is float32 or
|
|
dtype is float16):
|
|
return complex64()
|
|
else:
|
|
return complex()
|
|
|
|
def _fft(a, inv: bool, n: Optional[int] = None,
|
|
axis: int = -1, norm: Optional[str] = None,
|
|
out = None):
|
|
a = asarray(a)
|
|
T = type(_complex_dtype(a.dtype))
|
|
T0 = type(T().real)
|
|
axis = normalize_axis_index(axis, a.ndim)
|
|
istride = a.strides[axis]
|
|
|
|
m = a.shape[axis]
|
|
N = m
|
|
if n is not None:
|
|
N = n
|
|
M = N
|
|
|
|
out_shape = tuple_set(a.shape, axis, M)
|
|
|
|
if out is None:
|
|
ans = empty(out_shape, T)
|
|
else:
|
|
_check_out(out, out_shape)
|
|
if out.dtype is not T:
|
|
compile_error("'out' dtype must be complex")
|
|
ans = out
|
|
|
|
ostride = ans.strides[axis]
|
|
need_buf = (a.dtype is not T or
|
|
ostride != sizeof(T) or
|
|
N != m)
|
|
buf = Ptr[T](N) if need_buf else Ptr[T]()
|
|
|
|
for idx in multirange(tuple_delete(a.shape, axis)):
|
|
idx1 = tuple_insert(idx, axis, 0)
|
|
p = a._ptr(idx1)
|
|
q = ans._ptr(idx1)
|
|
pb = p.as_byte()
|
|
data = buf if need_buf else Ptr[T](q.as_byte())
|
|
|
|
for i in range(min(m, N)):
|
|
data[i] = cast(Ptr[a.dtype](pb)[0], T)
|
|
pb += istride
|
|
|
|
for i in range(m, N):
|
|
data[i] = T(0.0, 0.0)
|
|
|
|
if inv:
|
|
raw_ifft(data, N, norm)
|
|
else:
|
|
raw_fft(data, N, norm)
|
|
|
|
if need_buf:
|
|
qb = q.as_byte()
|
|
for i in range(M):
|
|
Ptr[ans.dtype](qb)[0] = data[i]
|
|
qb += ostride
|
|
|
|
if need_buf:
|
|
free(buf)
|
|
|
|
return ans
|
|
|
|
def fft(a, n: Optional[int] = None,
|
|
axis: int = -1, norm: Optional[str] = None,
|
|
out = None):
|
|
return _fft(a, inv=False, n=n, axis=axis, norm=norm, out=out)
|
|
|
|
def ifft(a, n: Optional[int] = None,
|
|
axis: int = -1, norm: Optional[str] = None,
|
|
out = None):
|
|
return _fft(a, inv=True, n=n, axis=axis, norm=norm, out=out)
|
|
|
|
def rfft(a, n: Optional[int] = None,
|
|
axis: int = -1, norm: Optional[str] = None,
|
|
out = None):
|
|
a = asarray(a)
|
|
T = type(_complex_dtype(a.dtype))
|
|
T0 = type(T().real)
|
|
axis = normalize_axis_index(axis, a.ndim)
|
|
istride = a.strides[axis]
|
|
|
|
m = a.shape[axis]
|
|
N = m
|
|
if n is not None:
|
|
N = n
|
|
M = (N >> 1) + 1
|
|
|
|
out_shape = tuple_set(a.shape, axis, M)
|
|
|
|
need_ibuf = (a.dtype is not T0 or
|
|
istride != sizeof(T0) or
|
|
N != m)
|
|
need_ibuf = True
|
|
|
|
if out is None:
|
|
ans = empty(out_shape, T)
|
|
else:
|
|
_check_out(out, out_shape)
|
|
if out.dtype is not T:
|
|
compile_error("'out' dtype must be complex")
|
|
ans = out
|
|
|
|
ostride = ans.strides[axis]
|
|
need_obuf = (ans.strides[axis] != sizeof(ans.dtype))
|
|
ibuf = Ptr[T0](N) if need_ibuf else Ptr[T0]()
|
|
obuf = Ptr[T](M) if need_obuf else Ptr[T]()
|
|
|
|
for idx in multirange(tuple_delete(a.shape, axis)):
|
|
idx1 = tuple_insert(idx, axis, 0)
|
|
p = a._ptr(idx1)
|
|
if need_ibuf:
|
|
pb = p.as_byte()
|
|
for i in range(min(m, N)):
|
|
ibuf[i] = cast(Ptr[a.dtype](pb)[0], T0)
|
|
pb += istride
|
|
for i in range(m, N):
|
|
ibuf[i] = T0(0.0)
|
|
in_data = ibuf
|
|
else:
|
|
in_data = Ptr[T0](p.as_byte())
|
|
|
|
q = ans._ptr(idx1)
|
|
if need_obuf:
|
|
out_data = obuf
|
|
else:
|
|
out_data = q
|
|
|
|
raw_rfft(in_data, N, True, norm, out_data)
|
|
|
|
if need_obuf:
|
|
qb = q.as_byte()
|
|
for i in range(M):
|
|
Ptr[ans.dtype](qb)[0] = out_data[i]
|
|
qb += ostride
|
|
|
|
if need_ibuf:
|
|
free(ibuf)
|
|
if need_obuf:
|
|
free(obuf)
|
|
|
|
return ans
|
|
|
|
def irfft(a, n: Optional[int] = None,
|
|
axis: int = -1, norm: Optional[str] = None,
|
|
out = None):
|
|
a = asarray(a)
|
|
T = type(_complex_dtype(a.dtype))
|
|
T0 = type(T().real)
|
|
axis = normalize_axis_index(axis, a.ndim)
|
|
istride = a.strides[axis]
|
|
|
|
m = a.shape[axis]
|
|
N = m
|
|
M = (m - 1) << 1
|
|
if n is not None:
|
|
N = (n >> 1) + 1
|
|
M = n
|
|
|
|
out_shape = tuple_set(a.shape, axis, M)
|
|
|
|
need_ibuf = (a.dtype is not T or
|
|
istride != sizeof(T) or
|
|
N != m)
|
|
|
|
if out is None:
|
|
ans = empty(out_shape, T0)
|
|
else:
|
|
_check_out(out, out_shape)
|
|
if out.dtype is not T0:
|
|
compile_error("'out' dtype must be float")
|
|
ans = out
|
|
|
|
ostride = ans.strides[axis]
|
|
need_obuf = (ans.strides[axis] != sizeof(ans.dtype))
|
|
ibuf = Ptr[T](N) if need_ibuf else Ptr[T]()
|
|
obuf = Ptr[T0](M) if need_obuf else Ptr[T0]()
|
|
|
|
for idx in multirange(tuple_delete(a.shape, axis)):
|
|
idx1 = tuple_insert(idx, axis, 0)
|
|
p = a._ptr(idx1)
|
|
if need_ibuf:
|
|
pb = p.as_byte()
|
|
for i in range(min(m, N)):
|
|
ibuf[i] = cast(Ptr[a.dtype](pb)[0], T)
|
|
pb += istride
|
|
for i in range(m, N):
|
|
ibuf[i] = T(0.0, 0.0)
|
|
in_data = ibuf
|
|
else:
|
|
in_data = Ptr[T](p.as_byte())
|
|
|
|
q = ans._ptr(idx1)
|
|
if need_obuf:
|
|
out_data = obuf
|
|
else:
|
|
out_data = q
|
|
|
|
raw_irfft(in_data, M, norm, out_data)
|
|
|
|
if need_obuf:
|
|
qb = q.as_byte()
|
|
for i in range(M):
|
|
Ptr[ans.dtype](qb)[0] = out_data[i]
|
|
qb += ostride
|
|
|
|
if need_ibuf:
|
|
free(ibuf)
|
|
if need_obuf:
|
|
free(obuf)
|
|
|
|
return ans
|
|
|
|
def hfft(a, n: Optional[int] = None,
|
|
axis: int = -1, norm: Optional[str] = None,
|
|
out = None):
|
|
a = asarray(a)
|
|
n1 = 0
|
|
if n is None:
|
|
n1 = (a.shape[axis] - 1) * 2
|
|
else:
|
|
n1 = n
|
|
new_norm = _swap_direction(norm)
|
|
output = irfft(a.conj(), n1, axis, norm=new_norm, out=None)
|
|
return output
|
|
|
|
def ihfft(a, n: Optional[int] = None,
|
|
axis: int = -1, norm: Optional[str] = None,
|
|
out = None):
|
|
a = asarray(a)
|
|
n1 = 0
|
|
if n is None:
|
|
n1 = a.shape[axis]
|
|
else:
|
|
n1 = n
|
|
new_norm = _swap_direction(norm)
|
|
out = rfft(a, n1, axis, norm=new_norm, out=out)
|
|
out.map(lambda x: x.conjugate(), inplace=True)
|
|
return out
|
|
|
|
def _cook_nd_args(a, s = None, axes = None, invreal: Static[int] = False):
|
|
shapeless: Static[int] = (s is None)
|
|
|
|
if s is None:
|
|
if axes is None:
|
|
s1 = a.shape
|
|
else:
|
|
s1 = tuple(a.shape[i] for i in axes)
|
|
else:
|
|
s1 = s
|
|
|
|
if axes is None:
|
|
if not shapeless:
|
|
pass # warning
|
|
r = tuple_range(staticlen(s1))[::-1]
|
|
axes1 = tuple(-i - 1 for i in r)
|
|
else:
|
|
axes1 = axes
|
|
|
|
if staticlen(s1) != staticlen(axes1):
|
|
compile_error("Shape and axes have different lengths.")
|
|
|
|
if invreal and shapeless:
|
|
s1 = tuple_set(s1, staticlen(s1) - 1, (a.shape[axes1[-1]] - 1) * 2)
|
|
|
|
s2 = tuple(a.shape[axes1[i]] if s1[i] == -1 else s1[i] for i in staticrange(staticlen(s1)))
|
|
return s2, axes1
|
|
|
|
def _raw_fftnd(a, s = None, axes = None, function = fft, norm: Optional[str] = None, out = None):
|
|
a = asarray(a)
|
|
s, axes = _cook_nd_args(a, s, axes)
|
|
if staticlen(axes) == 0:
|
|
return a
|
|
i = len(axes) - 1
|
|
a = function(a, n=s[i], axis=axes[i], norm=norm, out=out)
|
|
for i in range(len(axes) - 2, -1, -1):
|
|
a = function(a, n=s[i], axis=axes[i], norm=norm, out=out)
|
|
return a
|
|
|
|
def fftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
|
|
return _raw_fftnd(a, s, axes, fft, norm, out=out)
|
|
|
|
def ifftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
|
|
return _raw_fftnd(a, s, axes, ifft, norm, out=out)
|
|
|
|
def fft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
|
|
return _raw_fftnd(a, s, axes, fft, norm, out=out)
|
|
|
|
def ifft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
|
|
return _raw_fftnd(a, s, axes, ifft, norm, out=out)
|
|
|
|
def rfftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
|
|
a = asarray(a)
|
|
s, axes = _cook_nd_args(a, s, axes)
|
|
a = rfft(a, s[-1], axes[-1], norm, out=out)
|
|
for i in range(len(axes) - 1):
|
|
a = fft(a, s[i], axes[i], norm, out=out)
|
|
return a
|
|
|
|
def rfft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
|
|
return rfftn(a, s, axes, norm, out=out)
|
|
|
|
def irfftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
|
|
a = asarray(a)
|
|
s, axes = _cook_nd_args(a, s, axes, invreal=True)
|
|
if staticlen(axes) <= 1:
|
|
return irfft(a, s[-1], axes[-1], norm, out=out)
|
|
b = ifft(a, s[0], axes[0], norm)
|
|
for i in range(1, len(axes) - 1):
|
|
b = ifft(b, s[i], axes[i], norm)
|
|
return irfft(b, s[-1], axes[-1], norm, out=out)
|
|
|
|
def irfft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
|
|
return irfftn(a, s, axes, norm, out=None)
|
|
|
|
|
|
# Helpers
|
|
|
|
def fftshift(x, axes=None):
|
|
x = asarray(x)
|
|
if axes is None:
|
|
axes1 = tuple_range(x.ndim)
|
|
shift = tuple(dim // 2 for dim in x.shape)
|
|
elif isinstance(axes, int):
|
|
axes1 = axes
|
|
shift = x.shape[axes] // 2
|
|
else:
|
|
axes1 = axes
|
|
shift = tuple(x.shape[ax] // 2 for ax in axes)
|
|
|
|
return roll(x, shift, axes1)
|
|
|
|
def ifftshift(x, axes=None):
|
|
x = asarray(x)
|
|
if axes is None:
|
|
axes1 = tuple_range(x.ndim)
|
|
shift = tuple(-(dim // 2) for dim in x.shape)
|
|
elif isinstance(axes, int):
|
|
axes1 = axes
|
|
shift = -(x.shape[axes] // 2)
|
|
else:
|
|
axes1 = axes
|
|
shift = tuple(-(x.shape[ax] // 2) for ax in axes)
|
|
|
|
return roll(x, shift, axes1)
|
|
|
|
def fftfreq(n: int, d: float = 1.0):
|
|
val = 1.0 / (n * d)
|
|
results = empty(n, int)
|
|
N = (n-1)//2 + 1
|
|
p1 = arange(0, N, dtype=int)
|
|
results[:N] = p1
|
|
p2 = arange(-(n//2), 0, dtype=int)
|
|
results[N:] = p2
|
|
return results * val
|
|
|
|
def rfftfreq(n: int, d: float = 1.0):
|
|
val = 1.0/(n*d)
|
|
N = n//2 + 1
|
|
results = arange(0, N, dtype=int)
|
|
return results * val
|