mirror of https://github.com/exaloop/codon.git
566 lines
18 KiB
Python
566 lines
18 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
import util
|
|
from .ndarray import *
|
|
from .routines import *
|
|
|
|
def _multiindex(indexes, shape, index: int = 0):
|
|
if staticlen(indexes) != staticlen(shape):
|
|
compile_error("[internal error] bad multi-index")
|
|
|
|
if staticlen(indexes) == 0:
|
|
return ()
|
|
|
|
idx = indexes[0]
|
|
n = shape[0]
|
|
rest = _multiindex(indexes[1:], shape[1:], index + 1)
|
|
|
|
if isinstance(idx, int):
|
|
idx = util.normalize_index(idx, index, n)
|
|
return ((idx, idx + 1, 1, 1), *rest)
|
|
else:
|
|
return (idx.adjust_indices(n), *rest)
|
|
|
|
def _keep_axes(indexes, index: int = 0):
|
|
if staticlen(indexes) == 0:
|
|
return ()
|
|
|
|
idx = indexes[0]
|
|
rest = _keep_axes(indexes[1:], index + 1)
|
|
|
|
if isinstance(idx, int):
|
|
return rest
|
|
else:
|
|
return (index, *rest)
|
|
|
|
def _base_offset(mindices, strides):
|
|
offset = 0
|
|
for i in staticrange(staticlen(mindices)):
|
|
offset += mindices[i][0] * strides[i]
|
|
return offset
|
|
|
|
def _new_shape(mindex, keep):
|
|
return tuple(mindex[i][3] for i in keep)
|
|
|
|
def _new_strides(mindex, strides, keep):
|
|
return tuple(mindex[i][2] * strides[i] for i in keep)
|
|
|
|
def _extract_special(indexes):
|
|
if staticlen(indexes) == 0:
|
|
return (), ()
|
|
|
|
idx = indexes[0]
|
|
rest_newaxis, rest_ellipsis = _extract_special(indexes[1:])
|
|
|
|
if isinstance(idx, type(newaxis)):
|
|
return (idx, *rest_newaxis), rest_ellipsis
|
|
elif isinstance(idx, type(Ellipsis)):
|
|
return rest_newaxis, (idx, *rest_ellipsis)
|
|
else:
|
|
return rest_newaxis, rest_ellipsis
|
|
|
|
def _expand_ellipsis(indexes, n: Static[int], k: Static[int] = 1):
|
|
if k == 0:
|
|
return indexes
|
|
|
|
if staticlen(indexes) == 0:
|
|
return ()
|
|
|
|
idx = indexes[0]
|
|
rest = _expand_ellipsis(indexes[1:], n, k)
|
|
|
|
if isinstance(idx, type(Ellipsis)):
|
|
s = slice(None, None, None)
|
|
return ((s,) * n) + rest
|
|
else:
|
|
return (idx,) + rest
|
|
|
|
def _expand_remainder(indexes, n: Static[int]):
|
|
if n > 0:
|
|
s = slice(None, None, None)
|
|
return indexes + ((s,) * n)
|
|
else:
|
|
return indexes
|
|
|
|
def _expand_newaxis(indexes, shape, strides):
|
|
if staticlen(indexes) == 0:
|
|
return (), (), ()
|
|
|
|
idx = indexes[0]
|
|
if isinstance(idx, type(newaxis)):
|
|
rest_indexes, rest_shape, rest_strides = _expand_newaxis(indexes[1:], shape, strides)
|
|
return (slice(None, None, None), *rest_indexes), (1, *rest_shape), (0, *rest_strides)
|
|
else:
|
|
rest_indexes, rest_shape, rest_strides = _expand_newaxis(indexes[1:], shape[1:], strides[1:])
|
|
return (idx, *rest_indexes), (shape[0], *rest_shape), (strides[0], *rest_strides)
|
|
|
|
def _extract_sequences(indexes):
|
|
if staticlen(indexes) == 0:
|
|
return ()
|
|
|
|
idx = indexes[0]
|
|
rest = _extract_sequences(indexes[1:])
|
|
|
|
if (isinstance(idx, ndarray) or
|
|
isinstance(idx, List) or
|
|
isinstance(idx, Tuple)):
|
|
return (idx,) + rest
|
|
else:
|
|
return rest
|
|
|
|
def _adv_idx_convert(idx):
|
|
if idx is None or isinstance(idx, slice) or isinstance(idx, int):
|
|
return idx
|
|
elif isinstance(idx, ndarray) or isinstance(idx, List) or isinstance(idx, Tuple):
|
|
arr = asarray(idx)
|
|
if arr.dtype is not int and arr.dtype is not bool:
|
|
compile_error("advanced indexing requires integer arrays")
|
|
return arr
|
|
else:
|
|
compile_error("unsupported index type: " + type(idx).__name__)
|
|
|
|
def _adv_idx_replace_bools(indexes):
|
|
if staticlen(indexes) == 0:
|
|
return ()
|
|
|
|
idx = indexes[0]
|
|
rest = _adv_idx_replace_bools(indexes[1:])
|
|
|
|
if isinstance(idx, ndarray):
|
|
if idx.dtype is bool:
|
|
return idx.nonzero() + rest
|
|
else:
|
|
return (idx,) + rest
|
|
else:
|
|
return (idx,) + rest
|
|
|
|
def _adv_idx_length(idx, dim: int):
|
|
if isinstance(idx, slice):
|
|
return idx.adjust_indices(dim)[-1]
|
|
elif isinstance(idx, int):
|
|
return 0
|
|
elif idx is None:
|
|
return 1
|
|
else:
|
|
compile_error("[internal error]: bad input type")
|
|
|
|
def _adv_idx_iter_non_arrays(indexes, shape):
|
|
if staticlen(indexes) == 0:
|
|
yield ()
|
|
else:
|
|
idx = indexes[0]
|
|
dim = shape[0]
|
|
start, stop, step = idx.adjust_indices(dim)
|
|
|
|
k = 0
|
|
for i in range(start, stop, step):
|
|
if staticlen(indexes) == 1:
|
|
yield ((k, i),)
|
|
else:
|
|
for rest in _adv_idx_iter_non_arrays(indexes[1:], shape[1:]):
|
|
yield ((k, i),) + rest
|
|
k += 1
|
|
|
|
def _adv_idx_gather_arrays(indexes, k: Static[int] = 0):
|
|
if staticlen(indexes) == 0:
|
|
return (), ()
|
|
|
|
idx = indexes[0]
|
|
rest, rest_where = _adv_idx_gather_arrays(indexes[1:], k + 1)
|
|
|
|
if isinstance(idx, ndarray):
|
|
return (idx,) + rest, (k,) + rest_where
|
|
else:
|
|
return rest, rest_where
|
|
|
|
def _adv_idx_gather_non_arrays(indexes, k: Static[int] = 0):
|
|
if staticlen(indexes) == 0:
|
|
return (), ()
|
|
|
|
idx = indexes[0]
|
|
rest, rest_where = _adv_idx_gather_non_arrays(indexes[1:], k + 1)
|
|
|
|
if not isinstance(idx, ndarray):
|
|
return (idx,) + rest, (k,) + rest_where
|
|
else:
|
|
return rest, rest_where
|
|
|
|
def _adv_idx_replace_int(idx):
|
|
if idx is None or isinstance(idx, int):
|
|
return idx
|
|
else:
|
|
return slice(None, None, None)
|
|
|
|
def _adv_idx_prune_index(indexes):
|
|
if staticlen(indexes) == 0:
|
|
return ()
|
|
|
|
idx = indexes[0]
|
|
rest = _adv_idx_prune_index(indexes[1:])
|
|
|
|
if idx is None:
|
|
return (slice(None, None, None),) + rest
|
|
elif isinstance(idx, int):
|
|
return rest
|
|
else:
|
|
return (idx,) + rest
|
|
|
|
def _adv_idx_gather_none_and_int(indexes):
|
|
if staticlen(indexes) == 0:
|
|
return ()
|
|
|
|
idx = indexes[0]
|
|
rest = _adv_idx_gather_none_and_int(indexes[1:])
|
|
|
|
if idx is None or isinstance(idx, int):
|
|
return (idx,) + rest
|
|
else:
|
|
return rest
|
|
|
|
def _adv_idx_eliminate_new_and_used(arr, indexes):
|
|
if staticlen(_adv_idx_gather_none_and_int(indexes)) == 0:
|
|
# nothing to do
|
|
return arr, indexes
|
|
else:
|
|
elim_idx = tuple(_adv_idx_replace_int(idx) for idx in indexes)
|
|
arr = arr[elim_idx]
|
|
indexes = _adv_idx_prune_index(indexes)
|
|
return arr, indexes
|
|
|
|
def _adv_idx_build_for_contig_array(indexes, shape_from_non_arrays, arr_shape, saw_array: Static[int] = False):
|
|
if staticlen(indexes) == 0:
|
|
return ()
|
|
|
|
idx = indexes[0]
|
|
if isinstance(idx, ndarray):
|
|
if saw_array:
|
|
return _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays, arr_shape, saw_array)
|
|
else:
|
|
return arr_shape + _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays, arr_shape, True)
|
|
else:
|
|
return (shape_from_non_arrays[0],) + _adv_idx_build_for_contig_array(indexes[1:], shape_from_non_arrays[1:], arr_shape, saw_array)
|
|
|
|
def _bool_idx_get_bool_index(indexes):
|
|
if isinstance(indexes, ndarray):
|
|
if indexes.dtype is bool:
|
|
return (indexes,)
|
|
else:
|
|
return ()
|
|
elif isinstance(indexes, List):
|
|
if asarray(indexes).dtype is bool:
|
|
return (indexes,)
|
|
else:
|
|
return ()
|
|
else:
|
|
return ()
|
|
|
|
def _bool_idx_num_true(indexes, sz: int):
|
|
num_true = 0
|
|
if indexes._is_contig:
|
|
for i in range(sz):
|
|
if indexes.data[i]:
|
|
num_true += 1
|
|
else:
|
|
for idx in util.multirange(indexes.shape):
|
|
if indexes._ptr(idx)[0]:
|
|
num_true += 1
|
|
|
|
return num_true
|
|
|
|
# adapted from routines
|
|
def _broadcast_shapes(*args):
|
|
def _largest(args):
|
|
if staticlen(args) == 1:
|
|
return args[0]
|
|
|
|
a = args[0]
|
|
b = _largest(args[1:])
|
|
if staticlen(b) > staticlen(a):
|
|
return b
|
|
else:
|
|
return a
|
|
|
|
if staticlen(args) == 0:
|
|
return ()
|
|
|
|
t = _largest(args)
|
|
N: Static[int] = staticlen(t)
|
|
ans = (0,) * N
|
|
p = Ptr[int](__ptr__(ans).as_byte())
|
|
|
|
for i in staticrange(N):
|
|
p[i] = t[i]
|
|
|
|
for a in args:
|
|
for i in staticrange(staticlen(a)):
|
|
x = a[len(a) - 1 - i]
|
|
q = p + (len(t) - 1 - i)
|
|
y = q[0]
|
|
|
|
if y == 1:
|
|
q[0] = x
|
|
elif x != 1 and x != y:
|
|
msg = _strbuf(capacity=100)
|
|
msg.append("shape mismatch: indexing arrays could not be broadcast together with shapes")
|
|
for sh in args:
|
|
msg.append(" ")
|
|
msg.append(str(sh))
|
|
raise IndexError(msg.__str__())
|
|
|
|
return ans
|
|
|
|
def _getset_advanced(arr, indexes, item, dtype: type):
|
|
indexes = tuple(_adv_idx_convert(idx) for idx in indexes)
|
|
indexes = _adv_idx_replace_bools(indexes)
|
|
newaxis_tuple, ellipsis_tuple = _extract_special(indexes)
|
|
|
|
if staticlen(ellipsis_tuple) > 1:
|
|
compile_error("an index can only have a single ellipsis ('...')")
|
|
|
|
if staticlen(indexes) - staticlen(newaxis_tuple) - staticlen(ellipsis_tuple) > staticlen(arr.shape):
|
|
compile_error("too many indices for array")
|
|
|
|
indexes = _expand_ellipsis(indexes,
|
|
staticlen(arr.shape)
|
|
- (staticlen(indexes)
|
|
- staticlen(newaxis_tuple)
|
|
- staticlen(ellipsis_tuple)),
|
|
staticlen(ellipsis_tuple))
|
|
indexes = _expand_remainder(indexes,
|
|
staticlen(arr.shape)
|
|
- staticlen(indexes)
|
|
+ staticlen(newaxis_tuple))
|
|
|
|
# eliminate newaxis and used axes (i.e. integer indices)
|
|
arr, indexes = _adv_idx_eliminate_new_and_used(arr, indexes)
|
|
shape = arr.shape
|
|
|
|
# which indices are array-like?
|
|
index_arrays, arrays_where = _adv_idx_gather_arrays(indexes)
|
|
arrays_bshape = _broadcast_shapes(*tuple(a.shape for a in index_arrays))
|
|
|
|
if staticlen(index_arrays) == 0:
|
|
compile_error("[internal error] advanced indexing is not applicable to index")
|
|
|
|
# which indices are not array-like?
|
|
index_non_arrays, non_arrays_where = _adv_idx_gather_non_arrays(indexes)
|
|
|
|
arrays_at_front = False
|
|
for i in staticrange(1, staticlen(arrays_where)):
|
|
if arrays_where[i] != arrays_where[i - 1] + 1:
|
|
arrays_at_front = True
|
|
|
|
shape_from_non_arrays = tuple(_adv_idx_length(index_non_arrays[i], shape[non_arrays_where[i]])
|
|
for i in staticrange(staticlen(index_non_arrays)))
|
|
|
|
if arrays_at_front:
|
|
ans_shape = arrays_bshape + shape_from_non_arrays
|
|
else:
|
|
ans_shape = _adv_idx_build_for_contig_array(indexes, shape_from_non_arrays, arrays_bshape)
|
|
|
|
if item is None:
|
|
ans = empty(ans_shape, dtype)
|
|
item_arr = None
|
|
else:
|
|
ans = None
|
|
item_arr = broadcast_to(asarray(item), ans_shape)
|
|
|
|
subshape = tuple(shape[i] for i in non_arrays_where)
|
|
|
|
for idx in util.multirange(arrays_bshape):
|
|
idx_from_arrays = tuple(a._ptr(idx, broadcast=True)[0] for a in index_arrays)
|
|
|
|
for idx_from_non_arrays in _adv_idx_iter_non_arrays(index_non_arrays, subshape):
|
|
dst_idx_from_non_arrays = tuple(x[0] for x in idx_from_non_arrays)
|
|
src_idx_from_non_arrays = tuple(x[1] for x in idx_from_non_arrays)
|
|
|
|
if arrays_at_front:
|
|
dst_idx = idx + dst_idx_from_non_arrays
|
|
else:
|
|
dst_idx = _adv_idx_build_for_contig_array(indexes, dst_idx_from_non_arrays, idx)
|
|
|
|
src_idx = (0,) * staticlen(arr.shape)
|
|
psrc_idx = Ptr[int](__ptr__(src_idx).as_byte())
|
|
|
|
for i in staticrange(staticlen(index_arrays)):
|
|
psrc_idx[arrays_where[i]] = idx_from_arrays[i]
|
|
|
|
for i in staticrange(staticlen(index_non_arrays)):
|
|
psrc_idx[non_arrays_where[i]] = src_idx_from_non_arrays[i]
|
|
|
|
if item is None:
|
|
ans[dst_idx] = arr[src_idx]
|
|
else:
|
|
arr[src_idx] = util.cast(item_arr[dst_idx], arr.dtype)
|
|
|
|
if item is None:
|
|
return ans
|
|
|
|
def _getset_bool(arr, indexes, item, dtype: type = NoneType):
|
|
indexes = asarray(indexes)
|
|
|
|
if staticlen(indexes.shape) > staticlen(arr.shape):
|
|
compile_error("too many indices for array")
|
|
elif staticlen(arr.shape) == 0:
|
|
if item is None:
|
|
if indexes.item():
|
|
return atleast_1d(arr)
|
|
else:
|
|
return empty(0, arr.dtype)
|
|
else:
|
|
if indexes.item():
|
|
arr_item = asarray(item)
|
|
arr.data[0] = util.cast(arr_item.item(), arr.dtype)
|
|
elif staticlen(indexes.shape) == staticlen(arr.shape):
|
|
sz = 1
|
|
for i in range(len(indexes.shape)):
|
|
arr_dim = arr.shape[i]
|
|
idx_dim = indexes.shape[i]
|
|
if arr_dim != idx_dim:
|
|
raise IndexError(f"boolean index did not match indexed array "
|
|
f"along dimension {i}; dimension is {arr_dim} but "
|
|
f"corresponding boolean dimension is {idx_dim}")
|
|
sz *= arr_dim
|
|
|
|
num_true = 0
|
|
|
|
if item is None:
|
|
num_true = _bool_idx_num_true(indexes, sz)
|
|
arr_item = None
|
|
ans = empty(num_true, arr.dtype)
|
|
else:
|
|
arr_item = asarray(item)
|
|
|
|
if staticlen(arr_item.shape) == 1:
|
|
num_true = _bool_idx_num_true(indexes, sz)
|
|
if arr_item.size != 1 and arr_item.size != num_true:
|
|
raise ValueError(f"NumPy boolean array indexing assignment "
|
|
f"cannot assign {arr_item.size} input values "
|
|
f"to the {num_true} output values where the "
|
|
f"mask is true")
|
|
elif staticlen(arr_item.shape) > 1:
|
|
compile_error("NumPy boolean array indexing assignment requires a 0 or 1-dimensional input")
|
|
|
|
ans = None
|
|
|
|
cc1, _ = arr._contig
|
|
cc2, _ = indexes._contig
|
|
k = 0
|
|
|
|
if cc1 and cc2:
|
|
for i in range(sz):
|
|
if indexes.data[i]:
|
|
if item is None:
|
|
ans.data[k] = arr.data[i]
|
|
else:
|
|
if staticlen(arr_item.shape) == 0:
|
|
arr.data[i] = util.cast(arr_item.item(), arr.dtype)
|
|
else:
|
|
elem = arr_item.data[0] if arr_item.size == 1 else arr_item.data[k]
|
|
arr.data[i] = util.cast(elem, arr.dtype)
|
|
k += 1
|
|
else:
|
|
for idx in util.multirange(arr.shape):
|
|
if indexes._ptr(idx)[0]:
|
|
if item is None:
|
|
ans.data[k] = arr._ptr(idx)[0]
|
|
else:
|
|
if staticlen(arr_item.shape) == 0:
|
|
arr._ptr(idx)[0] = util.cast(arr_item.item(), arr.dtype)
|
|
else:
|
|
elem = arr_item.data[0] if arr_item.size == 1 else arr_item.data[k]
|
|
arr._ptr(idx)[0] = util.cast(elem, arr.dtype)
|
|
k += 1
|
|
|
|
if item is None:
|
|
return ans
|
|
else:
|
|
return _getset_advanced(arr, indexes.nonzero(), item, dtype)
|
|
|
|
def _assign_one(arr, elem):
|
|
elem = util.cast(elem, arr.dtype)
|
|
|
|
if staticlen(arr.shape) == 0:
|
|
arr.data[0] = elem
|
|
else:
|
|
if arr._is_contig:
|
|
p = arr.data
|
|
for i in range(util.count(arr.shape)):
|
|
p[i] = elem
|
|
else:
|
|
for idx in util.multirange(arr.shape):
|
|
a = arr._ptr(idx)
|
|
a[0] = elem
|
|
|
|
def _getset(arr, indexes, item, dtype: type = NoneType):
|
|
if staticlen(_bool_idx_get_bool_index(indexes)) > 0:
|
|
return _getset_bool(arr, indexes, item, dtype)
|
|
|
|
if not isinstance(indexes, Tuple):
|
|
return _getset(arr, (indexes,), item, dtype)
|
|
|
|
if staticlen(_extract_sequences(indexes)) > 0:
|
|
return _getset_advanced(arr, indexes, item, dtype)
|
|
|
|
newaxis_tuple, ellipsis_tuple = _extract_special(indexes)
|
|
|
|
if staticlen(ellipsis_tuple) > 1:
|
|
compile_error("an index can only have a single ellipsis ('...')")
|
|
|
|
if staticlen(indexes) - staticlen(newaxis_tuple) - staticlen(ellipsis_tuple) > staticlen(arr.shape):
|
|
compile_error("too many indices for array")
|
|
|
|
indexes = _expand_ellipsis(indexes,
|
|
staticlen(arr.shape)
|
|
- (staticlen(indexes)
|
|
- staticlen(newaxis_tuple)
|
|
- staticlen(ellipsis_tuple)),
|
|
staticlen(ellipsis_tuple))
|
|
indexes = _expand_remainder(indexes,
|
|
staticlen(arr.shape)
|
|
- staticlen(indexes)
|
|
+ staticlen(newaxis_tuple))
|
|
indexes, shape, strides = _expand_newaxis(indexes, arr.shape, arr.strides)
|
|
mindex = _multiindex(indexes, shape)
|
|
keep = _keep_axes(indexes)
|
|
p = Ptr[dtype](arr._data.as_byte() + _base_offset(mindex, strides))
|
|
|
|
if staticlen(keep) == 0:
|
|
if item is None:
|
|
return p[0]
|
|
else:
|
|
p[0] = util.cast(item, dtype)
|
|
|
|
ans_shape = _new_shape(mindex, keep)
|
|
ans_strides = _new_strides(mindex, strides, keep)
|
|
sub = ndarray(ans_shape, ans_strides, p)
|
|
|
|
if item is None:
|
|
return sub
|
|
elif isinstance(item, ndarray) or isinstance(item, List) or isinstance(item, Tuple):
|
|
arr_item = asarray(item)
|
|
|
|
if staticlen(arr_item.shape) == 0:
|
|
_assign_one(sub, arr_item.item())
|
|
else:
|
|
arr_bcast = broadcast_to(arr_item, sub.shape)
|
|
if sub._contig_match(arr_bcast):
|
|
q = arr_bcast._data
|
|
for i in range(arr_bcast.size):
|
|
p[i] = util.cast(q[i], dtype)
|
|
else:
|
|
for idx in util.multirange(sub.shape):
|
|
a = sub._ptr(idx)
|
|
b = arr_bcast._ptr(idx)
|
|
a[0] = util.cast(b[0], dtype)
|
|
else:
|
|
_assign_one(sub, item)
|
|
|
|
@extend
|
|
class ndarray:
|
|
def __getitem__(self, indexes):
|
|
return _getset(self, indexes, None, dtype)
|
|
|
|
def __setitem__(self, indexes, item):
|
|
_getset(self, indexes, item=item, dtype=dtype)
|