1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/lib/stride_tricks.codon
A. R. Shajii b8c1eeed36
2025 updates (#619)
* 2025 updates

* Update ci.yml
2025-01-29 15:41:43 -05:00

60 lines
2.1 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from ..ndarray import ndarray
from ..routines import asarray
from ..util import strides as make_strides, tuple_range, normalize_axis_tuple
def as_strided(x, shape = None, strides = None, writeable: bool = True):
x = asarray(x)
if shape is None:
return as_strided(x, shape=x.shape, strides=strides, writeable=writeable)
if strides is None:
st = make_strides(shape, False, x.dtype)
return as_strided(x, shape=shape, strides=st, writeable=writeable)
if not isinstance(shape, Tuple):
compile_error("shape must be a tuple of integers")
if not isinstance(strides, Tuple):
compile_error("strides must be a tuple of integers")
if staticlen(shape) != staticlen(strides):
compile_error("shape and strides have different lengths")
return ndarray(shape, strides, x.data)
def sliding_window_view(x, window_shape, axis = None, writeable: bool = False):
if isinstance(window_shape, int):
return sliding_window_view(x, window_shape=(window_shape,), axis=axis, writeable=writeable)
x = asarray(x)
ndim: Static[int] = staticlen(x.shape)
for w in window_shape:
if w < 0:
raise ValueError("`window_shape` cannot contain negative values")
if axis is None:
ax = tuple_range(ndim)
else:
ax = normalize_axis_tuple(axis, ndim, allow_duplicates=True)
if staticlen(window_shape) != staticlen(ax):
compile_error("window_shape length does not match dimension of x")
out_strides = x.strides + tuple(x.strides[a] for a in ax)
x_shape_trimmed = x.shape
px_shape_trimmed = Ptr[int](__ptr__(x_shape_trimmed).as_byte())
for i in range(len(ax)):
a = ax[i]
dim = window_shape[i]
if px_shape_trimmed[a] < dim:
raise ValueError(
"window shape cannot be larger than input array shape")
px_shape_trimmed[a] -= dim - 1
out_shape = tuple(x_shape_trimmed) + window_shape
return as_strided(x, strides=out_strides, shape=out_shape, writeable=writeable)