1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/fft/__init__.codon
A. R. Shajii b8c1eeed36
2025 updates (#619)
* 2025 updates

* Update ci.yml
2025-01-29 15:41:43 -05:00

375 lines
10 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from pocketfft import _swap_direction, fft as raw_fft, ifft as raw_ifft, rfft as raw_rfft, irfft as raw_irfft
from ..routines import _check_out, asarray, empty, arange, roll
from ..ndarray import ndarray
from ..util import zero, cast, free, multirange, normalize_axis_index, tuple_range, tuple_set, tuple_insert, tuple_delete, sizeof
def _complex_dtype(dtype: type):
if (dtype is complex64 or
dtype is float32 or
dtype is float16):
return complex64()
else:
return complex()
def _fft(a, inv: bool, n: Optional[int] = None,
axis: int = -1, norm: Optional[str] = None,
out = None):
a = asarray(a)
T = type(_complex_dtype(a.dtype))
T0 = type(T().real)
axis = normalize_axis_index(axis, a.ndim)
istride = a.strides[axis]
m = a.shape[axis]
N = m
if n is not None:
N = n
M = N
out_shape = tuple_set(a.shape, axis, M)
if out is None:
ans = empty(out_shape, T)
else:
_check_out(out, out_shape)
if out.dtype is not T:
compile_error("'out' dtype must be complex")
ans = out
ostride = ans.strides[axis]
need_buf = (a.dtype is not T or
ostride != sizeof(T) or
N != m)
buf = Ptr[T](N) if need_buf else Ptr[T]()
for idx in multirange(tuple_delete(a.shape, axis)):
idx1 = tuple_insert(idx, axis, 0)
p = a._ptr(idx1)
q = ans._ptr(idx1)
pb = p.as_byte()
data = buf if need_buf else Ptr[T](q.as_byte())
for i in range(min(m, N)):
data[i] = cast(Ptr[a.dtype](pb)[0], T)
pb += istride
for i in range(m, N):
data[i] = T(0.0, 0.0)
if inv:
raw_ifft(data, N, norm)
else:
raw_fft(data, N, norm)
if need_buf:
qb = q.as_byte()
for i in range(M):
Ptr[ans.dtype](qb)[0] = data[i]
qb += ostride
if need_buf:
free(buf)
return ans
def fft(a, n: Optional[int] = None,
axis: int = -1, norm: Optional[str] = None,
out = None):
return _fft(a, inv=False, n=n, axis=axis, norm=norm, out=out)
def ifft(a, n: Optional[int] = None,
axis: int = -1, norm: Optional[str] = None,
out = None):
return _fft(a, inv=True, n=n, axis=axis, norm=norm, out=out)
def rfft(a, n: Optional[int] = None,
axis: int = -1, norm: Optional[str] = None,
out = None):
a = asarray(a)
T = type(_complex_dtype(a.dtype))
T0 = type(T().real)
axis = normalize_axis_index(axis, a.ndim)
istride = a.strides[axis]
m = a.shape[axis]
N = m
if n is not None:
N = n
M = (N >> 1) + 1
out_shape = tuple_set(a.shape, axis, M)
need_ibuf = (a.dtype is not T0 or
istride != sizeof(T0) or
N != m)
need_ibuf = True
if out is None:
ans = empty(out_shape, T)
else:
_check_out(out, out_shape)
if out.dtype is not T:
compile_error("'out' dtype must be complex")
ans = out
ostride = ans.strides[axis]
need_obuf = (ans.strides[axis] != sizeof(ans.dtype))
ibuf = Ptr[T0](N) if need_ibuf else Ptr[T0]()
obuf = Ptr[T](M) if need_obuf else Ptr[T]()
for idx in multirange(tuple_delete(a.shape, axis)):
idx1 = tuple_insert(idx, axis, 0)
p = a._ptr(idx1)
if need_ibuf:
pb = p.as_byte()
for i in range(min(m, N)):
ibuf[i] = cast(Ptr[a.dtype](pb)[0], T0)
pb += istride
for i in range(m, N):
ibuf[i] = T0(0.0)
in_data = ibuf
else:
in_data = Ptr[T0](p.as_byte())
q = ans._ptr(idx1)
if need_obuf:
out_data = obuf
else:
out_data = q
raw_rfft(in_data, N, True, norm, out_data)
if need_obuf:
qb = q.as_byte()
for i in range(M):
Ptr[ans.dtype](qb)[0] = out_data[i]
qb += ostride
if need_ibuf:
free(ibuf)
if need_obuf:
free(obuf)
return ans
def irfft(a, n: Optional[int] = None,
axis: int = -1, norm: Optional[str] = None,
out = None):
a = asarray(a)
T = type(_complex_dtype(a.dtype))
T0 = type(T().real)
axis = normalize_axis_index(axis, a.ndim)
istride = a.strides[axis]
m = a.shape[axis]
N = m
M = (m - 1) << 1
if n is not None:
N = (n >> 1) + 1
M = n
out_shape = tuple_set(a.shape, axis, M)
need_ibuf = (a.dtype is not T or
istride != sizeof(T) or
N != m)
if out is None:
ans = empty(out_shape, T0)
else:
_check_out(out, out_shape)
if out.dtype is not T0:
compile_error("'out' dtype must be float")
ans = out
ostride = ans.strides[axis]
need_obuf = (ans.strides[axis] != sizeof(ans.dtype))
ibuf = Ptr[T](N) if need_ibuf else Ptr[T]()
obuf = Ptr[T0](M) if need_obuf else Ptr[T0]()
for idx in multirange(tuple_delete(a.shape, axis)):
idx1 = tuple_insert(idx, axis, 0)
p = a._ptr(idx1)
if need_ibuf:
pb = p.as_byte()
for i in range(min(m, N)):
ibuf[i] = cast(Ptr[a.dtype](pb)[0], T)
pb += istride
for i in range(m, N):
ibuf[i] = T(0.0, 0.0)
in_data = ibuf
else:
in_data = Ptr[T](p.as_byte())
q = ans._ptr(idx1)
if need_obuf:
out_data = obuf
else:
out_data = q
raw_irfft(in_data, M, norm, out_data)
if need_obuf:
qb = q.as_byte()
for i in range(M):
Ptr[ans.dtype](qb)[0] = out_data[i]
qb += ostride
if need_ibuf:
free(ibuf)
if need_obuf:
free(obuf)
return ans
def hfft(a, n: Optional[int] = None,
axis: int = -1, norm: Optional[str] = None,
out = None):
a = asarray(a)
n1 = 0
if n is None:
n1 = (a.shape[axis] - 1) * 2
else:
n1 = n
new_norm = _swap_direction(norm)
output = irfft(a.conj(), n1, axis, norm=new_norm, out=None)
return output
def ihfft(a, n: Optional[int] = None,
axis: int = -1, norm: Optional[str] = None,
out = None):
a = asarray(a)
n1 = 0
if n is None:
n1 = a.shape[axis]
else:
n1 = n
new_norm = _swap_direction(norm)
out = rfft(a, n1, axis, norm=new_norm, out=out)
out.map(lambda x: x.conjugate(), inplace=True)
return out
def _cook_nd_args(a, s = None, axes = None, invreal: Static[int] = False):
shapeless: Static[int] = (s is None)
if s is None:
if axes is None:
s1 = a.shape
else:
s1 = tuple(a.shape[i] for i in axes)
else:
s1 = s
if axes is None:
if not shapeless:
pass # warning
r = tuple_range(staticlen(s1))[::-1]
axes1 = tuple(-i - 1 for i in r)
else:
axes1 = axes
if staticlen(s1) != staticlen(axes1):
compile_error("Shape and axes have different lengths.")
if invreal and shapeless:
s1 = tuple_set(s1, staticlen(s1) - 1, (a.shape[axes1[-1]] - 1) * 2)
s2 = tuple(a.shape[axes1[i]] if s1[i] == -1 else s1[i] for i in staticrange(staticlen(s1)))
return s2, axes1
def _raw_fftnd(a, s = None, axes = None, function = fft, norm: Optional[str] = None, out = None):
a = asarray(a)
s, axes = _cook_nd_args(a, s, axes)
if staticlen(axes) == 0:
return a
i = len(axes) - 1
a = function(a, n=s[i], axis=axes[i], norm=norm, out=out)
for i in range(len(axes) - 2, -1, -1):
a = function(a, n=s[i], axis=axes[i], norm=norm, out=out)
return a
def fftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
return _raw_fftnd(a, s, axes, fft, norm, out=out)
def ifftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
return _raw_fftnd(a, s, axes, ifft, norm, out=out)
def fft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
return _raw_fftnd(a, s, axes, fft, norm, out=out)
def ifft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
return _raw_fftnd(a, s, axes, ifft, norm, out=out)
def rfftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
a = asarray(a)
s, axes = _cook_nd_args(a, s, axes)
a = rfft(a, s[-1], axes[-1], norm, out=out)
for i in range(len(axes) - 1):
a = fft(a, s[i], axes[i], norm, out=out)
return a
def rfft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
return rfftn(a, s, axes, norm, out=out)
def irfftn(a, s=None, axes=None, norm: Optional[str] = None, out = None):
a = asarray(a)
s, axes = _cook_nd_args(a, s, axes, invreal=True)
if staticlen(axes) <= 1:
return irfft(a, s[-1], axes[-1], norm, out=out)
b = ifft(a, s[0], axes[0], norm)
for i in range(1, len(axes) - 1):
b = ifft(b, s[i], axes[i], norm)
return irfft(b, s[-1], axes[-1], norm, out=out)
def irfft2(a, s=None, axes=(-2, -1), norm: Optional[str] = None, out = None):
return irfftn(a, s, axes, norm, out=None)
# Helpers
def fftshift(x, axes=None):
x = asarray(x)
if axes is None:
axes1 = tuple_range(x.ndim)
shift = tuple(dim // 2 for dim in x.shape)
elif isinstance(axes, int):
axes1 = axes
shift = x.shape[axes] // 2
else:
axes1 = axes
shift = tuple(x.shape[ax] // 2 for ax in axes)
return roll(x, shift, axes1)
def ifftshift(x, axes=None):
x = asarray(x)
if axes is None:
axes1 = tuple_range(x.ndim)
shift = tuple(-(dim // 2) for dim in x.shape)
elif isinstance(axes, int):
axes1 = axes
shift = -(x.shape[axes] // 2)
else:
axes1 = axes
shift = tuple(-(x.shape[ax] // 2) for ax in axes)
return roll(x, shift, axes1)
def fftfreq(n: int, d: float = 1.0):
val = 1.0 / (n * d)
results = empty(n, int)
N = (n-1)//2 + 1
p1 = arange(0, N, dtype=int)
results[:N] = p1
p2 = arange(-(n//2), 0, dtype=int)
results[N:] = p2
return results * val
def rfftfreq(n: int, d: float = 1.0):
val = 1.0/(n*d)
N = n//2 + 1
results = arange(0, N, dtype=int)
return results * val