# Copyright (C) 2022-2025 Exaloop Inc. from ..ndarray import ndarray from ..routines import empty, zeros, ones, asarray, broadcast_to, moveaxis, ascontiguousarray, rot90, atleast_1d from ..sorting import sort, lexsort from ..ndmath import isnan from ..util import multirange, free, count, normalize_axis_index, coerce, cast def _unique1d(ar, return_index: Static[int] = False, return_inverse: Static[int] = False, return_counts: Static[int] = False, equal_nan: bool = True): n = ar.size if return_index or return_inverse: sort_kind = 'mergesort' if return_index else 'quicksort' ar_flat = ar.ravel() perm = ar_flat.argsort(kind=sort_kind) else: ar_flat = ar.flatten() ar_flat.sort() perm = None def perm_get(perm, i: int): if perm is None: return 0 else: return perm._ptr((i, ))[0] def ar_get(ar, i: int, perm): if perm is None: return ar._ptr((i, ))[0] else: return ar._ptr((perm_get(perm, i), ))[0] def is_equal(a, b, equal_nan: bool): if equal_nan: return (a == b) or (isnan(a) and isnan(b)) else: return a == b num_unique = 0 i = 0 j = 0 while i < n: curr = ar_get(ar_flat, i, perm) j = i + 1 while j < n and is_equal(curr, ar_get(ar_flat, j, perm), equal_nan): j += 1 i = j num_unique += 1 ret_unique = empty((num_unique, ), ar.dtype) if return_index: ret_index = empty((num_unique, ), int) else: ret_index = None if return_inverse: ret_inverse = empty((n, ), int) else: ret_inverse = None if return_counts: ret_counts = empty((num_unique, ), int) else: ret_counts = None i = 0 j = 0 k = 0 while i < n: curr = ar_get(ar_flat, i, perm) if return_inverse: ret_inverse.data[perm_get(perm, i)] = k j = i + 1 while j < n and is_equal(curr, ar_get(ar_flat, j, perm), equal_nan): if return_inverse: ret_inverse.data[perm_get(perm, j)] = k j += 1 ret_unique.data[k] = curr if return_index: ret_index.data[k] = perm_get(perm, i) if return_counts: ret_counts.data[k] = j - i i = j k += 1 if not (return_index or return_inverse or return_counts): return ret_unique ans0 = (ret_unique, ) if return_index: ans1 = ans0 + (ret_index, ) else: ans1 = ans0 if return_inverse: ans2 = ans1 + (ret_inverse, ) else: ans2 = ans1 if return_counts: ans3 = ans2 + (ret_counts, ) else: ans3 = ans2 return ans3 def unique(ar, return_index: Static[int] = False, return_inverse: Static[int] = False, return_counts: Static[int] = False, axis=None, equal_nan: bool = True): ar = atleast_1d(asarray(ar)) if axis is None: return _unique1d(ar, return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, equal_nan=equal_nan) if not isinstance(axis, int): compile_error("'axis' must be an int or None") axis = normalize_axis_index(axis, ar.ndim) ar = moveaxis(ar, axis, 0) orig_shape = ar.shape nrows = orig_shape[0] ncols = count(orig_shape[1:]) ar = ar.reshape(nrows, ncols) ar = ascontiguousarray(ar) perm = lexsort(rot90(ar)) def perm_get(perm, i: int): return perm._ptr((i, ))[0] def ar_get(ar, i: int, ncols: int, perm): i = perm_get(perm, i) return ar.data + (i * ncols) def is_equal(a, b, ncols: int, equal_nan: bool): # NumPy ignores equal_nan in this case, so we do too for i in range(ncols): if not (a[i] == b[i]): return False return True num_unique = 0 i = 0 j = 0 n = nrows while i < n: curr = ar_get(ar, i, ncols, perm) j = i + 1 while j < n and is_equal(curr, ar_get(ar, j, ncols, perm), ncols, equal_nan): j += 1 i = j num_unique += 1 ret_unique = empty((num_unique, ncols), ar.dtype) if return_index: ret_index = empty((num_unique, ), int) else: ret_index = None if return_inverse: ret_inverse = empty((n, ), int) else: ret_inverse = None if return_counts: ret_counts = empty((num_unique, ), int) else: ret_counts = None i = 0 j = 0 k = 0 while i < n: curr = ar_get(ar, i, ncols, perm) if return_inverse: ret_inverse.data[perm_get(perm, i)] = k j = i + 1 while j < n and is_equal(curr, ar_get(ar, j, ncols, perm), ncols, equal_nan): if return_inverse: ret_inverse.data[perm_get(perm, j)] = k j += 1 p = ret_unique.data + (k * ncols) for z in range(nrows): p[z] = curr[z] if return_index: ret_index.data[k] = perm_get(perm, i) if return_counts: ret_counts.data[k] = j - i i = j k += 1 ret_unique = ret_unique.reshape(num_unique, *orig_shape[1:]) ret_unique = moveaxis(ret_unique, 0, axis) if not (return_index or return_inverse or return_counts): return ret_unique ans0 = (ret_unique, ) if return_index: ans1 = ans0 + (ret_index, ) else: ans1 = ans0 if return_inverse: ans2 = ans1 + (ret_inverse, ) else: ans2 = ans1 if return_counts: ans3 = ans2 + (ret_counts, ) else: ans3 = ans2 return ans3 def _sorted(ar): if isinstance(ar, ndarray): x = ar.flatten() x.sort() return x else: x = asarray(ar).ravel() x.sort() return x def _next_distinct(p, i: int, n: int, assume_unique: bool = False): if assume_unique: return i + 1 while True: i += 1 if not (i < n and p[i - 1] == p[i]): break return i def union1d(ar1, ar2): v1 = _sorted(ar1) v2 = _sorted(ar2) dtype = type(coerce(v1.dtype, v2.dtype)) n1 = v1.size n2 = v2.size p1 = v1.data p2 = v2.data i1 = 0 i2 = 0 count = 0 while i1 < n1 and i2 < n2: e1 = cast(p1[i1], dtype) e2 = cast(p2[i2], dtype) count += 1 if e1 < e2: i1 = _next_distinct(p1, i1, n1) elif e2 < e1: i2 = _next_distinct(p2, i2, n2) else: i1 = _next_distinct(p1, i1, n1) i2 = _next_distinct(p2, i2, n2) while i1 < n1: count += 1 i1 = _next_distinct(p1, i1, n1) while i2 < n2: count += 1 i2 = _next_distinct(p2, i2, n2) ans = empty(count, dtype) q = ans.data j = 0 i1 = 0 i2 = 0 while i1 < n1 and i2 < n2: e1 = cast(p1[i1], dtype) e2 = cast(p2[i2], dtype) if e1 < e2: q[j] = e1 i1 = _next_distinct(p1, i1, n1) elif e2 < e1: q[j] = e2 i2 = _next_distinct(p2, i2, n2) else: q[j] = e1 i1 = _next_distinct(p1, i1, n1) i2 = _next_distinct(p2, i2, n2) j += 1 while i1 < n1: e1 = cast(p1[i1], dtype) q[j] = e1 j += 1 i1 = _next_distinct(p1, i1, n1) while i2 < n2: e2 = cast(p2[i2], dtype) q[j] = e2 j += 1 i2 = _next_distinct(p2, i2, n2) return ans def setxor1d(ar1, ar2, assume_unique: bool = False): v1 = _sorted(ar1) v2 = _sorted(ar2) dtype = type(coerce(v1.dtype, v2.dtype)) n1 = v1.size n2 = v2.size p1 = v1.data p2 = v2.data i1 = 0 i2 = 0 count = 0 while i1 < n1 and i2 < n2: e1 = cast(p1[i1], dtype) e2 = cast(p2[i2], dtype) if e1 < e2: count += 1 i1 = _next_distinct(p1, i1, n1, assume_unique) elif e2 < e1: count += 1 i2 = _next_distinct(p2, i2, n2, assume_unique) else: i1 = _next_distinct(p1, i1, n1, assume_unique) i2 = _next_distinct(p2, i2, n2, assume_unique) while i1 < n1: count += 1 i1 = _next_distinct(p1, i1, n1, assume_unique) while i2 < n2: count += 1 i2 = _next_distinct(p2, i2, n2, assume_unique) ans = empty(count, dtype) q = ans.data j = 0 i1 = 0 i2 = 0 while i1 < n1 and i2 < n2: e1 = cast(p1[i1], dtype) e2 = cast(p2[i2], dtype) if e1 < e2: q[j] = e1 j += 1 i1 = _next_distinct(p1, i1, n1, assume_unique) elif e2 < e1: q[j] = e2 j += 1 i2 = _next_distinct(p2, i2, n2, assume_unique) else: i1 = _next_distinct(p1, i1, n1, assume_unique) i2 = _next_distinct(p2, i2, n2, assume_unique) while i1 < n1: e1 = cast(p1[i1], dtype) q[j] = e1 j += 1 i1 = _next_distinct(p1, i1, n1, assume_unique) while i2 < n2: e2 = cast(p2[i2], dtype) q[j] = e2 j += 1 i2 = _next_distinct(p2, i2, n2, assume_unique) return ans def setdiff1d(ar1, ar2, assume_unique: bool = False): v1 = _sorted(ar1) v2 = _sorted(ar2) dtype = type(coerce(v1.dtype, v2.dtype)) n1 = v1.size n2 = v2.size p1 = v1.data p2 = v2.data i1 = 0 i2 = 0 count = 0 while i1 < n1 and i2 < n2: e1 = cast(p1[i1], dtype) e2 = cast(p2[i2], dtype) if e1 < e2: count += 1 i1 = _next_distinct(p1, i1, n1, assume_unique) elif e2 < e1: i2 = _next_distinct(p2, i2, n2, assume_unique) else: i1 = _next_distinct(p1, i1, n1, assume_unique) i2 = _next_distinct(p2, i2, n2, assume_unique) while i1 < n1: count += 1 i1 = _next_distinct(p1, i1, n1, assume_unique) ans = empty(count, dtype) q = ans.data j = 0 i1 = 0 i2 = 0 while i1 < n1 and i2 < n2: e1 = cast(p1[i1], dtype) e2 = cast(p2[i2], dtype) if e1 < e2: q[j] = e1 j += 1 i1 = _next_distinct(p1, i1, n1, assume_unique) elif e2 < e1: i2 = _next_distinct(p2, i2, n2, assume_unique) else: i1 = _next_distinct(p1, i1, n1, assume_unique) i2 = _next_distinct(p2, i2, n2, assume_unique) while i1 < n1: e1 = cast(p1[i1], dtype) q[j] = e1 j += 1 i1 = _next_distinct(p1, i1, n1, assume_unique) return ans def _intersect1d(ar1, ar2, assume_unique: bool = False): v1 = _sorted(ar1) v2 = _sorted(ar2) dtype = type(coerce(v1.dtype, v2.dtype)) n1 = v1.size n2 = v2.size p1 = v1.data p2 = v2.data i1 = 0 i2 = 0 count = 0 while i1 < n1 and i2 < n2: if not assume_unique and i1 > 0 and p1[i1] == p1[i1 - 1]: i1 += 1 continue e1 = cast(p1[i1], dtype) e2 = cast(p2[i2], dtype) if e1 < e2: i1 += 1 elif e2 < e1: i2 += 1 else: count += 1 i1 += 1 i2 += 1 ans = empty(count, dtype) q = ans.data j = 0 i1 = 0 i2 = 0 while i1 < n1 and i2 < n2: if i1 > 0 and p1[i1] == p1[i1 - 1]: i1 += 1 continue e1 = cast(p1[i1], dtype) e2 = cast(p2[i2], dtype) if e1 < e2: i1 += 1 elif e2 < e1: i2 += 1 else: q[j] = e1 j += 1 i1 += 1 i2 += 1 return ans def _intersect1d_indices(ar1, ar2, assume_unique: bool = False): v1 = asarray(ar1).ravel() v2 = asarray(ar2).ravel() perm1 = v1.argsort().data # argsort return array should always be contiguous perm2 = v2.argsort().data dtype = type(coerce(v1.dtype, v2.dtype)) n1 = v1.size n2 = v2.size p1 = v1.data p2 = v2.data i1 = 0 i2 = 0 count = 0 while i1 < n1 and i2 < n2: if not assume_unique and i1 > 0 and p1[perm1[i1]] == p1[perm1[i1 - 1]]: i1 += 1 continue e1 = cast(p1[perm1[i1]], dtype) e2 = cast(p2[perm2[i2]], dtype) if e1 < e2: i1 += 1 elif e2 < e1: i2 += 1 else: count += 1 i1 += 1 i2 += 1 ans = empty(count, dtype) idx1 = empty(count, int) idx2 = empty(count, int) q = ans.data j = 0 i1 = 0 i2 = 0 while i1 < n1 and i2 < n2: if i1 > 0 and p1[perm1[i1]] == p1[perm1[i1 - 1]]: i1 += 1 continue e1 = cast(p1[perm1[i1]], dtype) e2 = cast(p2[perm2[i2]], dtype) if e1 < e2: i1 += 1 elif e2 < e1: i2 += 1 else: q[j] = e1 idx1.data[j] = perm1[i1] idx2.data[j] = perm2[i2] j += 1 i1 += 1 i2 += 1 return ans, idx1, idx2 def intersect1d(ar1, ar2, assume_unique: bool = False, return_indices: Static[int] = False): if return_indices: return _intersect1d_indices(ar1, ar2, assume_unique=assume_unique) else: return _intersect1d(ar1, ar2, assume_unique=assume_unique) def isin(element, test_elements, assume_unique: bool = False, invert: bool = False, kind: Optional[str] = None): def table_method_ok(dtype: type): if dtype is int or dtype is byte or dtype is bool: return True if isinstance(dtype, Int) or isinstance(dtype, UInt): return dtype.N <= 64 return False def bitset_get(x, x_min, x_max, bitset: Ptr[u64]): if x < x_min or x > x_max: return False pos = int(x) - int(x_min) offset = pos >> 6 return bool(bitset[offset] & (u64(1) << (u64(pos) & u64(63)))) def bitset_set(x, x_min, x_max, bitset: Ptr[u64]): pos = int(x) - int(x_min) offset = pos >> 6 bitset[offset] |= (u64(1) << (u64(pos) & u64(63))) def array_min_max(ar, dtype: type): x_min = dtype() x_max = dtype() first = True for idx in multirange(ar.shape): x = cast(ar._ptr(idx)[0], dtype) if first or x < x_min: x_min = x if first or x > x_max: x_max = x first = False return x_min, x_max def binsearch(arr, n: int, x: dtype, dtype: type): low, high, mid = 0, n - 1, 0 while low <= high: mid = (high + low) >> 1 if cast(arr[mid], dtype) < x: low = mid + 1 elif cast(arr[mid], dtype) > x: high = mid - 1 else: return True return False ar1 = atleast_1d(asarray(element)) ar2 = atleast_1d(asarray(test_elements)) dtype = type(coerce(ar1.dtype, ar2.dtype)) use_table_method = False kind_is_table = False if kind is None: use_table_method = table_method_ok(dtype) elif kind == "table": use_table_method = table_method_ok(dtype) kind_is_table = True elif kind == "sort": use_table_method = False else: raise ValueError( f"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.") if ar1.size == 0: return empty(ar1.shape, bool) if ar2.size == 0: if invert: return ones(ar1.shape, bool) else: return zeros(ar1.shape, bool) if use_table_method: ar2_min, ar2_max = array_min_max(ar2, dtype) # check for u64 to make sure we don't overflow if dtype is u64: ar2_range = u64(ar2_max) - u64(ar2_min) + u64(1) else: ar2_range = int(ar2_max) - int(ar2_min) + 1 R = type(ar2_range) below_memory_constraint = ar2_range <= R(6 * (ar1.size + ar2.size)) if below_memory_constraint or kind_is_table: bitset_size = int(ar2_range >> R(6)) if ar2_range & R(63): bitset_size += 1 bitset = Ptr[u64](bitset_size) for i in range(bitset_size): bitset[i] = u64(0) for idx in multirange(ar2.shape): x2 = cast(ar2._ptr(idx)[0], dtype) bitset_set(x2, ar2_min, ar2_max, bitset) ans = empty(ar1.shape, bool) for idx in multirange(ar1.shape): x1 = cast(ar1._ptr(idx)[0], dtype) p = ans._ptr(idx) found = bitset_get(x1, ar2_min, ar2_max, bitset) p[0] = (not found) if invert else found free(bitset) return ans elif kind_is_table: raise RuntimeError( "You have specified kind='table', " "but the range of values in `ar2` or `ar1` exceed the " "maximum integer of the datatype. " "Please set `kind` to None or 'sort'.") elif kind_is_table: raise ValueError("The 'table' method is only " "supported for boolean or integer arrays. " "Please select 'sort' or None for kind.") if ar2.size < 10 * ar1.size**0.145: mask = empty(ar1.shape, bool) for idx1 in multirange(ar1.shape): x1 = cast(ar1._ptr(idx1)[0], dtype) found = False for idx2 in multirange(ar2.shape): x2 = cast(ar2._ptr(idx2)[0], dtype) if x1 == x2: found = True break mask._ptr(idx1)[0] = (not found) if invert else found return mask s = sort(ar2) ans = empty(ar1.shape, bool) for idx1 in multirange(ar1.shape): x1 = cast(ar1._ptr(idx1)[0], dtype) found = binsearch(s.data, s.size, x1) ans._ptr(idx1)[0] = (not found) if invert else found return ans def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False, kind: Optional[str] = None): return isin(ar1, ar2, assume_unique=assume_unique, invert=invert, kind=kind).ravel()