# Copyright (C) 2022-2025 Exaloop Inc. import util from .ndarray import * from .routines import * def _multiindex(indexes, shape, index: int = 0): if staticlen(indexes) != staticlen(shape): compile_error("[internal error] bad multi-index") if staticlen(indexes) == 0: return () idx = indexes[0] n = shape[0] rest = _multiindex(indexes[1:], shape[1:], index + 1) if isinstance(idx, int): idx = util.normalize_index(idx, index, n) return ((idx, idx + 1, 1, 1), *rest) else: return (idx.adjust_indices(n), *rest) def _keep_axes(indexes, index: int = 0): if staticlen(indexes) == 0: return () idx = indexes[0] rest = _keep_axes(indexes[1:], index + 1) if isinstance(idx, int): return rest else: return (index, *rest) def _base_offset(mindices, strides): offset = 0 for i in staticrange(staticlen(mindices)): offset += mindices[i][0] * strides[i] return offset def _new_shape(mindex, keep): return tuple(mindex[i][3] for i in keep) def _new_strides(mindex, strides, keep): return tuple(mindex[i][2] * strides[i] for i in keep) def _extract_special(indexes): if staticlen(indexes) == 0: return (), () idx = indexes[0] rest_newaxis, rest_ellipsis = _extract_special(indexes[1:]) if isinstance(idx, type(newaxis)): return (idx, *rest_newaxis), rest_ellipsis elif isinstance(idx, type(Ellipsis)): return rest_newaxis, (idx, *rest_ellipsis) else: return rest_newaxis, rest_ellipsis def _expand_ellipsis(indexes, n: Static[int], k: Static[int] = 1): if k == 0: return indexes if staticlen(indexes) == 0: return () idx = indexes[0] rest = _expand_ellipsis(indexes[1:], n, k) if isinstance(idx, type(Ellipsis)): s = slice(None, None, None) return ((s,) * n) + rest else: return (idx,) + rest def _expand_remainder(indexes, n: Static[int]): if n > 0: s = slice(None, None, None) return indexes + ((s,) * n) else: return indexes def _expand_newaxis(indexes, shape, strides): if staticlen(indexes) == 0: return (), (), () idx = indexes[0] if isinstance(idx, type(newaxis)): rest_indexes, rest_shape, rest_strides = _expand_newaxis(indexes[1:], shape, strides) return (slice(None, None, None), *rest_indexes), (1, *rest_shape), (0, *rest_strides) else: rest_indexes, rest_shape, rest_strides = _expand_newaxis(indexes[1:], shape[1:], strides[1:]) return (idx, *rest_indexes), (shape[0], *rest_shape), (strides[0], *rest_strides) def _extract_sequences(indexes): if staticlen(indexes) == 0: return () idx = indexes[0] rest = _extract_sequences(indexes[1:]) if (isinstance(idx, ndarray) or isinstance(idx, List) or isinstance(idx, Tuple)): return (idx,) + rest else: return rest def _adv_idx_convert(idx): if idx is None or isinstance(idx, slice) or isinstance(idx, int): return idx elif isinstance(idx, ndarray) or isinstance(idx, List) or isinstance(idx, Tuple): arr = asarray(idx) if arr.dtype is not int and arr.dtype is not bool: compile_error("advanced indexing requires integer arrays") return arr else: compile_error("unsupported index type: " + type(idx).__name__) def _adv_idx_replace_bools(indexes): if staticlen(indexes) == 0: return () idx = indexes[0] rest = _adv_idx_replace_bools(indexes[1:]) if isinstance(idx, ndarray): if idx.dtype is bool: return idx.nonzero() + rest else: return (idx,) + rest else: return (idx,) + rest def _adv_idx_length(idx, dim: int): if isinstance(idx, slice): return idx.adjust_indices(dim)[-1] elif isinstance(idx, int): return 0 elif idx is None: return 1 else: compile_error("[internal error]: bad input type") def _adv_idx_iter_non_arrays(indexes, shape): if staticlen(indexes) == 0: yield () else: idx = indexes[0] dim = shape[0] start, stop, step = idx.adjust_indices(dim) k = 0 for i in range(start, stop, step): if staticlen(indexes) == 1: yield ((k, i),) else: for rest in _adv_idx_iter_non_arrays(indexes[1:], shape[1:]): yield ((k, i),) + rest k += 1 def _adv_idx_gather_arrays(indexes, k: Static[int] = 0): if staticlen(indexes) == 0: return (), () idx = indexes[0] rest, rest_where = _adv_idx_gather_arrays(indexes[1:], k + 1) if isinstance(idx, ndarray): return (idx,) + rest, (k,) + rest_where else: return rest, rest_where def _adv_idx_gather_non_arrays(indexes, k: Static[int] = 0): if staticlen(indexes) == 0: return (), () idx = indexes[0] rest, rest_where = _adv_idx_gather_non_arrays(indexes[1:], k + 1) if not isinstance(idx, ndarray): return (idx,) + rest, (k,) + rest_where else: return rest, rest_where def _adv_idx_replace_int(idx): if idx is None or isinstance(idx, int): return idx else: return slice(None, None, None) def _adv_idx_prune_index(indexes): if staticlen(indexes) == 0: return () idx = indexes[0] rest = _adv_idx_prune_index(indexes[1:]) if idx is None: return (slice(None, None, None),) + rest elif isinstance(idx, int): return rest else: return (idx,) + rest def _adv_idx_gather_none_and_int(indexes): if staticlen(indexes) == 0: return () idx = indexes[0] rest = _adv_idx_gather_none_and_int(indexes[1:]) if idx is None or isinstance(idx, int): return (idx,) + rest else: return rest def _adv_idx_eliminate_new_and_used(arr, indexes): if staticlen(_adv_idx_gather_none_and_int(indexes)) == 0: # nothing to do return arr, indexes else: elim_idx = tuple(_adv_idx_replace_int(idx) for idx in indexes) arr = arr[elim_idx] indexes = _adv_idx_prune_index(indexes) return arr, indexes def _adv_idx_build_for_contig_array(indexes, shape_from_non_arrays, arr_shape, saw_array: Static[int] = False): if staticlen(indexes) == 0: return () idx = indexes[0] if isinstance(idx, ndarray): if saw_array: return _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays, arr_shape, saw_array) else: return arr_shape + _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays, arr_shape, True) else: return (shape_from_non_arrays[0],) + _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays[1:], arr_shape, saw_array) def _bool_idx_get_bool_index(indexes): if isinstance(indexes, ndarray): if indexes.dtype is bool: return (indexes,) else: return () elif isinstance(indexes, List): if asarray(indexes).dtype is bool: return (indexes,) else: return () else: return () def _bool_idx_num_true(indexes, sz: int): num_true = 0 if indexes._is_contig: for i in range(sz): if indexes.data[i]: num_true += 1 else: for idx in util.multirange(indexes.shape): if indexes._ptr(idx)[0]: num_true += 1 return num_true # adapted from routines def _broadcast_shapes(*args): def _largest(args): if staticlen(args) == 1: return args[0] a = args[0] b = _largest(args[1:]) if staticlen(b) > staticlen(a): return b else: return a if staticlen(args) == 0: return () t = _largest(args) N: Static[int] = staticlen(t) ans = (0,) * N p = Ptr[int](__ptr__(ans).as_byte()) for i in staticrange(N): p[i] = t[i] for a in args: for i in staticrange(staticlen(a)): x = a[len(a) - 1 - i] q = p + (len(t) - 1 - i) y = q[0] if y == 1: q[0] = x elif x != 1 and x != y: msg = _strbuf(capacity=100) msg.append("shape mismatch: indexing arrays could not be broadcast together with shapes") for sh in args: msg.append(" ") msg.append(str(sh)) raise IndexError(msg.__str__()) return ans def _getset_advanced(arr, indexes, item, dtype: type): indexes = tuple(_adv_idx_convert(idx) for idx in indexes) indexes = _adv_idx_replace_bools(indexes) newaxis_tuple, ellipsis_tuple = _extract_special(indexes) if staticlen(ellipsis_tuple) > 1: compile_error("an index can only have a single ellipsis ('...')") if staticlen(indexes) - staticlen(newaxis_tuple) - staticlen(ellipsis_tuple) > staticlen(arr.shape): compile_error("too many indices for array") indexes = _expand_ellipsis(indexes, staticlen(arr.shape) - (staticlen(indexes) - staticlen(newaxis_tuple) - staticlen(ellipsis_tuple)), staticlen(ellipsis_tuple)) indexes = _expand_remainder(indexes, staticlen(arr.shape) - staticlen(indexes) + staticlen(newaxis_tuple)) # eliminate newaxis and used axes (i.e. integer indices) arr, indexes = _adv_idx_eliminate_new_and_used(arr, indexes) shape = arr.shape # which indices are array-like? index_arrays, arrays_where = _adv_idx_gather_arrays(indexes) arrays_bshape = _broadcast_shapes(*tuple(a.shape for a in index_arrays)) if staticlen(index_arrays) == 0: compile_error("[internal error] advanced indexing is not applicable to index") # which indices are not array-like? index_non_arrays, non_arrays_where = _adv_idx_gather_non_arrays(indexes) arrays_at_front = False for i in staticrange(1, staticlen(arrays_where)): if arrays_where[i] != arrays_where[i - 1] + 1: arrays_at_front = True shape_from_non_arrays = tuple(_adv_idx_length(index_non_arrays[i], shape[non_arrays_where[i]]) for i in staticrange(staticlen(index_non_arrays))) if arrays_at_front: ans_shape = arrays_bshape + shape_from_non_arrays else: ans_shape = _adv_idx_build_for_contig_array(indexes, shape_from_non_arrays, arrays_bshape) if item is None: ans = empty(ans_shape, dtype) item_arr = None else: ans = None item_arr = broadcast_to(asarray(item), ans_shape) subshape = tuple(shape[i] for i in non_arrays_where) for idx in util.multirange(arrays_bshape): idx_from_arrays = tuple(a._ptr(idx, broadcast=True)[0] for a in index_arrays) for idx_from_non_arrays in _adv_idx_iter_non_arrays(index_non_arrays, subshape): dst_idx_from_non_arrays = tuple(x[0] for x in idx_from_non_arrays) src_idx_from_non_arrays = tuple(x[1] for x in idx_from_non_arrays) if arrays_at_front: dst_idx = idx + dst_idx_from_non_arrays else: dst_idx = _adv_idx_build_for_contig_array(indexes, dst_idx_from_non_arrays, idx) src_idx = (0,) * staticlen(arr.shape) psrc_idx = Ptr[int](__ptr__(src_idx).as_byte()) for i in staticrange(staticlen(index_arrays)): psrc_idx[arrays_where[i]] = idx_from_arrays[i] for i in staticrange(staticlen(index_non_arrays)): psrc_idx[non_arrays_where[i]] = src_idx_from_non_arrays[i] if item is None: ans[dst_idx] = arr[src_idx] else: arr[src_idx] = util.cast(item_arr[dst_idx], arr.dtype) if item is None: return ans def _getset_bool(arr, indexes, item, dtype: type = NoneType): indexes = asarray(indexes) if staticlen(indexes.shape) > staticlen(arr.shape): compile_error("too many indices for array") elif staticlen(arr.shape) == 0: if item is None: if indexes.item(): return atleast_1d(arr) else: return empty(0, arr.dtype) else: if indexes.item(): arr_item = asarray(item) arr.data[0] = util.cast(arr_item.item(), arr.dtype) elif staticlen(indexes.shape) == staticlen(arr.shape): sz = 1 for i in range(len(indexes.shape)): arr_dim = arr.shape[i] idx_dim = indexes.shape[i] if arr_dim != idx_dim: raise IndexError(f"boolean index did not match indexed array " f"along dimension {i}; dimension is {arr_dim} but " f"corresponding boolean dimension is {idx_dim}") sz *= arr_dim num_true = 0 if item is None: num_true = _bool_idx_num_true(indexes, sz) arr_item = None ans = empty(num_true, arr.dtype) else: arr_item = asarray(item) if staticlen(arr_item.shape) == 1: num_true = _bool_idx_num_true(indexes, sz) if arr_item.size != 1 and arr_item.size != num_true: raise ValueError(f"NumPy boolean array indexing assignment " f"cannot assign {arr_item.size} input values " f"to the {num_true} output values where the " f"mask is true") elif staticlen(arr_item.shape) > 1: compile_error("NumPy boolean array indexing assignment requires a 0 or 1-dimensional input") ans = None cc1, _ = arr._contig cc2, _ = indexes._contig k = 0 if cc1 and cc2: for i in range(sz): if indexes.data[i]: if item is None: ans.data[k] = arr.data[i] else: if staticlen(arr_item.shape) == 0: arr.data[i] = util.cast(arr_item.item(), arr.dtype) else: elem = arr_item.data[0] if arr_item.size == 1 else arr_item.data[k] arr.data[i] = util.cast(elem, arr.dtype) k += 1 else: for idx in util.multirange(arr.shape): if indexes._ptr(idx)[0]: if item is None: ans.data[k] = arr._ptr(idx)[0] else: if staticlen(arr_item.shape) == 0: arr._ptr(idx)[0] = util.cast(arr_item.item(), arr.dtype) else: elem = arr_item.data[0] if arr_item.size == 1 else arr_item.data[k] arr._ptr(idx)[0] = util.cast(elem, arr.dtype) k += 1 if item is None: return ans else: return _getset_advanced(arr, indexes.nonzero(), item, dtype) def _assign_one(arr, elem): elem = util.cast(elem, arr.dtype) if staticlen(arr.shape) == 0: arr.data[0] = elem else: if arr._is_contig: p = arr.data for i in range(util.count(arr.shape)): p[i] = elem else: for idx in util.multirange(arr.shape): a = arr._ptr(idx) a[0] = elem def _getset(arr, indexes, item, dtype: type = NoneType): if staticlen(_bool_idx_get_bool_index(indexes)) > 0: return _getset_bool(arr, indexes, item, dtype) if not isinstance(indexes, Tuple): return _getset(arr, (indexes,), item, dtype) if staticlen(_extract_sequences(indexes)) > 0: return _getset_advanced(arr, indexes, item, dtype) newaxis_tuple, ellipsis_tuple = _extract_special(indexes) if staticlen(ellipsis_tuple) > 1: compile_error("an index can only have a single ellipsis ('...')") if staticlen(indexes) - staticlen(newaxis_tuple) - staticlen(ellipsis_tuple) > staticlen(arr.shape): compile_error("too many indices for array") indexes = _expand_ellipsis(indexes, staticlen(arr.shape) - (staticlen(indexes) - staticlen(newaxis_tuple) - staticlen(ellipsis_tuple)), staticlen(ellipsis_tuple)) indexes = _expand_remainder(indexes, staticlen(arr.shape) - staticlen(indexes) + staticlen(newaxis_tuple)) indexes, shape, strides = _expand_newaxis(indexes, arr.shape, arr.strides) mindex = _multiindex(indexes, shape) keep = _keep_axes(indexes) p = Ptr[dtype](arr._data.as_byte() + _base_offset(mindex, strides)) if staticlen(keep) == 0: if item is None: return p[0] else: p[0] = util.cast(item, dtype) ans_shape = _new_shape(mindex, keep) ans_strides = _new_strides(mindex, strides, keep) sub = ndarray(ans_shape, ans_strides, p) if item is None: return sub elif isinstance(item, ndarray) or isinstance(item, List) or isinstance(item, Tuple): arr_item = asarray(item) if staticlen(arr_item.shape) == 0: _assign_one(sub, arr_item.item()) else: arr_bcast = broadcast_to(arr_item, sub.shape) if sub._contig_match(arr_bcast): q = arr_bcast._data for i in range(arr_bcast.size): p[i] = util.cast(q[i], dtype) else: for idx in util.multirange(sub.shape): a = sub._ptr(idx) b = arr_bcast._ptr(idx) a[0] = util.cast(b[0], dtype) else: _assign_one(sub, item) @extend class ndarray: def __getitem__(self, indexes): return _getset(self, indexes, None, dtype) def __setitem__(self, indexes, item): _getset(self, indexes, item=item, dtype=dtype)