1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/lib/arraysetops.codon
A. R. Shajii b8c1eeed36
2025 updates (#619)
* 2025 updates

* Update ci.yml
2025-01-29 15:41:43 -05:00

756 lines
18 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from ..ndarray import ndarray
from ..routines import empty, zeros, ones, asarray, broadcast_to, moveaxis, ascontiguousarray, rot90, atleast_1d
from ..sorting import sort, lexsort
from ..ndmath import isnan
from ..util import multirange, free, count, normalize_axis_index, coerce, cast
def _unique1d(ar,
return_index: Static[int] = False,
return_inverse: Static[int] = False,
return_counts: Static[int] = False,
equal_nan: bool = True):
n = ar.size
if return_index or return_inverse:
sort_kind = 'mergesort' if return_index else 'quicksort'
ar_flat = ar.ravel()
perm = ar_flat.argsort(kind=sort_kind)
else:
ar_flat = ar.flatten()
ar_flat.sort()
perm = None
def perm_get(perm, i: int):
if perm is None:
return 0
else:
return perm._ptr((i, ))[0]
def ar_get(ar, i: int, perm):
if perm is None:
return ar._ptr((i, ))[0]
else:
return ar._ptr((perm_get(perm, i), ))[0]
def is_equal(a, b, equal_nan: bool):
if equal_nan:
return (a == b) or (isnan(a) and isnan(b))
else:
return a == b
num_unique = 0
i = 0
j = 0
while i < n:
curr = ar_get(ar_flat, i, perm)
j = i + 1
while j < n and is_equal(curr, ar_get(ar_flat, j, perm), equal_nan):
j += 1
i = j
num_unique += 1
ret_unique = empty((num_unique, ), ar.dtype)
if return_index:
ret_index = empty((num_unique, ), int)
else:
ret_index = None
if return_inverse:
ret_inverse = empty((n, ), int)
else:
ret_inverse = None
if return_counts:
ret_counts = empty((num_unique, ), int)
else:
ret_counts = None
i = 0
j = 0
k = 0
while i < n:
curr = ar_get(ar_flat, i, perm)
if return_inverse:
ret_inverse.data[perm_get(perm, i)] = k
j = i + 1
while j < n and is_equal(curr, ar_get(ar_flat, j, perm), equal_nan):
if return_inverse:
ret_inverse.data[perm_get(perm, j)] = k
j += 1
ret_unique.data[k] = curr
if return_index:
ret_index.data[k] = perm_get(perm, i)
if return_counts:
ret_counts.data[k] = j - i
i = j
k += 1
if not (return_index or return_inverse or return_counts):
return ret_unique
ans0 = (ret_unique, )
if return_index:
ans1 = ans0 + (ret_index, )
else:
ans1 = ans0
if return_inverse:
ans2 = ans1 + (ret_inverse, )
else:
ans2 = ans1
if return_counts:
ans3 = ans2 + (ret_counts, )
else:
ans3 = ans2
return ans3
def unique(ar,
return_index: Static[int] = False,
return_inverse: Static[int] = False,
return_counts: Static[int] = False,
axis=None,
equal_nan: bool = True):
ar = atleast_1d(asarray(ar))
if axis is None:
return _unique1d(ar,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
equal_nan=equal_nan)
if not isinstance(axis, int):
compile_error("'axis' must be an int or None")
axis = normalize_axis_index(axis, ar.ndim)
ar = moveaxis(ar, axis, 0)
orig_shape = ar.shape
nrows = orig_shape[0]
ncols = count(orig_shape[1:])
ar = ar.reshape(nrows, ncols)
ar = ascontiguousarray(ar)
perm = lexsort(rot90(ar))
def perm_get(perm, i: int):
return perm._ptr((i, ))[0]
def ar_get(ar, i: int, ncols: int, perm):
i = perm_get(perm, i)
return ar.data + (i * ncols)
def is_equal(a, b, ncols: int, equal_nan: bool):
# NumPy ignores equal_nan in this case, so we do too
for i in range(ncols):
if not (a[i] == b[i]):
return False
return True
num_unique = 0
i = 0
j = 0
n = nrows
while i < n:
curr = ar_get(ar, i, ncols, perm)
j = i + 1
while j < n and is_equal(curr, ar_get(ar, j, ncols, perm), ncols,
equal_nan):
j += 1
i = j
num_unique += 1
ret_unique = empty((num_unique, ncols), ar.dtype)
if return_index:
ret_index = empty((num_unique, ), int)
else:
ret_index = None
if return_inverse:
ret_inverse = empty((n, ), int)
else:
ret_inverse = None
if return_counts:
ret_counts = empty((num_unique, ), int)
else:
ret_counts = None
i = 0
j = 0
k = 0
while i < n:
curr = ar_get(ar, i, ncols, perm)
if return_inverse:
ret_inverse.data[perm_get(perm, i)] = k
j = i + 1
while j < n and is_equal(curr, ar_get(ar, j, ncols, perm), ncols,
equal_nan):
if return_inverse:
ret_inverse.data[perm_get(perm, j)] = k
j += 1
p = ret_unique.data + (k * ncols)
for z in range(nrows):
p[z] = curr[z]
if return_index:
ret_index.data[k] = perm_get(perm, i)
if return_counts:
ret_counts.data[k] = j - i
i = j
k += 1
ret_unique = ret_unique.reshape(num_unique, *orig_shape[1:])
ret_unique = moveaxis(ret_unique, 0, axis)
if not (return_index or return_inverse or return_counts):
return ret_unique
ans0 = (ret_unique, )
if return_index:
ans1 = ans0 + (ret_index, )
else:
ans1 = ans0
if return_inverse:
ans2 = ans1 + (ret_inverse, )
else:
ans2 = ans1
if return_counts:
ans3 = ans2 + (ret_counts, )
else:
ans3 = ans2
return ans3
def _sorted(ar):
if isinstance(ar, ndarray):
x = ar.flatten()
x.sort()
return x
else:
x = asarray(ar).ravel()
x.sort()
return x
def _next_distinct(p, i: int, n: int, assume_unique: bool = False):
if assume_unique:
return i + 1
while True:
i += 1
if not (i < n and p[i - 1] == p[i]):
break
return i
def union1d(ar1, ar2):
v1 = _sorted(ar1)
v2 = _sorted(ar2)
dtype = type(coerce(v1.dtype, v2.dtype))
n1 = v1.size
n2 = v2.size
p1 = v1.data
p2 = v2.data
i1 = 0
i2 = 0
count = 0
while i1 < n1 and i2 < n2:
e1 = cast(p1[i1], dtype)
e2 = cast(p2[i2], dtype)
count += 1
if e1 < e2:
i1 = _next_distinct(p1, i1, n1)
elif e2 < e1:
i2 = _next_distinct(p2, i2, n2)
else:
i1 = _next_distinct(p1, i1, n1)
i2 = _next_distinct(p2, i2, n2)
while i1 < n1:
count += 1
i1 = _next_distinct(p1, i1, n1)
while i2 < n2:
count += 1
i2 = _next_distinct(p2, i2, n2)
ans = empty(count, dtype)
q = ans.data
j = 0
i1 = 0
i2 = 0
while i1 < n1 and i2 < n2:
e1 = cast(p1[i1], dtype)
e2 = cast(p2[i2], dtype)
if e1 < e2:
q[j] = e1
i1 = _next_distinct(p1, i1, n1)
elif e2 < e1:
q[j] = e2
i2 = _next_distinct(p2, i2, n2)
else:
q[j] = e1
i1 = _next_distinct(p1, i1, n1)
i2 = _next_distinct(p2, i2, n2)
j += 1
while i1 < n1:
e1 = cast(p1[i1], dtype)
q[j] = e1
j += 1
i1 = _next_distinct(p1, i1, n1)
while i2 < n2:
e2 = cast(p2[i2], dtype)
q[j] = e2
j += 1
i2 = _next_distinct(p2, i2, n2)
return ans
def setxor1d(ar1, ar2, assume_unique: bool = False):
v1 = _sorted(ar1)
v2 = _sorted(ar2)
dtype = type(coerce(v1.dtype, v2.dtype))
n1 = v1.size
n2 = v2.size
p1 = v1.data
p2 = v2.data
i1 = 0
i2 = 0
count = 0
while i1 < n1 and i2 < n2:
e1 = cast(p1[i1], dtype)
e2 = cast(p2[i2], dtype)
if e1 < e2:
count += 1
i1 = _next_distinct(p1, i1, n1, assume_unique)
elif e2 < e1:
count += 1
i2 = _next_distinct(p2, i2, n2, assume_unique)
else:
i1 = _next_distinct(p1, i1, n1, assume_unique)
i2 = _next_distinct(p2, i2, n2, assume_unique)
while i1 < n1:
count += 1
i1 = _next_distinct(p1, i1, n1, assume_unique)
while i2 < n2:
count += 1
i2 = _next_distinct(p2, i2, n2, assume_unique)
ans = empty(count, dtype)
q = ans.data
j = 0
i1 = 0
i2 = 0
while i1 < n1 and i2 < n2:
e1 = cast(p1[i1], dtype)
e2 = cast(p2[i2], dtype)
if e1 < e2:
q[j] = e1
j += 1
i1 = _next_distinct(p1, i1, n1, assume_unique)
elif e2 < e1:
q[j] = e2
j += 1
i2 = _next_distinct(p2, i2, n2, assume_unique)
else:
i1 = _next_distinct(p1, i1, n1, assume_unique)
i2 = _next_distinct(p2, i2, n2, assume_unique)
while i1 < n1:
e1 = cast(p1[i1], dtype)
q[j] = e1
j += 1
i1 = _next_distinct(p1, i1, n1, assume_unique)
while i2 < n2:
e2 = cast(p2[i2], dtype)
q[j] = e2
j += 1
i2 = _next_distinct(p2, i2, n2, assume_unique)
return ans
def setdiff1d(ar1, ar2, assume_unique: bool = False):
v1 = _sorted(ar1)
v2 = _sorted(ar2)
dtype = type(coerce(v1.dtype, v2.dtype))
n1 = v1.size
n2 = v2.size
p1 = v1.data
p2 = v2.data
i1 = 0
i2 = 0
count = 0
while i1 < n1 and i2 < n2:
e1 = cast(p1[i1], dtype)
e2 = cast(p2[i2], dtype)
if e1 < e2:
count += 1
i1 = _next_distinct(p1, i1, n1, assume_unique)
elif e2 < e1:
i2 = _next_distinct(p2, i2, n2, assume_unique)
else:
i1 = _next_distinct(p1, i1, n1, assume_unique)
i2 = _next_distinct(p2, i2, n2, assume_unique)
while i1 < n1:
count += 1
i1 = _next_distinct(p1, i1, n1, assume_unique)
ans = empty(count, dtype)
q = ans.data
j = 0
i1 = 0
i2 = 0
while i1 < n1 and i2 < n2:
e1 = cast(p1[i1], dtype)
e2 = cast(p2[i2], dtype)
if e1 < e2:
q[j] = e1
j += 1
i1 = _next_distinct(p1, i1, n1, assume_unique)
elif e2 < e1:
i2 = _next_distinct(p2, i2, n2, assume_unique)
else:
i1 = _next_distinct(p1, i1, n1, assume_unique)
i2 = _next_distinct(p2, i2, n2, assume_unique)
while i1 < n1:
e1 = cast(p1[i1], dtype)
q[j] = e1
j += 1
i1 = _next_distinct(p1, i1, n1, assume_unique)
return ans
def _intersect1d(ar1, ar2, assume_unique: bool = False):
v1 = _sorted(ar1)
v2 = _sorted(ar2)
dtype = type(coerce(v1.dtype, v2.dtype))
n1 = v1.size
n2 = v2.size
p1 = v1.data
p2 = v2.data
i1 = 0
i2 = 0
count = 0
while i1 < n1 and i2 < n2:
if not assume_unique and i1 > 0 and p1[i1] == p1[i1 - 1]:
i1 += 1
continue
e1 = cast(p1[i1], dtype)
e2 = cast(p2[i2], dtype)
if e1 < e2:
i1 += 1
elif e2 < e1:
i2 += 1
else:
count += 1
i1 += 1
i2 += 1
ans = empty(count, dtype)
q = ans.data
j = 0
i1 = 0
i2 = 0
while i1 < n1 and i2 < n2:
if i1 > 0 and p1[i1] == p1[i1 - 1]:
i1 += 1
continue
e1 = cast(p1[i1], dtype)
e2 = cast(p2[i2], dtype)
if e1 < e2:
i1 += 1
elif e2 < e1:
i2 += 1
else:
q[j] = e1
j += 1
i1 += 1
i2 += 1
return ans
def _intersect1d_indices(ar1, ar2, assume_unique: bool = False):
v1 = asarray(ar1).ravel()
v2 = asarray(ar2).ravel()
perm1 = v1.argsort().data # argsort return array should always be contiguous
perm2 = v2.argsort().data
dtype = type(coerce(v1.dtype, v2.dtype))
n1 = v1.size
n2 = v2.size
p1 = v1.data
p2 = v2.data
i1 = 0
i2 = 0
count = 0
while i1 < n1 and i2 < n2:
if not assume_unique and i1 > 0 and p1[perm1[i1]] == p1[perm1[i1 - 1]]:
i1 += 1
continue
e1 = cast(p1[perm1[i1]], dtype)
e2 = cast(p2[perm2[i2]], dtype)
if e1 < e2:
i1 += 1
elif e2 < e1:
i2 += 1
else:
count += 1
i1 += 1
i2 += 1
ans = empty(count, dtype)
idx1 = empty(count, int)
idx2 = empty(count, int)
q = ans.data
j = 0
i1 = 0
i2 = 0
while i1 < n1 and i2 < n2:
if i1 > 0 and p1[perm1[i1]] == p1[perm1[i1 - 1]]:
i1 += 1
continue
e1 = cast(p1[perm1[i1]], dtype)
e2 = cast(p2[perm2[i2]], dtype)
if e1 < e2:
i1 += 1
elif e2 < e1:
i2 += 1
else:
q[j] = e1
idx1.data[j] = perm1[i1]
idx2.data[j] = perm2[i2]
j += 1
i1 += 1
i2 += 1
return ans, idx1, idx2
def intersect1d(ar1,
ar2,
assume_unique: bool = False,
return_indices: Static[int] = False):
if return_indices:
return _intersect1d_indices(ar1, ar2, assume_unique=assume_unique)
else:
return _intersect1d(ar1, ar2, assume_unique=assume_unique)
def isin(element,
test_elements,
assume_unique: bool = False,
invert: bool = False,
kind: Optional[str] = None):
def table_method_ok(dtype: type):
if dtype is int or dtype is byte or dtype is bool:
return True
if isinstance(dtype, Int) or isinstance(dtype, UInt):
return dtype.N <= 64
return False
def bitset_get(x, x_min, x_max, bitset: Ptr[u64]):
if x < x_min or x > x_max:
return False
pos = int(x) - int(x_min)
offset = pos >> 6
return bool(bitset[offset] & (u64(1) << (u64(pos) & u64(63))))
def bitset_set(x, x_min, x_max, bitset: Ptr[u64]):
pos = int(x) - int(x_min)
offset = pos >> 6
bitset[offset] |= (u64(1) << (u64(pos) & u64(63)))
def array_min_max(ar, dtype: type):
x_min = dtype()
x_max = dtype()
first = True
for idx in multirange(ar.shape):
x = cast(ar._ptr(idx)[0], dtype)
if first or x < x_min:
x_min = x
if first or x > x_max:
x_max = x
first = False
return x_min, x_max
def binsearch(arr, n: int, x: dtype, dtype: type):
low, high, mid = 0, n - 1, 0
while low <= high:
mid = (high + low) >> 1
if cast(arr[mid], dtype) < x:
low = mid + 1
elif cast(arr[mid], dtype) > x:
high = mid - 1
else:
return True
return False
ar1 = atleast_1d(asarray(element))
ar2 = atleast_1d(asarray(test_elements))
dtype = type(coerce(ar1.dtype, ar2.dtype))
use_table_method = False
kind_is_table = False
if kind is None:
use_table_method = table_method_ok(dtype)
elif kind == "table":
use_table_method = table_method_ok(dtype)
kind_is_table = True
elif kind == "sort":
use_table_method = False
else:
raise ValueError(
f"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.")
if ar1.size == 0:
return empty(ar1.shape, bool)
if ar2.size == 0:
if invert:
return ones(ar1.shape, bool)
else:
return zeros(ar1.shape, bool)
if use_table_method:
ar2_min, ar2_max = array_min_max(ar2, dtype)
# check for u64 to make sure we don't overflow
if dtype is u64:
ar2_range = u64(ar2_max) - u64(ar2_min) + u64(1)
else:
ar2_range = int(ar2_max) - int(ar2_min) + 1
R = type(ar2_range)
below_memory_constraint = ar2_range <= R(6 * (ar1.size + ar2.size))
if below_memory_constraint or kind_is_table:
bitset_size = int(ar2_range >> R(6))
if ar2_range & R(63):
bitset_size += 1
bitset = Ptr[u64](bitset_size)
for i in range(bitset_size):
bitset[i] = u64(0)
for idx in multirange(ar2.shape):
x2 = cast(ar2._ptr(idx)[0], dtype)
bitset_set(x2, ar2_min, ar2_max, bitset)
ans = empty(ar1.shape, bool)
for idx in multirange(ar1.shape):
x1 = cast(ar1._ptr(idx)[0], dtype)
p = ans._ptr(idx)
found = bitset_get(x1, ar2_min, ar2_max, bitset)
p[0] = (not found) if invert else found
free(bitset)
return ans
elif kind_is_table:
raise RuntimeError(
"You have specified kind='table', "
"but the range of values in `ar2` or `ar1` exceed the "
"maximum integer of the datatype. "
"Please set `kind` to None or 'sort'.")
elif kind_is_table:
raise ValueError("The 'table' method is only "
"supported for boolean or integer arrays. "
"Please select 'sort' or None for kind.")
if ar2.size < 10 * ar1.size**0.145:
mask = empty(ar1.shape, bool)
for idx1 in multirange(ar1.shape):
x1 = cast(ar1._ptr(idx1)[0], dtype)
found = False
for idx2 in multirange(ar2.shape):
x2 = cast(ar2._ptr(idx2)[0], dtype)
if x1 == x2:
found = True
break
mask._ptr(idx1)[0] = (not found) if invert else found
return mask
s = sort(ar2)
ans = empty(ar1.shape, bool)
for idx1 in multirange(ar1.shape):
x1 = cast(ar1._ptr(idx1)[0], dtype)
found = binsearch(s.data, s.size, x1)
ans._ptr(idx1)[0] = (not found) if invert else found
return ans
def in1d(ar1,
ar2,
assume_unique: bool = False,
invert: bool = False,
kind: Optional[str] = None):
return isin(ar1,
ar2,
assume_unique=assume_unique,
invert=invert,
kind=kind).ravel()