# Copyright (C) 2022-2025 Exaloop Inc. from .ndarray import ndarray from .routines import array, asarray, empty, zeros import util ########## # Common # ########## def _less(a: T, b: T, T: type): if T is float or T is float32: return a < b or (b != b and a == a) elif T is complex or T is complex64: if a.real < b.real: return a.imag == a.imag or b.imag != b.imag elif a.real > b.real: return b.imag != b.imag and a.imag == a.imag elif a.real == b.real or (a.real != a.real and b.real != b.real): return a.imag < b.imag or (b.imag != b.imag and a.imag == a.imag) else: return b.real != b.real else: return a < b ############# # Insertion # ############# def insertionsort(start: Ptr[T], n: int, T: type): i = 1 while i < n: x = start[i] j = i - 1 while j >= 0 and _less(x, start[j]): start[j + 1] = start[j] j -= 1 start[j + 1] = x i += 1 def ainsertionsort(v: Ptr[T], tosort: Ptr[int], n: int, T: type): i = 1 while i < n: x = tosort[i] j = i - 1 while j >= 0 and _less(v[x], v[tosort[j]]): tosort[j + 1] = tosort[j] j -= 1 tosort[j + 1] = x i += 1 ######## # Heap # ######## def heapsort(start: Ptr[T], n: int, T: type): a = start - 1 for l in range(n >> 1, 0, -1): tmp = a[l] i = l j = l << 1 while j <= n: if j < n and _less(a[j], a[j + 1]): j += 1 if _less(tmp, a[j]): a[i] = a[j] i = j j += j else: break a[i] = tmp while n > 1: tmp = a[n] a[n] = a[1] n -= 1 i = 1 j = 2 while j <= n: if j < n and _less(a[j], a[j + 1]): j += 1 if _less(tmp, a[j]): a[i] = a[j] i = j j += j else: break a[i] = tmp def aheapsort(vv: Ptr[T], tosort: Ptr[int], n: int, T: type): v = vv a = tosort - 1 for l in range(n >> 1, 0, -1): tmp = a[l] i = l j = l << 1 while j <= n: if j < n and _less(v[a[j]], v[a[j + 1]]): j += 1 if _less(v[tmp], v[a[j]]): a[i] = a[j] i = j j += j else: break a[i] = tmp while n > 1: tmp = a[n] a[n] = a[1] n -= 1 i = 1 j = 2 while j <= n: if j < n and _less(v[a[j]], v[a[j + 1]]): j += 1 if _less(v[tmp], v[a[j]]): a[i] = a[j] i = j j += j else: break a[i] = tmp ######### # Merge # ######### PYA_QS_STACK : Static[int] = 100 SMALL_QUICKSORT: Static[int] = 15 SMALL_MERGESORT: Static[int] = 20 SMALL_STRING : Static[int] = 16 def _mergesort0(pl: Ptr[T], pr: Ptr[T], pw: Ptr[T], T: type): if pr - pl > SMALL_MERGESORT: pm = pl + ((pr - pl) >> 1) _mergesort0(pl, pm, pw) _mergesort0(pm, pr, pw) pi = pw pj = pl while pj < pm: pi[0] = pj[0] pi += 1 pj += 1 pi = pw + (pm - pl) pj = pw pk = pl while pj < pi and pm < pr: if _less(pm[0], pj[0]): pk[0] = pm[0] pk += 1 pm += 1 else: pk[0] = pj[0] pk += 1 pj += 1 while pj < pi: pk[0] = pj[0] pk += 1 pj += 1 else: pi = pl + 1 while pi < pr: vp = pi[0] pj = pi pk = pi - 1 while pj > pl and _less(vp, pk[0]): pj[0] = pk[0] pj -= 1 pk -= 1 pj[0] = vp pi += 1 def mergesort(start: Ptr[T], num: int, T: type): pl = start pr = pl + num pw = Ptr[T](num // 2) _mergesort0(pl, pr, pw) util.free(pw) def _amergesort0(pl: Ptr[int], pr: Ptr[int], v: Ptr[T], pw: Ptr[int], T: type): if pr - pl > SMALL_MERGESORT: pm = pl + ((pr - pl) >> 1) _amergesort0(pl, pm, v, pw) _amergesort0(pm, pr, v, pw) pi = pw pj = pl while pj < pm: pi[0] = pj[0] pi += 1 pj += 1 pi = pw + (pm - pl) pj = pw pk = pl while pj < pi and pm < pr: if _less(v[pm[0]], v[pj[0]]): pk[0] = pm[0] pk += 1 pm += 1 else: pk[0] = pj[0] pk += 1 pj += 1 while pj < pi: pk[0] = pj[0] pk += 1 pj += 1 else: pi = pl + 1 while pi < pr: vi = pi[0] vp = v[vi] pj = pi pk = pi - 1 while pj > pl and _less(vp, v[pk[0]]): pj[0] = pk[0] pj -= 1 pk -= 1 pj[0] = vi pi += 1 def amergesort(v: Ptr[T], tosort: Ptr[int], num: int, T: type): pl = tosort pr = pl + num pw = Ptr[int](num // 2) _amergesort0(pl, pr, v, pw) util.free(pw) ############### # Quick (PDQ) # ############### INSERTION_SORT_THRESHOLD : Static[int] = 24 NINTHER_THRESHOLD : Static[int] = 128 PARTIAL_INSERTION_SORT_LIMIT: Static[int] = 8 def _floor_log2(n: int): log = 0 while True: n >>= 1 if n == 0: break log += 1 return log def _partial_insertion_sort(arr: Ptr[T], begin: int, end: int, T: type): if begin == end: return True limit = 0 cur = begin + 1 while cur != end: if limit > PARTIAL_INSERTION_SORT_LIMIT: return False sift = cur sift_1 = cur - 1 if _less(arr[sift], arr[sift_1]): tmp = arr[sift] while True: arr[sift] = arr[sift_1] sift -= 1 sift_1 -= 1 if sift == begin or not _less(tmp, arr[sift_1]): break arr[sift] = tmp limit += cur - sift cur += 1 return True def _partition_left(arr: Ptr[T], begin: int, end: int, T: type): pivot = arr[begin] first = begin last = end while True: last -= 1 if not _less(pivot, arr[last]): break if last + 1 == end: while first < last: first += 1 if _less(pivot, arr[first]): break else: while True: first += 1 if _less(pivot, arr[first]): break while first < last: arr[first], arr[last] = arr[last], arr[first] while True: last -= 1 if not _less(pivot, arr[last]): break while True: first += 1 if _less(pivot, arr[first]): break pivot_pos = last arr[begin] = arr[pivot_pos] arr[pivot_pos] = pivot return pivot_pos def _partition_right(arr: Ptr[T], begin: int, end: int, T: type): pivot = arr[begin] first = begin last = end while True: first += 1 if not _less(arr[first], pivot): break if first - 1 == begin: while first < last: last -= 1 if _less(arr[last], pivot): break else: while True: last -= 1 if _less(arr[last], pivot): break already_partitioned = 0 if first >= last: already_partitioned = 1 while first < last: arr[first], arr[last] = arr[last], arr[first] while True: first += 1 if not _less(arr[first], pivot): break while True: last -= 1 if _less(arr[last], pivot): break pivot_pos = first - 1 arr[begin] = arr[pivot_pos] arr[pivot_pos] = pivot return (pivot_pos, already_partitioned) def _sort2(arr: Ptr[T], i: int, j: int, T: type): if _less(arr[j], arr[i]): arr[i], arr[j] = arr[j], arr[i] def _sort3(arr: Ptr[T], i: int, j: int, k: int, T: type): _sort2(arr, i, j) _sort2(arr, j, k) _sort2(arr, i, j) def _pdq_sort( arr: Ptr[T], begin: int, end: int, bad_allowed: int, leftmost: bool, T: type ): while True: size = end - begin if size < INSERTION_SORT_THRESHOLD: insertionsort(arr + begin, size) return size_2 = size // 2 if size > NINTHER_THRESHOLD: _sort3(arr, begin, begin + size_2, end - 1) _sort3(arr, begin + 1, begin + (size_2 - 1), end - 2) _sort3(arr, begin + 2, begin + (size_2 + 1), end - 3) _sort3( arr, begin + (size_2 - 1), begin + size_2, begin + (size_2 + 1) ) arr[begin], arr[begin + size_2] = arr[begin + size_2], arr[begin] else: _sort3(arr, begin + size_2, begin, end - 1) if not leftmost and not _less(arr[begin - 1], arr[begin]): begin = _partition_left(arr, begin, end) + 1 continue part_result = _partition_right(arr, begin, end) pivot_pos = part_result[0] already_partitioned = part_result[1] == 1 l_size = pivot_pos - begin r_size = end - (pivot_pos + 1) highly_unbalanced = (l_size < (size // 8)) or (r_size < (size // 8)) if highly_unbalanced: bad_allowed -= 1 if bad_allowed == 0: heapsort(arr + begin, end - begin) return if l_size >= INSERTION_SORT_THRESHOLD: arr[begin], arr[begin + l_size // 4] = ( arr[begin + l_size // 4], arr[begin], ) arr[pivot_pos - 1], arr[pivot_pos - l_size // 4] = ( arr[pivot_pos - l_size // 4], arr[pivot_pos - 1], ) if l_size > NINTHER_THRESHOLD: arr[begin + 1], arr[begin + (l_size // 4 + 1)] = ( arr[begin + (l_size // 4 + 1)], arr[begin + 1], ) arr[begin + 2], arr[begin + (l_size // 4 + 2)] = ( arr[begin + (l_size // 4 + 2)], arr[begin + 2], ) arr[pivot_pos - 2], arr[pivot_pos - (l_size // 4 + 1)] = ( arr[pivot_pos - (l_size // 4 + 1)], arr[pivot_pos - 2], ) arr[pivot_pos - 3], arr[pivot_pos - (l_size // 4 + 2)] = ( arr[pivot_pos - (l_size // 4 + 2)], arr[pivot_pos - 3], ) if r_size >= INSERTION_SORT_THRESHOLD: arr[pivot_pos + 1], arr[pivot_pos + (1 + r_size // 4)] = ( arr[pivot_pos + (1 + r_size // 4)], arr[pivot_pos + 1], ) arr[end - 1], arr[end - r_size // 4] = ( arr[end - r_size // 4], arr[end - 1], ) if r_size > NINTHER_THRESHOLD: arr[pivot_pos + 2], arr[pivot_pos + (2 + r_size // 4)] = ( arr[pivot_pos + (2 + r_size // 4)], arr[pivot_pos + 2], ) arr[pivot_pos + 3], arr[pivot_pos + (3 + r_size // 4)] = ( arr[pivot_pos + (3 + r_size // 4)], arr[pivot_pos + 3], ) arr[end - 2], arr[end - (1 + r_size // 4)] = ( arr[end - (1 + r_size // 4)], arr[end - 2], ) arr[end - 3], arr[end - (2 + r_size // 4)] = ( arr[end - (2 + r_size // 4)], arr[end - 3], ) else: if ( already_partitioned and _partial_insertion_sort(arr, begin, pivot_pos) and _partial_insertion_sort(arr, pivot_pos + 1, end) ): return _pdq_sort(arr, begin, pivot_pos, bad_allowed, leftmost) begin = pivot_pos + 1 leftmost = False # C stubs for vectorized quicksort from Highway from C import cnp_sort_int16(cobj, int) from C import cnp_sort_uint16(cobj, int) from C import cnp_sort_int32(cobj, int) from C import cnp_sort_uint32(cobj, int) from C import cnp_sort_int64(cobj, int) from C import cnp_sort_uint64(cobj, int) from C import cnp_sort_int16(cobj, int) from C import cnp_sort_uint128(cobj, int) from C import cnp_sort_float32(cobj, int) from C import cnp_sort_float64(cobj, int) def quicksort(start: Ptr[T], n: int, T: type): if T is int: cnp_sort_int64(start.as_byte(), n) elif T is i16: cnp_sort_int16(start.as_byte(), n) elif T is u16: cnp_sort_uint16(start.as_byte(), n) elif T is i32: cnp_sort_int32(start.as_byte(), n) elif T is u32: cnp_sort_uint32(start.as_byte(), n) elif T is i64: cnp_sort_int64(start.as_byte(), n) elif T is u64: cnp_sort_uint64(start.as_byte(), n) elif T is float32: cnp_sort_float32(start.as_byte(), n) elif T is float: cnp_sort_float64(start.as_byte(), n) else: _pdq_sort(start, 0, n, _floor_log2(n), True) def _apartial_insertion_sort(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type): if begin == end: return True limit = 0 cur = begin + 1 while cur != end: if limit > PARTIAL_INSERTION_SORT_LIMIT: return False sift = cur sift_1 = cur - 1 if _less(arr[tosort[sift]], arr[tosort[sift_1]]): itmp = tosort[sift] tmp = arr[itmp] while True: tosort[sift] = tosort[sift_1] sift -= 1 sift_1 -= 1 if sift == begin or not _less(tmp, arr[tosort[sift_1]]): break tosort[sift] = itmp limit += cur - sift cur += 1 return True def _apartition_left(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type): ipivot = tosort[begin] pivot = arr[ipivot] first = begin last = end while True: last -= 1 if not _less(pivot, arr[tosort[last]]): break if last + 1 == end: while first < last: first += 1 if _less(pivot, arr[tosort[first]]): break else: while True: first += 1 if _less(pivot, arr[tosort[first]]): break while first < last: tosort[first], tosort[last] = tosort[last], tosort[first] while True: last -= 1 if not _less(pivot, arr[tosort[last]]): break while True: first += 1 if _less(pivot, arr[tosort[first]]): break pivot_pos = last tosort[begin] = tosort[pivot_pos] tosort[pivot_pos] = ipivot return pivot_pos def _apartition_right(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type): ipivot = tosort[begin] pivot = arr[ipivot] first = begin last = end while True: first += 1 if not _less(arr[tosort[first]], pivot): break if first - 1 == begin: while first < last: last -= 1 if _less(arr[tosort[last]], pivot): break else: while True: last -= 1 if _less(arr[tosort[last]], pivot): break already_partitioned = 0 if first >= last: already_partitioned = 1 while first < last: tosort[first], tosort[last] = tosort[last], tosort[first] while True: first += 1 if not _less(arr[tosort[first]], pivot): break while True: last -= 1 if _less(arr[tosort[last]], pivot): break pivot_pos = first - 1 tosort[begin] = tosort[pivot_pos] tosort[pivot_pos] = ipivot return (pivot_pos, already_partitioned) def _asort2(arr: Ptr[T], tosort: Ptr[int], i: int, j: int, T: type): if _less(arr[tosort[j]], arr[tosort[i]]): tosort[i], tosort[j] = tosort[j], tosort[i] def _asort3(arr: Ptr[T], tosort: Ptr[int], i: int, j: int, k: int, T: type): _asort2(arr, tosort, i, j) _asort2(arr, tosort, j, k) _asort2(arr, tosort, i, j) def _apdq_sort( arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, bad_allowed: int, leftmost: bool, T: type ): while True: size = end - begin if size < INSERTION_SORT_THRESHOLD: ainsertionsort(arr, tosort + begin, size) return size_2 = size // 2 if size > NINTHER_THRESHOLD: _asort3(arr, tosort, begin, begin + size_2, end - 1) _asort3(arr, tosort, begin + 1, begin + (size_2 - 1), end - 2) _asort3(arr, tosort, begin + 2, begin + (size_2 + 1), end - 3) _asort3( arr, tosort, begin + (size_2 - 1), begin + size_2, begin + (size_2 + 1) ) tosort[begin], tosort[begin + size_2] = tosort[begin + size_2], tosort[begin] else: _asort3(arr, tosort, begin + size_2, begin, end - 1) if not leftmost and not _less(arr[tosort[begin - 1]], arr[tosort[begin]]): begin = _apartition_left(arr, tosort, begin, end) + 1 continue part_result = _apartition_right(arr, tosort, begin, end) pivot_pos = part_result[0] already_partitioned = part_result[1] == 1 l_size = pivot_pos - begin r_size = end - (pivot_pos + 1) highly_unbalanced = (l_size < (size // 8)) or (r_size < (size // 8)) if highly_unbalanced: bad_allowed -= 1 if bad_allowed == 0: aheapsort(arr, tosort + begin, end - begin) return if l_size >= INSERTION_SORT_THRESHOLD: tosort[begin], tosort[begin + l_size // 4] = ( tosort[begin + l_size // 4], tosort[begin], ) tosort[pivot_pos - 1], tosort[pivot_pos - l_size // 4] = ( tosort[pivot_pos - l_size // 4], tosort[pivot_pos - 1], ) if l_size > NINTHER_THRESHOLD: tosort[begin + 1], tosort[begin + (l_size // 4 + 1)] = ( tosort[begin + (l_size // 4 + 1)], tosort[begin + 1], ) tosort[begin + 2], tosort[begin + (l_size // 4 + 2)] = ( tosort[begin + (l_size // 4 + 2)], tosort[begin + 2], ) tosort[pivot_pos - 2], tosort[pivot_pos - (l_size // 4 + 1)] = ( tosort[pivot_pos - (l_size // 4 + 1)], tosort[pivot_pos - 2], ) tosort[pivot_pos - 3], tosort[pivot_pos - (l_size // 4 + 2)] = ( tosort[pivot_pos - (l_size // 4 + 2)], tosort[pivot_pos - 3], ) if r_size >= INSERTION_SORT_THRESHOLD: tosort[pivot_pos + 1], tosort[pivot_pos + (1 + r_size // 4)] = ( tosort[pivot_pos + (1 + r_size // 4)], tosort[pivot_pos + 1], ) tosort[end - 1], tosort[end - r_size // 4] = ( tosort[end - r_size // 4], tosort[end - 1], ) if r_size > NINTHER_THRESHOLD: tosort[pivot_pos + 2], tosort[pivot_pos + (2 + r_size // 4)] = ( tosort[pivot_pos + (2 + r_size // 4)], tosort[pivot_pos + 2], ) tosort[pivot_pos + 3], tosort[pivot_pos + (3 + r_size // 4)] = ( tosort[pivot_pos + (3 + r_size // 4)], tosort[pivot_pos + 3], ) tosort[end - 2], tosort[end - (1 + r_size // 4)] = ( tosort[end - (1 + r_size // 4)], tosort[end - 2], ) tosort[end - 3], tosort[end - (2 + r_size // 4)] = ( tosort[end - (2 + r_size // 4)], tosort[end - 3], ) else: if ( already_partitioned and _apartial_insertion_sort(arr, tosort, begin, pivot_pos) and _apartial_insertion_sort(arr, tosort, pivot_pos + 1, end) ): return _apdq_sort(arr, tosort, begin, pivot_pos, bad_allowed, leftmost) begin = pivot_pos + 1 leftmost = False def aquicksort(start: Ptr[T], tosort: Ptr[int], n: int, T: type): _apdq_sort(start, tosort, 0, n, _floor_log2(n), True) ######### # Radix # ######### def key_of(x: UT, T: type, UT: type): if T is UT: return x else: return x ^ (util.cast(1, UT) << util.cast(util.sizeof(UT) * 8 - 1, UT)) def nth_byte(key: T, l: int, T: type): return int(key >> util.cast(l << 3, T)) & 0xFF def _radixsort0(start: Ptr[UT], aux: Ptr[UT], num: int, T: type, UT: type): m = util.sizeof(UT) n = (1 << 8) ncnt = m * n cnt = Ptr[int](ncnt) # note: compiler should put this on the stack str.memset(cnt.as_byte(), byte(0), ncnt * util.sizeof(int)) key0 = key_of(start[0], T=T, UT=UT) for i in range(num): k = key_of(start[i], T=T, UT=UT) for l in range(m): cnt[l*n + nth_byte(k, l)] += 1 ncols = 0 cols = Ptr[u8](m) # again, compiler should put on stack for l in range(m): if cnt[l*n + nth_byte(key0, l)] != num: cols[ncols] = u8(l) ncols += 1 for l in range(ncols): a = 0 for i in range(256): b = cnt[int(cols[l])*n + i] cnt[int(cols[l])*n + i] = a a += b for l in range(ncols): for i in range(num): k = key_of(start[i], T=T, UT=UT) q = cnt + (int(cols[l])*n + nth_byte(k, int(cols[l]))) dst = q[0] q[0] += 1 aux[dst] = start[i] temp = aux aux = start start = temp return start def _radixsort(start: Ptr[UT], num: int, T: type, UT: type): if num < 2: return all_sorted = True k1 = key_of(start[0], T=T, UT=UT) for i in range(1, num): k2 = key_of(start[i], T=T, UT=UT) if k1 > k2: all_sorted = False break k1 = k2 if all_sorted: return aux = Ptr[UT](num) sort = _radixsort0(start, aux, num, T=T) if sort != start: str.memcpy(start.as_byte(), sort.as_byte(), num * util.sizeof(UT)) util.free(aux) def radixsort(start: Ptr[T], num: int, T: type): def wrap(start: Ptr, num: int, T: type, UT: type): _radixsort(Ptr[UT](start.as_byte()), num, T=T, UT=UT) if T is int: wrap(start, num, int, u64) elif isinstance(T, Int): wrap(start, num, T, UInt[T.N]) elif isinstance(T, UInt): wrap(start, num, T, T) elif isinstance(T, byte): wrap(start, num, i8, u8) else: raise ValueError("cannot radixsort type '" + T.__name__ + "'") def _aradixsort0(start: Ptr[UT], aux: Ptr[int], tosort: Ptr[int], num: int, T: type, UT: type): m = util.sizeof(UT) n = (1 << 8) ncnt = m * n cnt = Ptr[int](ncnt) # note: compiler should put this on the stack str.memset(cnt.as_byte(), byte(0), ncnt * util.sizeof(int)) key0 = key_of(start[0], T=T, UT=UT) for i in range(num): k = key_of(start[i], T=T, UT=UT) for l in range(m): cnt[l*n + nth_byte(k, l)] += 1 ncols = 0 cols = Ptr[u8](m) # again, compiler should put on stack for l in range(m): if cnt[l*n + nth_byte(key0, l)] != num: cols[ncols] = u8(l) ncols += 1 for l in range(ncols): a = 0 for i in range(256): b = cnt[int(cols[l])*n + i] cnt[int(cols[l])*n + i] = a a += b for l in range(ncols): for i in range(num): k = key_of(start[tosort[i]], T=T, UT=UT) q = cnt + (int(cols[l])*n + nth_byte(k, int(cols[l]))) dst = q[0] q[0] += 1 aux[dst] = tosort[i] temp = aux aux = tosort tosort = temp return tosort def _aradixsort(start: Ptr[UT], tosort: Ptr[int], num: int, T: type, UT: type): if num < 2: return all_sorted = True k1 = key_of(start[tosort[0]], T=T, UT=UT) for i in range(1, num): k2 = key_of(start[tosort[i]], T=T, UT=UT) if k1 > k2: all_sorted = False break k1 = k2 if all_sorted: return aux = Ptr[int](num) sort = _aradixsort0(start, aux, tosort, num, T=T) if sort != tosort: str.memcpy(tosort.as_byte(), sort.as_byte(), num * util.sizeof(int)) util.free(aux) def aradixsort(start: Ptr[T], tosort: Ptr[int], num: int, T: type): def wrap(start: Ptr, tosort: Ptr[int], num: int, T: type, UT: type): _aradixsort(Ptr[UT](start.as_byte()), tosort, num, T=T, UT=UT) if T is int: wrap(start, tosort, num, int, u64) elif isinstance(T, Int): wrap(start, tosort, num, T, UInt[T.N]) elif isinstance(T, UInt): wrap(start, tosort, num, T, T) elif isinstance(T, byte): wrap(start, tosort, num, i8, u8) else: raise ValueError("cannot radixsort type '" + T.__name__ + "'") ####### # Tim # ####### TIMSORT_STACK_SIZE: Static[int] = 128 def _compute_min_run(num: int): r = 0 while 64 < num: r |= num & 1 num >>= 1 return num + r # run = 2-tuple of (start, length) # buffer = 2-tuple of (pw, size) class Buffer[T]: pw: Ptr[T] size: int def __init__(self): self.pw = Ptr[T]() self.size = 0 def resize(self, new_size: int): buffer_pw = self.pw buffer_size = self.size if new_size <= buffer_size: return elif not buffer_pw: self.pw = Ptr[T](new_size) else: self.pw = util.realloc(buffer_pw, new_size, buffer_size) self.size = new_size def free(self): if self.pw: util.free(self.pw) self.pw = Ptr[T]() def _count_run(arr: Ptr[T], l: int, num: int, minrun: int, T: type): if num - l == 1: return 1 pl = arr + l if not _less(pl[1], pl[0]): pi = pl + 1 while pi < arr + (num - 1) and not _less(pi[1], pi[0]): pi += 1 else: pi = pl + 1 while pi < arr + (num - 1) and _less(pi[1], pi[0]): pi += 1 pj = pl pr = pi while pj < pr: pj[0], pr[0] = pr[0], pj[0] pj += 1 pr -= 1 pi += 1 sz = pi - pl if sz < minrun: if l + minrun < num: sz = minrun else: sz = num - l pr = pl + sz while pi < pr: vc = pi[0] pj = pi while pl < pj and _less(vc, pj[-1]): pj[0] = pj[-1] pj -= 1 pj[0] = vc pi += 1 return sz def _merge_left(p1: Ptr[T], l1: int, p2: Ptr[T], l2: int, p3: Ptr[T], T: type): end = p2 + l2 str.memcpy(p3.as_byte(), p1.as_byte(), l1 * util.sizeof(T)) p1[0] = p2[0] p1 += 1 p2 += 1 while p1 < p2 and p2 < end: if _less(p2[0], p3[0]): p1[0] = p2[0] p1 += 1 p2 += 1 else: p1[0] = p3[0] p1 += 1 p3 += 1 if p1 != p2: str.memcpy(p1.as_byte(), p3.as_byte(), (p2 - p1) * util.sizeof(T)) def _merge_right(p1: Ptr[T], l1: int, p2: Ptr[T], l2: int, p3: Ptr[T], T: type): start = p1 - 1 str.memcpy(p3.as_byte(), p2.as_byte(), l2 * util.sizeof(T)) p1 += l1 - 1 p2 += l2 - 1 p3 += l2 - 1 p2[0] = p1[0] p2 -= 1 p1 -= 1 while p1 < p2 and start < p1: if _less(p3[0], p1[0]): p2[0] = p1[0] p2 -= 1 p1 -= 1 else: p2[0] = p3[0] p2 -= 1 p3 -= 1 if p1 != p2: ofs = p2 - start str.memcpy((start + 1).as_byte(), (p3 - ofs + 1).as_byte(), ofs * util.sizeof(T)) def _gallop_right(arr: Ptr[T], size: int, key: T, T: type): if _less(key, arr[0]): return 0 last_ofs = 0 ofs = 1 while True: if size <= ofs or ofs < 0: ofs = size break if _less(key, arr[ofs]): break else: last_ofs = ofs ofs = (ofs << 1) + 1 while last_ofs + 1 < ofs: m = last_ofs + ((ofs - last_ofs) >> 1) if _less(key, arr[m]): ofs = m else: last_ofs = m return ofs def _gallop_left(arr: Ptr[T], size: int, key: T, T: type): if _less(arr[size - 1], key): return size last_ofs = 0 ofs = 1 while True: if size <= ofs or ofs < 0: ofs = size break if _less(arr[size - ofs - 1], key): break else: last_ofs = ofs ofs = (ofs << 1) + 1 l = size - ofs - 1 r = size - last_ofs - 1 while l + 1 < r: m = l + ((r - l) >> 1) if _less(arr[m], key): l = m else: r = m return r def _merge_at(arr: Ptr[T], stack: Ptr[Tuple[int,int]], at: int, buffer: Buffer[T], T: type): s1, l1 = stack[at] s2, l2 = stack[at + 1] k = _gallop_right(arr + s1, l1, arr[s2]) if l1 == k: return p1 = arr + (s1 + k) l1 -= k p2 = arr + s2 l2 = _gallop_left(arr + s2, l2, arr[s2 - 1]) if l2 < l1: buffer.resize(l2) _merge_right(p1, l1, p2, l2, buffer.pw) else: buffer.resize(l1) _merge_left(p1, l1, p2, l2, buffer.pw) def _try_collapse(arr: Ptr[T], stack: Ptr[Tuple[int,int]], stack_ptr: Ptr[int], buffer: Buffer[T], T: type): top = stack_ptr[0] while 1 < top: B = stack[top - 2][1] C = stack[top - 1][1] if ((2 < top and stack[top - 3][1] <= B + C) or (3 < top and stack[top - 4][1] <= stack[top - 3][1] + B)): A = stack[top - 3][1] if A <= C: _merge_at(arr, stack, top - 3, buffer) s, l = stack[top - 3] stack[top - 3] = (s, l + B) stack[top - 2] = stack[top - 1] top -= 1 else: _merge_at(arr, stack, top - 2, buffer) s, l = stack[top - 2] stack[top - 2] = (s, l + C) top -= 1 elif 1 < top and B <= C: _merge_at(arr, stack, top - 2, buffer) s, l = stack[top - 2] stack[top - 2] = (s, l + C) top -= 1 else: break stack_ptr[0] = top def _force_collapse(arr: Ptr[T], stack: Ptr[Tuple[int,int]], stack_ptr: Ptr[int], buffer: Buffer[T], T: type): top = stack_ptr[0] while 2 < top: if stack[top - 3][1] <= stack[top - 1][1]: _merge_at(arr, stack, top - 3, buffer) _, l1 = stack[top - 2] s, l2 = stack[top - 3] stack[top - 3] = (s, l1 + l2) stack[top - 2] = stack[top - 1] top -= 1 else: _merge_at(arr, stack, top - 2, buffer) _, l1 = stack[top - 1] s, l2 = stack[top - 2] stack[top - 2] = (s, l1 + l2) top -= 1 if 1 < top: _merge_at(arr, stack, top - 2, buffer) def timsort(start: Ptr[T], num: int, T: type): stack = __array__[Tuple[int,int]](TIMSORT_STACK_SIZE) buffer = Buffer[T]() stack_ptr = 0 minrun = _compute_min_run(num) l = 0 while l < num: n = _count_run(start, l, num, minrun) stack[stack_ptr] = (l, n) stack_ptr += 1 _try_collapse(start, stack.ptr, __ptr__(stack_ptr), buffer) l += n _force_collapse(start, stack.ptr, __ptr__(stack_ptr), buffer) buffer.free() def _acount_run(arr: Ptr[T], tosort: Ptr[int], l: int, num: int, minrun: int, T: type): if num - l == 1: return 1 pl = tosort + l if not _less(arr[pl[1]], arr[pl[0]]): pi = pl + 1 while pi < tosort + (num - 1) and not _less(arr[pi[1]], arr[pi[0]]): pi += 1 else: pi = pl + 1 while pi < tosort + (num - 1) and _less(arr[pi[1]], arr[pi[0]]): pi += 1 pj = pl pr = pi while pj < pr: pj[0], pr[0] = pr[0], pj[0] pj += 1 pr -= 1 pi += 1 sz = pi - pl if sz < minrun: if l + minrun < num: sz = minrun else: sz = num - l pr = pl + sz while pi < pr: vi = pi[0] vc = arr[vi] pj = pi while pl < pj and _less(vc, arr[pj[-1]]): pj[0] = pj[-1] pj -= 1 pj[0] = vi pi += 1 return sz def _amerge_left(arr: Ptr[T], p1: Ptr[int], l1: int, p2: Ptr[int], l2: int, p3: Ptr[int], T: type): end = p2 + l2 str.memcpy(p3.as_byte(), p1.as_byte(), l1 * util.sizeof(int)) p1[0] = p2[0] p1 += 1 p2 += 1 while p1 < p2 and p2 < end: if _less(arr[p2[0]], arr[p3[0]]): p1[0] = p2[0] p1 += 1 p2 += 1 else: p1[0] = p3[0] p1 += 1 p3 += 1 if p1 != p2: str.memcpy(p1.as_byte(), p3.as_byte(), (p2 - p1) * util.sizeof(int)) def _amerge_right(arr: Ptr[T], p1: Ptr[int], l1: int, p2: Ptr[int], l2: int, p3: Ptr[int], T: type): start = p1 - 1 str.memcpy(p3.as_byte(), p2.as_byte(), l2 * util.sizeof(int)) p1 += l1 - 1 p2 += l2 - 1 p3 += l2 - 1 p2[0] = p1[0] p2 -= 1 p1 -= 1 while p1 < p2 and start < p1: if _less(arr[p3[0]], arr[p1[0]]): p2[0] = p1[0] p2 -= 1 p1 -= 1 else: p2[0] = p3[0] p2 -= 1 p3 -= 1 if p1 != p2: ofs = p2 - start str.memcpy((start + 1).as_byte(), (p3 - ofs + 1).as_byte(), ofs * util.sizeof(int)) def _agallop_right(arr: Ptr[T], tosort: Ptr[int], size: int, key: T, T: type): if _less(key, arr[tosort[0]]): return 0 last_ofs = 0 ofs = 1 while True: if size <= ofs or ofs < 0: ofs = size break if _less(key, arr[tosort[ofs]]): break else: last_ofs = ofs ofs = (ofs << 1) + 1 while last_ofs + 1 < ofs: m = last_ofs + ((ofs - last_ofs) >> 1) if _less(key, arr[tosort[m]]): ofs = m else: last_ofs = m return ofs def _agallop_left(arr: Ptr[T], tosort: Ptr[int], size: int, key: T, T: type): if _less(arr[tosort[size - 1]], key): return size last_ofs = 0 ofs = 1 while True: if size <= ofs or ofs < 0: ofs = size break if _less(arr[tosort[size - ofs - 1]], key): break else: last_ofs = ofs ofs = (ofs << 1) + 1 l = size - ofs - 1 r = size - last_ofs - 1 while l + 1 < r: m = l + ((r - l) >> 1) if _less(arr[tosort[m]], key): l = m else: r = m return r def _amerge_at(arr: Ptr[T], tosort: Ptr[int], stack: Ptr[Tuple[int,int]], at: int, buffer: Buffer[int], T: type): s1, l1 = stack[at] s2, l2 = stack[at + 1] k = _agallop_right(arr, tosort + s1, l1, arr[tosort[s2]]) if l1 == k: return p1 = tosort + (s1 + k) l1 -= k p2 = tosort + s2 l2 = _agallop_left(arr, tosort + s2, l2, arr[tosort[s2 - 1]]) if l2 < l1: buffer.resize(l2) _amerge_right(arr, p1, l1, p2, l2, buffer.pw) else: buffer.resize(l1) _amerge_left(arr, p1, l1, p2, l2, buffer.pw) def _atry_collapse(arr: Ptr[T], tosort: Ptr[int], stack: Ptr[Tuple[int,int]], stack_ptr: Ptr[int], buffer: Buffer[int], T: type): top = stack_ptr[0] while 1 < top: B = stack[top - 2][1] C = stack[top - 1][1] if ((2 < top and stack[top - 3][1] <= B + C) or (3 < top and stack[top - 4][1] <= stack[top - 3][1] + B)): A = stack[top - 3][1] if A <= C: _amerge_at(arr, tosort, stack, top - 3, buffer) s, l = stack[top - 3] stack[top - 3] = (s, l + B) stack[top - 2] = stack[top - 1] top -= 1 else: _amerge_at(arr, tosort, stack, top - 2, buffer) s, l = stack[top - 2] stack[top - 2] = (s, l + C) top -= 1 elif 1 < top and B <= C: _amerge_at(arr, tosort, stack, top - 2, buffer) s, l = stack[top - 2] stack[top - 2] = (s, l + C) top -= 1 else: break stack_ptr[0] = top def _aforce_collapse(arr: Ptr[T], tosort: Ptr[int], stack: Ptr[Tuple[int,int]], stack_ptr: Ptr[int], buffer: Buffer[int], T: type): top = stack_ptr[0] while 2 < top: if stack[top - 3][1] <= stack[top - 1][1]: _amerge_at(arr, tosort, stack, top - 3, buffer) _, l1 = stack[top - 2] s, l2 = stack[top - 3] stack[top - 3] = (s, l1 + l2) stack[top - 2] = stack[top - 1] top -= 1 else: _amerge_at(arr, tosort, stack, top - 2, buffer) _, l1 = stack[top - 1] s, l2 = stack[top - 2] stack[top - 2] = (s, l1 + l2) top -= 1 if 1 < top: _amerge_at(arr, tosort, stack, top - 2, buffer) def atimsort(start: Ptr[T], tosort: Ptr[int], num: int, T: type): stack = __array__[Tuple[int,int]](TIMSORT_STACK_SIZE) buffer = Buffer[int]() stack_ptr = 0 minrun = _compute_min_run(num) l = 0 while l < num: n = _acount_run(start, tosort, l, num, minrun) stack[stack_ptr] = (l, n) stack_ptr += 1 _atry_collapse(start, tosort, stack.ptr, __ptr__(stack_ptr), buffer) l += n _aforce_collapse(start, tosort, stack.ptr, __ptr__(stack_ptr), buffer) buffer.free() ########## # Stable # ########## def stablesort(start: Ptr[T], num: int, T: type): if T is int or T is byte or isinstance(T, Int) or isinstance(T, UInt): return radixsort(start, num) else: return timsort(start, num) def astablesort(start: Ptr[T], tosort: Ptr[int], num: int, T: type): if T is int or T is byte or isinstance(T, Int) or isinstance(T, UInt): return aradixsort(start, tosort, num) else: return atimsort(start, tosort, num) ############# # Selection # ############# MAX_PIVOT_STACK: Static[int] = 50 @tuple class Sortee: v: Ptr[T] T: type def __call__(self, i: int): return self.v[i] def swap(self, i: int, j: int): v = self.v tmp = v[i] v[i] = v[j] v[j] = tmp @tuple class Idx: def __call__(self, i: int): return i @tuple class ArgSortee: tosort: Ptr[int] def __call__(self, i: int): return self.tosort[i] def swap(self, i: int, j: int): v = self.tosort tmp = v[i] v[i] = v[j] v[j] = tmp @tuple class ArgIdx: tosort: Ptr[int] def __call__(self, i: int): return self.tosort[i] def _store_pivot(pivot: int, kth: int, pivots: Ptr[int], npiv: Ptr[int]): if not pivots: return if pivot == kth and npiv[0] == MAX_PIVOT_STACK: pivots[npiv[0] - 1] = pivot elif pivot >= kth and npiv[0] < MAX_PIVOT_STACK: pivots[npiv[0]] = pivot npiv[0] += 1 def _median3_swap(v: Ptr[T], tosort: Ptr[int], low: int, mid: int, high: int, arg: Static[int], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) else: idx = Idx() sortee = Sortee(v) if _less(v[idx(high)], v[idx(mid)]): sortee.swap(high, mid) if _less(v[idx(high)], v[idx(low)]): sortee.swap(high, low) if _less(v[idx(low)], v[idx(mid)]): sortee.swap(low, mid) sortee.swap(mid, low + 1) def _median5(v: Ptr[T], tosort: Ptr[int], arg: Static[int], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) else: idx = Idx() sortee = Sortee(v) if _less(v[idx(1)], v[idx(0)]): sortee.swap(1, 0) if _less(v[idx(4)], v[idx(3)]): sortee.swap(4, 3) if _less(v[idx(3)], v[idx(0)]): sortee.swap(3, 0) if _less(v[idx(4)], v[idx(1)]): sortee.swap(4, 1) if _less(v[idx(2)], v[idx(1)]): sortee.swap(2, 1) if _less(v[idx(3)], v[idx(2)]): if _less(v[idx(3)], v[idx(1)]): return 1 else: return 3 else: return 2 def _unguarded_partition(v: Ptr[T], tosort: Ptr[int], pivot: T, ll: int, hh: int, arg: Static[int], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) else: idx = Idx() sortee = Sortee(v) while True: while True: ll += 1 if not _less(v[idx(ll)], pivot): break while True: hh -= 1 if not _less(pivot, v[idx(hh)]): break if hh < ll: break sortee.swap(ll, hh) return ll, hh def _median_of_median5(v: Ptr[T], tosort: Ptr[int], num: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[int], isel, T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) else: idx = Idx() sortee = Sortee(v) right = num - 1 nmed = (right + 1) // 5 subleft = 0 for i in range(nmed): m = _median5(v + (0 if arg else subleft), tosort + (subleft if arg else 0), arg) sortee.swap(subleft + m, i) subleft += 5 if nmed > 2: isel(v, tosort, nmed, nmed // 2, pivots, npiv, arg) return nmed // 2 def _dumb_select(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, arg: Static[int], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) else: idx = Idx() sortee = Sortee(v) for i in range(kth + 1): minidx = i minval = v[idx(i)] for k in range(i + 1, num): if _less(v[idx(k)], minval): minidx = k minval = v[idx(k)] sortee.swap(i, minidx) def _msb(unum: u64): depth_limit = 0 while unum >> u64(1): depth_limit += 1 unum >>= u64(1) return depth_limit def _introselect(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[int], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) else: idx = Idx() sortee = Sortee(v) low = 0 high = num - 1 if not npiv: pivots = Ptr[int]() while pivots and npiv[0] > 0: if pivots[npiv[0] - 1] > kth: high = pivots[npiv[0] - 1] - 1 break elif pivots[npiv[0] - 1] == kth: return low = pivots[npiv[0] - 1] + 1 npiv[0] -= 1 if kth - low < 3: _dumb_select(v + (0 if arg else low), tosort + (low if arg else 0), high - low + 1, kth - low, arg) _store_pivot(kth, kth, pivots, npiv) return elif (T is float or T is float32) and kth == num - 1: maxidx = low maxval = v[idx(low)] for k in range(low + 1, num): if not _less(v[idx(k)], maxval): maxidx = k maxval = v[idx(k)] sortee.swap(kth, maxidx) return depth_limit = _msb(u64(num)) * 2 while low + 1 < high: ll = low + 1 hh = high if depth_limit > 0 or hh - ll < 5: mid = low + (high - low) // 2 _median3_swap(v, tosort, low, mid, high, arg) else: mid = ll + _median_of_median5(v + (0 if arg else ll), tosort + (ll if arg else 0), hh - ll, Ptr[int](), Ptr[int](), arg, _introselect) sortee.swap(mid, low) ll -= 1 hh += 1 depth_limit -= 1 ll, hh = _unguarded_partition(v, tosort, v[idx(low)], ll, hh, arg) sortee.swap(low, hh) if hh != kth: _store_pivot(hh, kth, pivots, npiv) if hh >= kth: high = hh - 1 if hh <= kth: low = ll if high == low + 1: if _less(v[idx(high)], v[idx(low)]): sortee.swap(high, low) _store_pivot(kth, kth, pivots, npiv) def _partition(v: Ptr[T], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], T: type): _introselect(v, Ptr[int](), num, kth, pivots, npiv, arg=False) def _argpartition(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], T: type): _introselect(v, tosort, num, kth, pivots, npiv, arg=True) def _partition_compact(a: ndarray, kth: Ptr[int], nkth: int, axis: int): n = a.shape[axis] pivots = __array__[int](MAX_PIVOT_STACK) for idx in util.multirange(util.tuple_delete(a.shape, axis)): idx1 = util.tuple_insert(idx, axis, 0) p = a._ptr(idx1) npiv = 0 for i in range(nkth): _partition(p, n, kth[i], pivots.ptr, __ptr__(npiv)) def _partition_buffered(a: ndarray, kth: Ptr[int], nkth: int, axis: int): n = a.shape[axis] st = a.strides[axis] buf = Ptr[a.dtype](n) pivots = __array__[int](MAX_PIVOT_STACK) for idx in util.multirange(util.tuple_delete(a.shape, axis)): npiv = 0 idx1 = util.tuple_insert(idx, axis, 0) p = a._ptr(idx1).as_byte() p0 = p for i in range(n): buf[i] = (Ptr[a.dtype](p))[0] p += st for i in range(nkth): _partition(buf, n, kth[i], pivots.ptr, __ptr__(npiv)) p = p0 for i in range(n): Ptr[a.dtype](p)[0] = buf[i] p += st util.free(buf) def _fix_kth(kth, n: int): kth = asarray(kth, order='C') if kth.dtype is not int: compile_error("Partition index must be integer") if kth.ndim > 1: compile_error("kth array must have dimension <= 1") pkth = kth.data nkth = kth.size for i in range(nkth): if pkth[i] < 0: pkth[i] += n if pkth[i] < 0 or pkth[i] >= n: raise ValueError(f"kth(={pkth[i]}) out of bounds ({n})") if kth.ndim == 1: kth.sort() return pkth, nkth def _partition_helper(a: ndarray, kth, axis, kind: str, force_compact: Static[int] = False): if axis is None: _partition_helper(a.flatten(), kth=kth, axis=-1, kind=kind, force_compact=force_compact) return if kind != 'introselect': raise ValueError(f"select kind must be 'introselect' (got {repr(kind)})") axis = util.normalize_axis_index(axis, a.ndim) n = a.shape[axis] pkth, nkth = _fix_kth(kth, n) if force_compact: _partition_compact(a, pkth, nkth, axis) else: if a.strides[axis] == a.itemsize: _partition_compact(a, pkth, nkth, axis) else: _partition_buffered(a, pkth, nkth, axis) def partition(a, kth, axis = -1, kind: str = 'introselect'): if isinstance(a, ndarray): b = a.copy(order='C') else: b = asarray(a) _partition_helper(b, kth=kth, axis=axis, kind=kind, force_compact=(b.ndim == 1)) return b def argpartition(a: ndarray, kth, axis = -1, kind: str = 'introselect'): if axis is None: return _argpartition(a.flatten(), kth=kth, axis=-1, kind=kind) if kind != 'introselect': raise ValueError(f"select kind must be 'introselect' (got {repr(kind)})") axis = util.normalize_axis_index(axis, a.ndim) n = a.shape[axis] pkth, nkth = _fix_kth(kth, n) b = empty(a.shape, int) a_stride = a.strides[axis] b_stride = b.strides[axis] a_compact = (a_stride == a.itemsize) b_compact = (b_stride == b.itemsize) a_buf = Ptr[a.dtype]() if a_compact else Ptr[a.dtype](n) b_buf = Ptr[b.dtype]() if b_compact else Ptr[b.dtype](n) pivots = __array__[int](MAX_PIVOT_STACK) for idx in util.multirange(util.tuple_delete(a.shape, axis)): npiv = 0 idx1 = util.tuple_insert(idx, axis, 0) pa = a._ptr(idx1) pb = b._ptr(idx1) qa = pa.as_byte() qb = pb.as_byte() if not a_compact: for i in range(n): a_buf[i] = (Ptr[a.dtype](qa))[0] qa += a_stride if not b_compact: for i in range(n): b_buf[i] = i else: for i in range(n): pb[i] = i start = pa if a_compact else a_buf tosort = pb if b_compact else b_buf for i in range(nkth): _argpartition(start, tosort, n, pkth[i], pivots.ptr, __ptr__(npiv)) if not a_compact: qa = pa.as_byte() for i in range(n): Ptr[a.dtype](qa)[0] = a_buf[i] qa += a_stride if not b_compact: qb = pb.as_byte() for i in range(n): Ptr[b.dtype](qb)[0] = b_buf[i] qb += b_stride if not a_compact: util.free(a_buf) if not b_compact: util.free(b_buf) return b ########### # Lexsort # ########### def _lexsort(sort_keys, n: int, axis: int): sort_keys0 = asarray(sort_keys[0]) nd: Static[int] = sort_keys0.ndim if nd == 0 and (axis == 0 or axis == -1): pass else: axis = util.normalize_axis_index(axis, nd) if nd == 0: return 0 if sort_keys0.size <= 1: return zeros(sort_keys[0].shape, int) ret = empty(sort_keys0.shape, int) rstride = ret.strides[axis] maxelsize = sort_keys0.itemsize needcopy = (rstride != ret.itemsize) N = sort_keys0.shape[axis] for j in range(n): key = sort_keys[j] needcopy |= (key.strides[axis] != key.itemsize) maxelsize = max(maxelsize, key.itemsize) valbuffer = cobj() indbuffer = Ptr[int]() if needcopy: valbufsize = N * maxelsize if valbufsize == 0: valbufsize = 1 valbuffer = cobj(valbufsize) indbufsize = N * ret.itemsize indbuffer = Ptr[int](indbufsize) for idx in util.multirange(util.tuple_delete(ret.shape, axis)): idx1 = util.tuple_insert(idx, axis, 0) q = ret._ptr(idx1).as_byte() for i in range(N): indbuffer[i] = i for j in range(n): key = sort_keys[j] buf = Ptr[key.dtype](valbuffer.as_byte()) p = key._ptr(idx1).as_byte() s = key.strides[axis] for i in range(N): buf[i] = (Ptr[key.dtype](p))[0] p += s astablesort(buf, indbuffer, N) for i in range(N): z = Ptr[int](q) z[0] = indbuffer[i] q += rstride else: for idx in util.multirange(util.tuple_delete(ret.shape, axis)): idx1 = util.tuple_insert(idx, axis, 0) q = ret._ptr(idx1) for i in range(N): q[i] = i for j in range(n): p = sort_keys[j]._ptr(idx1) astablesort(p, q, N) if valbuffer: util.free(valbuffer) if indbuffer: util.free(indbuffer) return ret def lexsort(keys, axis: int = -1): n = len(keys) if n == 0: raise ValueError("need sequence of keys with len > 0 in lexsort") if isinstance(keys, Tuple): sort_keys = tuple(asarray(key) for key in keys) for i in staticrange(1, staticlen(sort_keys)): if sort_keys[i].ndim != sort_keys[0].ndim: compile_error("all keys need to be the same shape") if sort_keys[i].shape != sort_keys[0].shape: raise ValueError("all keys need to be the same shape") return _lexsort(sort_keys, n, axis) elif isinstance(keys, List): if isinstance(keys[0], ndarray): sort_keys = keys else: sort_keys = [asarray(key) for key in keys] for i in range(1, n): if sort_keys[i].shape != sort_keys[0].shape: raise ValueError("all keys need to be the same shape") return _lexsort(sort_keys, n, axis) elif isinstance(keys, ndarray): return _lexsort(keys, n, axis) else: compile_error("keys must be an ndarray or a tuple") ####### # API # ####### def _sort_compact(a: ndarray, axis: int, sorter): n = a.shape[axis] for idx in util.multirange(util.tuple_delete(a.shape, axis)): idx1 = util.tuple_insert(idx, axis, 0) p = a._ptr(idx1) sorter(p, n) def _sort_buffered(a: ndarray, axis: int, sorter): n = a.shape[axis] st = a.strides[axis] buf = Ptr[a.dtype](n) for idx in util.multirange(util.tuple_delete(a.shape, axis)): idx1 = util.tuple_insert(idx, axis, 0) p = a._ptr(idx1).as_byte() p0 = p for i in range(n): buf[i] = (Ptr[a.dtype](p))[0] p += st sorter(buf, n) p = p0 for i in range(n): Ptr[a.dtype](p)[0] = buf[i] p += st util.free(buf) def _sort_dispatch(a: ndarray, axis: int, sorter, force_compact: Static[int] = False): axis = util.normalize_axis_index(axis, a.ndim) if force_compact: _sort_compact(a, axis, sorter) else: if a.strides[axis] == a.itemsize: _sort_compact(a, axis, sorter) else: _sort_buffered(a, axis, sorter) def _sort(a: ndarray, axis: int, kind: Optional[str], force_compact: Static[int] = False): if kind is None or kind == 'quicksort' or kind == 'quick': _sort_dispatch(a, axis, quicksort, force_compact) elif kind == 'mergesort' or kind == 'merge' or kind == 'stable': _sort_dispatch(a, axis, stablesort, force_compact) elif kind == 'heapsort' or kind == 'heap': _sort_dispatch(a, axis, heapsort, force_compact) else: raise ValueError(f"sort kind must be one of 'quick', 'heap', or 'stable' (got {repr(kind)})") def sort(a, axis = -1, kind: Optional[str] = None): if axis is None: return sort(asarray(a).flatten(), axis=-1, kind=kind) if isinstance(a, ndarray): b = a.copy(order='C') else: b = asarray(a) _sort(b, axis=axis, kind=kind, force_compact=(b.ndim == 1)) return b def _asort(a: ndarray, axis: int, sorter): axis = util.normalize_axis_index(axis, a.ndim) n = a.shape[axis] b = empty(a.shape, int) a_stride = a.strides[axis] b_stride = b.strides[axis] a_compact = (a_stride == a.itemsize) b_compact = (b_stride == b.itemsize) a_buf = Ptr[a.dtype]() if a_compact else Ptr[a.dtype](n) b_buf = Ptr[b.dtype]() if b_compact else Ptr[b.dtype](n) for idx in util.multirange(util.tuple_delete(a.shape, axis)): idx1 = util.tuple_insert(idx, axis, 0) pa = a._ptr(idx1) pb = b._ptr(idx1) qa = pa.as_byte() qb = pb.as_byte() if not a_compact: for i in range(n): a_buf[i] = (Ptr[a.dtype](qa))[0] qa += a_stride if not b_compact: for i in range(n): b_buf[i] = i else: for i in range(n): pb[i] = i start = pa if a_compact else a_buf tosort = pb if b_compact else b_buf sorter(start, tosort, n) if not a_compact: qa = pa.as_byte() for i in range(n): Ptr[a.dtype](qa)[0] = a_buf[i] qa += a_stride if not b_compact: qb = pb.as_byte() for i in range(n): Ptr[b.dtype](qb)[0] = b_buf[i] qb += b_stride if not a_compact: util.free(a_buf) if not b_compact: util.free(b_buf) return b def argsort(a, axis = -1, kind: Optional[str] = None): if axis is None: return argsort(asarray(a).flatten(), axis=-1, kind=kind) a = asarray(a) if kind is None or kind == 'quicksort' or kind == 'quick': return _asort(a, axis, aquicksort) elif kind == 'mergesort' or kind == 'merge' or kind == 'stable': return _asort(a, axis, astablesort) elif kind == 'heapsort' or kind == 'heap': return _asort(a, axis, aheapsort) else: raise ValueError(f"sort kind must be one of 'quick', 'heap', or 'stable' (got {repr(kind)})") def sort_complex(a): b = array(a, copy=True) b.sort() if not (b.dtype is complex or b.dtype is complex64): if b.dtype is byte or b.dtype is i8 or b.dtype is u8 or b.dtype is i16 or b.dtype is u16: return b.astype(complex64) else: return b.astype(complex) else: return b @extend class ndarray: def sort(self, axis: int = -1, kind: Optional[str] = None): _sort(self, axis=axis, kind=kind, force_compact=False) def argsort(self, axis: int = -1, kind: Optional[str] = None): return argsort(self, axis=axis, kind=kind) def partition(self, kth, axis: int = -1, kind: str = 'introselect'): _partition_helper(self, kth=kth, axis=axis, kind=kind, force_compact=False) def argpartition(self, kth, axis: int = -1, kind: str = 'introselect'): return argpartition(self, kth=kth, axis=axis, kind=kind) # TODO: support 'order' argument on sort, argsort, ndarray.sort, partition and argpartition