mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
from ..ndarray import ndarray
|
|
from ..routines import asarray
|
|
from ..util import strides as make_strides, tuple_range, normalize_axis_tuple
|
|
|
|
def as_strided(x, shape = None, strides = None, writeable: bool = True):
|
|
x = asarray(x)
|
|
|
|
if shape is None:
|
|
return as_strided(x, shape=x.shape, strides=strides, writeable=writeable)
|
|
|
|
if strides is None:
|
|
st = make_strides(shape, False, x.dtype)
|
|
return as_strided(x, shape=shape, strides=st, writeable=writeable)
|
|
|
|
if not isinstance(shape, Tuple):
|
|
compile_error("shape must be a tuple of integers")
|
|
|
|
if not isinstance(strides, Tuple):
|
|
compile_error("strides must be a tuple of integers")
|
|
|
|
if staticlen(shape) != staticlen(strides):
|
|
compile_error("shape and strides have different lengths")
|
|
|
|
return ndarray(shape, strides, x.data)
|
|
|
|
def sliding_window_view(x, window_shape, axis = None, writeable: bool = False):
|
|
if isinstance(window_shape, int):
|
|
return sliding_window_view(x, window_shape=(window_shape,), axis=axis, writeable=writeable)
|
|
|
|
x = asarray(x)
|
|
ndim: Static[int] = staticlen(x.shape)
|
|
|
|
for w in window_shape:
|
|
if w < 0:
|
|
raise ValueError("`window_shape` cannot contain negative values")
|
|
|
|
if axis is None:
|
|
ax = tuple_range(ndim)
|
|
else:
|
|
ax = normalize_axis_tuple(axis, ndim, allow_duplicates=True)
|
|
|
|
if staticlen(window_shape) != staticlen(ax):
|
|
compile_error("window_shape length does not match dimension of x")
|
|
|
|
out_strides = x.strides + tuple(x.strides[a] for a in ax)
|
|
x_shape_trimmed = x.shape
|
|
px_shape_trimmed = Ptr[int](__ptr__(x_shape_trimmed).as_byte())
|
|
|
|
for i in range(len(ax)):
|
|
a = ax[i]
|
|
dim = window_shape[i]
|
|
if px_shape_trimmed[a] < dim:
|
|
raise ValueError(
|
|
"window shape cannot be larger than input array shape")
|
|
px_shape_trimmed[a] -= dim - 1
|
|
out_shape = tuple(x_shape_trimmed) + window_shape
|
|
return as_strided(x, strides=out_strides, shape=out_shape, writeable=writeable)
|