codon/stdlib/numpy/indexing.codon

566 lines
18 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
import util
from .ndarray import *
from .routines import *
def _multiindex(indexes, shape, index: int = 0):
if staticlen(indexes) != staticlen(shape):
compile_error("[internal error] bad multi-index")
if staticlen(indexes) == 0:
return ()
idx = indexes[0]
n = shape[0]
rest = _multiindex(indexes[1:], shape[1:], index + 1)
if isinstance(idx, int):
idx = util.normalize_index(idx, index, n)
return ((idx, idx + 1, 1, 1), *rest)
else:
return (idx.adjust_indices(n), *rest)
def _keep_axes(indexes, index: int = 0):
if staticlen(indexes) == 0:
return ()
idx = indexes[0]
rest = _keep_axes(indexes[1:], index + 1)
if isinstance(idx, int):
return rest
else:
return (index, *rest)
def _base_offset(mindices, strides):
offset = 0
for i in staticrange(staticlen(mindices)):
offset += mindices[i][0] * strides[i]
return offset
def _new_shape(mindex, keep):
return tuple(mindex[i][3] for i in keep)
def _new_strides(mindex, strides, keep):
return tuple(mindex[i][2] * strides[i] for i in keep)
def _extract_special(indexes):
if staticlen(indexes) == 0:
return (), ()
idx = indexes[0]
rest_newaxis, rest_ellipsis = _extract_special(indexes[1:])
if isinstance(idx, type(newaxis)):
return (idx, *rest_newaxis), rest_ellipsis
elif isinstance(idx, type(Ellipsis)):
return rest_newaxis, (idx, *rest_ellipsis)
else:
return rest_newaxis, rest_ellipsis
def _expand_ellipsis(indexes, n: Static[int], k: Static[int] = 1):
if k == 0:
return indexes
if staticlen(indexes) == 0:
return ()
idx = indexes[0]
rest = _expand_ellipsis(indexes[1:], n, k)
if isinstance(idx, type(Ellipsis)):
s = slice(None, None, None)
return ((s,) * n) + rest
else:
return (idx,) + rest
def _expand_remainder(indexes, n: Static[int]):
if n > 0:
s = slice(None, None, None)
return indexes + ((s,) * n)
else:
return indexes
def _expand_newaxis(indexes, shape, strides):
if staticlen(indexes) == 0:
return (), (), ()
idx = indexes[0]
if isinstance(idx, type(newaxis)):
rest_indexes, rest_shape, rest_strides = _expand_newaxis(indexes[1:], shape, strides)
return (slice(None, None, None), *rest_indexes), (1, *rest_shape), (0, *rest_strides)
else:
rest_indexes, rest_shape, rest_strides = _expand_newaxis(indexes[1:], shape[1:], strides[1:])
return (idx, *rest_indexes), (shape[0], *rest_shape), (strides[0], *rest_strides)
def _extract_sequences(indexes):
if staticlen(indexes) == 0:
return ()
idx = indexes[0]
rest = _extract_sequences(indexes[1:])
if (isinstance(idx, ndarray) or
isinstance(idx, List) or
isinstance(idx, Tuple)):
return (idx,) + rest
else:
return rest
def _adv_idx_convert(idx):
if idx is None or isinstance(idx, slice) or isinstance(idx, int):
return idx
elif isinstance(idx, ndarray) or isinstance(idx, List) or isinstance(idx, Tuple):
arr = asarray(idx)
if arr.dtype is not int and arr.dtype is not bool:
compile_error("advanced indexing requires integer arrays")
return arr
else:
compile_error("unsupported index type: " + type(idx).__name__)
def _adv_idx_replace_bools(indexes):
if staticlen(indexes) == 0:
return ()
idx = indexes[0]
rest = _adv_idx_replace_bools(indexes[1:])
if isinstance(idx, ndarray):
if idx.dtype is bool:
return idx.nonzero() + rest
else:
return (idx,) + rest
else:
return (idx,) + rest
def _adv_idx_length(idx, dim: int):
if isinstance(idx, slice):
return idx.adjust_indices(dim)[-1]
elif isinstance(idx, int):
return 0
elif idx is None:
return 1
else:
compile_error("[internal error]: bad input type")
def _adv_idx_iter_non_arrays(indexes, shape):
if staticlen(indexes) == 0:
yield ()
else:
idx = indexes[0]
dim = shape[0]
start, stop, step = idx.adjust_indices(dim)
k = 0
for i in range(start, stop, step):
if staticlen(indexes) == 1:
yield ((k, i),)
else:
for rest in _adv_idx_iter_non_arrays(indexes[1:], shape[1:]):
yield ((k, i),) + rest
k += 1
def _adv_idx_gather_arrays(indexes, k: Static[int] = 0):
if staticlen(indexes) == 0:
return (), ()
idx = indexes[0]
rest, rest_where = _adv_idx_gather_arrays(indexes[1:], k + 1)
if isinstance(idx, ndarray):
return (idx,) + rest, (k,) + rest_where
else:
return rest, rest_where
def _adv_idx_gather_non_arrays(indexes, k: Static[int] = 0):
if staticlen(indexes) == 0:
return (), ()
idx = indexes[0]
rest, rest_where = _adv_idx_gather_non_arrays(indexes[1:], k + 1)
if not isinstance(idx, ndarray):
return (idx,) + rest, (k,) + rest_where
else:
return rest, rest_where
def _adv_idx_replace_int(idx):
if idx is None or isinstance(idx, int):
return idx
else:
return slice(None, None, None)
def _adv_idx_prune_index(indexes):
if staticlen(indexes) == 0:
return ()
idx = indexes[0]
rest = _adv_idx_prune_index(indexes[1:])
if idx is None:
return (slice(None, None, None),) + rest
elif isinstance(idx, int):
return rest
else:
return (idx,) + rest
def _adv_idx_gather_none_and_int(indexes):
if staticlen(indexes) == 0:
return ()
idx = indexes[0]
rest = _adv_idx_gather_none_and_int(indexes[1:])
if idx is None or isinstance(idx, int):
return (idx,) + rest
else:
return rest
def _adv_idx_eliminate_new_and_used(arr, indexes):
if staticlen(_adv_idx_gather_none_and_int(indexes)) == 0:
# nothing to do
return arr, indexes
else:
elim_idx = tuple(_adv_idx_replace_int(idx) for idx in indexes)
arr = arr[elim_idx]
indexes = _adv_idx_prune_index(indexes)
return arr, indexes
def _adv_idx_build_for_contig_array(indexes, shape_from_non_arrays, arr_shape, saw_array: Static[int] = False):
if staticlen(indexes) == 0:
return ()
idx = indexes[0]
if isinstance(idx, ndarray):
if saw_array:
return _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays, arr_shape, saw_array)
else:
return arr_shape + _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays, arr_shape, True)
else:
return (shape_from_non_arrays[0],) + _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays[1:], arr_shape, saw_array)
def _bool_idx_get_bool_index(indexes):
if isinstance(indexes, ndarray):
if indexes.dtype is bool:
return (indexes,)
else:
return ()
elif isinstance(indexes, List):
if asarray(indexes).dtype is bool:
return (indexes,)
else:
return ()
else:
return ()
def _bool_idx_num_true(indexes, sz: int):
num_true = 0
if indexes._is_contig:
for i in range(sz):
if indexes.data[i]:
num_true += 1
else:
for idx in util.multirange(indexes.shape):
if indexes._ptr(idx)[0]:
num_true += 1
return num_true
# adapted from routines
def _broadcast_shapes(*args):
def _largest(args):
if staticlen(args) == 1:
return args[0]
a = args[0]
b = _largest(args[1:])
if staticlen(b) > staticlen(a):
return b
else:
return a
if staticlen(args) == 0:
return ()
t = _largest(args)
N: Static[int] = staticlen(t)
ans = (0,) * N
p = Ptr[int](__ptr__(ans).as_byte())
for i in staticrange(N):
p[i] = t[i]
for a in args:
for i in staticrange(staticlen(a)):
x = a[len(a) - 1 - i]
q = p + (len(t) - 1 - i)
y = q[0]
if y == 1:
q[0] = x
elif x != 1 and x != y:
msg = _strbuf(capacity=100)
msg.append("shape mismatch: indexing arrays could not be broadcast together with shapes")
for sh in args:
msg.append(" ")
msg.append(str(sh))
raise IndexError(msg.__str__())
return ans
def _getset_advanced(arr, indexes, item, dtype: type):
indexes = tuple(_adv_idx_convert(idx) for idx in indexes)
indexes = _adv_idx_replace_bools(indexes)
newaxis_tuple, ellipsis_tuple = _extract_special(indexes)
if staticlen(ellipsis_tuple) > 1:
compile_error("an index can only have a single ellipsis ('...')")
if staticlen(indexes) - staticlen(newaxis_tuple) - staticlen(ellipsis_tuple) > staticlen(arr.shape):
compile_error("too many indices for array")
indexes = _expand_ellipsis(indexes,
staticlen(arr.shape)
- (staticlen(indexes)
- staticlen(newaxis_tuple)
- staticlen(ellipsis_tuple)),
staticlen(ellipsis_tuple))
indexes = _expand_remainder(indexes,
staticlen(arr.shape)
- staticlen(indexes)
+ staticlen(newaxis_tuple))
# eliminate newaxis and used axes (i.e. integer indices)
arr, indexes = _adv_idx_eliminate_new_and_used(arr, indexes)
shape = arr.shape
# which indices are array-like?
index_arrays, arrays_where = _adv_idx_gather_arrays(indexes)
arrays_bshape = _broadcast_shapes(*tuple(a.shape for a in index_arrays))
if staticlen(index_arrays) == 0:
compile_error("[internal error] advanced indexing is not applicable to index")
# which indices are not array-like?
index_non_arrays, non_arrays_where = _adv_idx_gather_non_arrays(indexes)
arrays_at_front = False
for i in staticrange(1, staticlen(arrays_where)):
if arrays_where[i] != arrays_where[i - 1] + 1:
arrays_at_front = True
shape_from_non_arrays = tuple(_adv_idx_length(index_non_arrays[i], shape[non_arrays_where[i]])
for i in staticrange(staticlen(index_non_arrays)))
if arrays_at_front:
ans_shape = arrays_bshape + shape_from_non_arrays
else:
ans_shape = _adv_idx_build_for_contig_array(indexes, shape_from_non_arrays, arrays_bshape)
if item is None:
ans = empty(ans_shape, dtype)
item_arr = None
else:
ans = None
item_arr = broadcast_to(asarray(item), ans_shape)
subshape = tuple(shape[i] for i in non_arrays_where)
for idx in util.multirange(arrays_bshape):
idx_from_arrays = tuple(a._ptr(idx, broadcast=True)[0] for a in index_arrays)
for idx_from_non_arrays in _adv_idx_iter_non_arrays(index_non_arrays, subshape):
dst_idx_from_non_arrays = tuple(x[0] for x in idx_from_non_arrays)
src_idx_from_non_arrays = tuple(x[1] for x in idx_from_non_arrays)
if arrays_at_front:
dst_idx = idx + dst_idx_from_non_arrays
else:
dst_idx = _adv_idx_build_for_contig_array(indexes, dst_idx_from_non_arrays, idx)
src_idx = (0,) * staticlen(arr.shape)
psrc_idx = Ptr[int](__ptr__(src_idx).as_byte())
for i in staticrange(staticlen(index_arrays)):
psrc_idx[arrays_where[i]] = idx_from_arrays[i]
for i in staticrange(staticlen(index_non_arrays)):
psrc_idx[non_arrays_where[i]] = src_idx_from_non_arrays[i]
if item is None:
ans[dst_idx] = arr[src_idx]
else:
arr[src_idx] = util.cast(item_arr[dst_idx], arr.dtype)
if item is None:
return ans
def _getset_bool(arr, indexes, item, dtype: type = NoneType):
indexes = asarray(indexes)
if staticlen(indexes.shape) > staticlen(arr.shape):
compile_error("too many indices for array")
elif staticlen(arr.shape) == 0:
if item is None:
if indexes.item():
return atleast_1d(arr)
else:
return empty(0, arr.dtype)
else:
if indexes.item():
arr_item = asarray(item)
arr.data[0] = util.cast(arr_item.item(), arr.dtype)
elif staticlen(indexes.shape) == staticlen(arr.shape):
sz = 1
for i in range(len(indexes.shape)):
arr_dim = arr.shape[i]
idx_dim = indexes.shape[i]
if arr_dim != idx_dim:
raise IndexError(f"boolean index did not match indexed array "
f"along dimension {i}; dimension is {arr_dim} but "
f"corresponding boolean dimension is {idx_dim}")
sz *= arr_dim
num_true = 0
if item is None:
num_true = _bool_idx_num_true(indexes, sz)
arr_item = None
ans = empty(num_true, arr.dtype)
else:
arr_item = asarray(item)
if staticlen(arr_item.shape) == 1:
num_true = _bool_idx_num_true(indexes, sz)
if arr_item.size != 1 and arr_item.size != num_true:
raise ValueError(f"NumPy boolean array indexing assignment "
f"cannot assign {arr_item.size} input values "
f"to the {num_true} output values where the "
f"mask is true")
elif staticlen(arr_item.shape) > 1:
compile_error("NumPy boolean array indexing assignment requires a 0 or 1-dimensional input")
ans = None
cc1, _ = arr._contig
cc2, _ = indexes._contig
k = 0
if cc1 and cc2:
for i in range(sz):
if indexes.data[i]:
if item is None:
ans.data[k] = arr.data[i]
else:
if staticlen(arr_item.shape) == 0:
arr.data[i] = util.cast(arr_item.item(), arr.dtype)
else:
elem = arr_item.data[0] if arr_item.size == 1 else arr_item.data[k]
arr.data[i] = util.cast(elem, arr.dtype)
k += 1
else:
for idx in util.multirange(arr.shape):
if indexes._ptr(idx)[0]:
if item is None:
ans.data[k] = arr._ptr(idx)[0]
else:
if staticlen(arr_item.shape) == 0:
arr._ptr(idx)[0] = util.cast(arr_item.item(), arr.dtype)
else:
elem = arr_item.data[0] if arr_item.size == 1 else arr_item.data[k]
arr._ptr(idx)[0] = util.cast(elem, arr.dtype)
k += 1
if item is None:
return ans
else:
return _getset_advanced(arr, indexes.nonzero(), item, dtype)
def _assign_one(arr, elem):
elem = util.cast(elem, arr.dtype)
if staticlen(arr.shape) == 0:
arr.data[0] = elem
else:
if arr._is_contig:
p = arr.data
for i in range(util.count(arr.shape)):
p[i] = elem
else:
for idx in util.multirange(arr.shape):
a = arr._ptr(idx)
a[0] = elem
def _getset(arr, indexes, item, dtype: type = NoneType):
if staticlen(_bool_idx_get_bool_index(indexes)) > 0:
return _getset_bool(arr, indexes, item, dtype)
if not isinstance(indexes, Tuple):
return _getset(arr, (indexes,), item, dtype)
if staticlen(_extract_sequences(indexes)) > 0:
return _getset_advanced(arr, indexes, item, dtype)
newaxis_tuple, ellipsis_tuple = _extract_special(indexes)
if staticlen(ellipsis_tuple) > 1:
compile_error("an index can only have a single ellipsis ('...')")
if staticlen(indexes) - staticlen(newaxis_tuple) - staticlen(ellipsis_tuple) > staticlen(arr.shape):
compile_error("too many indices for array")
indexes = _expand_ellipsis(indexes,
staticlen(arr.shape)
- (staticlen(indexes)
- staticlen(newaxis_tuple)
- staticlen(ellipsis_tuple)),
staticlen(ellipsis_tuple))
indexes = _expand_remainder(indexes,
staticlen(arr.shape)
- staticlen(indexes)
+ staticlen(newaxis_tuple))
indexes, shape, strides = _expand_newaxis(indexes, arr.shape, arr.strides)
mindex = _multiindex(indexes, shape)
keep = _keep_axes(indexes)
p = Ptr[dtype](arr._data.as_byte() + _base_offset(mindex, strides))
if staticlen(keep) == 0:
if item is None:
return p[0]
else:
p[0] = util.cast(item, dtype)
ans_shape = _new_shape(mindex, keep)
ans_strides = _new_strides(mindex, strides, keep)
sub = ndarray(ans_shape, ans_strides, p)
if item is None:
return sub
elif isinstance(item, ndarray) or isinstance(item, List) or isinstance(item, Tuple):
arr_item = asarray(item)
if staticlen(arr_item.shape) == 0:
_assign_one(sub, arr_item.item())
else:
arr_bcast = broadcast_to(arr_item, sub.shape)
if sub._contig_match(arr_bcast):
q = arr_bcast._data
for i in range(arr_bcast.size):
p[i] = util.cast(q[i], dtype)
else:
for idx in util.multirange(sub.shape):
a = sub._ptr(idx)
b = arr_bcast._ptr(idx)
a[0] = util.cast(b[0], dtype)
else:
_assign_one(sub, item)
@extend
class ndarray:
def __getitem__(self, indexes):
return _getset(self, indexes, None, dtype)
def __setitem__(self, indexes, item):
_getset(self, indexes, item=item, dtype=dtype)