1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/sorting.codon

2184 lines
58 KiB
Python
Raw Normal View History

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from .ndarray import ndarray
from .routines import array, asarray, empty, zeros
import util
##########
# Common #
##########
def _less(a: T, b: T, T: type):
if T is float or T is float32:
return a < b or (b != b and a == a)
elif T is complex or T is complex64:
if a.real < b.real:
return a.imag == a.imag or b.imag != b.imag
elif a.real > b.real:
return b.imag != b.imag and a.imag == a.imag
elif a.real == b.real or (a.real != a.real and b.real != b.real):
return a.imag < b.imag or (b.imag != b.imag and a.imag == a.imag)
else:
return b.real != b.real
else:
return a < b
#############
# Insertion #
#############
def insertionsort(start: Ptr[T], n: int, T: type):
i = 1
while i < n:
x = start[i]
j = i - 1
while j >= 0 and _less(x, start[j]):
start[j + 1] = start[j]
j -= 1
start[j + 1] = x
i += 1
def ainsertionsort(v: Ptr[T], tosort: Ptr[int], n: int, T: type):
i = 1
while i < n:
x = tosort[i]
j = i - 1
while j >= 0 and _less(v[x], v[tosort[j]]):
tosort[j + 1] = tosort[j]
j -= 1
tosort[j + 1] = x
i += 1
########
# Heap #
########
def heapsort(start: Ptr[T], n: int, T: type):
a = start - 1
for l in range(n >> 1, 0, -1):
tmp = a[l]
i = l
j = l << 1
while j <= n:
if j < n and _less(a[j], a[j + 1]):
j += 1
if _less(tmp, a[j]):
a[i] = a[j]
i = j
j += j
else:
break
a[i] = tmp
while n > 1:
tmp = a[n]
a[n] = a[1]
n -= 1
i = 1
j = 2
while j <= n:
if j < n and _less(a[j], a[j + 1]):
j += 1
if _less(tmp, a[j]):
a[i] = a[j]
i = j
j += j
else:
break
a[i] = tmp
def aheapsort(vv: Ptr[T], tosort: Ptr[int], n: int, T: type):
v = vv
a = tosort - 1
for l in range(n >> 1, 0, -1):
tmp = a[l]
i = l
j = l << 1
while j <= n:
if j < n and _less(v[a[j]], v[a[j + 1]]):
j += 1
if _less(v[tmp], v[a[j]]):
a[i] = a[j]
i = j
j += j
else:
break
a[i] = tmp
while n > 1:
tmp = a[n]
a[n] = a[1]
n -= 1
i = 1
j = 2
while j <= n:
if j < n and _less(v[a[j]], v[a[j + 1]]):
j += 1
if _less(v[tmp], v[a[j]]):
a[i] = a[j]
i = j
j += j
else:
break
a[i] = tmp
#########
# Merge #
#########
PYA_QS_STACK : Static[int] = 100
SMALL_QUICKSORT: Static[int] = 15
SMALL_MERGESORT: Static[int] = 20
SMALL_STRING : Static[int] = 16
def _mergesort0(pl: Ptr[T], pr: Ptr[T], pw: Ptr[T], T: type):
if pr - pl > SMALL_MERGESORT:
pm = pl + ((pr - pl) >> 1)
_mergesort0(pl, pm, pw)
_mergesort0(pm, pr, pw)
pi = pw
pj = pl
while pj < pm:
pi[0] = pj[0]
pi += 1
pj += 1
pi = pw + (pm - pl)
pj = pw
pk = pl
while pj < pi and pm < pr:
if _less(pm[0], pj[0]):
pk[0] = pm[0]
pk += 1
pm += 1
else:
pk[0] = pj[0]
pk += 1
pj += 1
while pj < pi:
pk[0] = pj[0]
pk += 1
pj += 1
else:
pi = pl + 1
while pi < pr:
vp = pi[0]
pj = pi
pk = pi - 1
while pj > pl and _less(vp, pk[0]):
pj[0] = pk[0]
pj -= 1
pk -= 1
pj[0] = vp
pi += 1
def mergesort(start: Ptr[T], num: int, T: type):
pl = start
pr = pl + num
pw = Ptr[T](num // 2)
_mergesort0(pl, pr, pw)
util.free(pw)
def _amergesort0(pl: Ptr[int], pr: Ptr[int], v: Ptr[T], pw: Ptr[int], T: type):
if pr - pl > SMALL_MERGESORT:
pm = pl + ((pr - pl) >> 1)
_amergesort0(pl, pm, v, pw)
_amergesort0(pm, pr, v, pw)
pi = pw
pj = pl
while pj < pm:
pi[0] = pj[0]
pi += 1
pj += 1
pi = pw + (pm - pl)
pj = pw
pk = pl
while pj < pi and pm < pr:
if _less(v[pm[0]], v[pj[0]]):
pk[0] = pm[0]
pk += 1
pm += 1
else:
pk[0] = pj[0]
pk += 1
pj += 1
while pj < pi:
pk[0] = pj[0]
pk += 1
pj += 1
else:
pi = pl + 1
while pi < pr:
vi = pi[0]
vp = v[vi]
pj = pi
pk = pi - 1
while pj > pl and _less(vp, v[pk[0]]):
pj[0] = pk[0]
pj -= 1
pk -= 1
pj[0] = vi
pi += 1
def amergesort(v: Ptr[T], tosort: Ptr[int], num: int, T: type):
pl = tosort
pr = pl + num
pw = Ptr[int](num // 2)
_amergesort0(pl, pr, v, pw)
util.free(pw)
###############
# Quick (PDQ) #
###############
INSERTION_SORT_THRESHOLD : Static[int] = 24
NINTHER_THRESHOLD : Static[int] = 128
PARTIAL_INSERTION_SORT_LIMIT: Static[int] = 8
def _floor_log2(n: int):
log = 0
while True:
n >>= 1
if n == 0:
break
log += 1
return log
def _partial_insertion_sort(arr: Ptr[T], begin: int, end: int, T: type):
if begin == end:
return True
limit = 0
cur = begin + 1
while cur != end:
if limit > PARTIAL_INSERTION_SORT_LIMIT:
return False
sift = cur
sift_1 = cur - 1
if _less(arr[sift], arr[sift_1]):
tmp = arr[sift]
while True:
arr[sift] = arr[sift_1]
sift -= 1
sift_1 -= 1
if sift == begin or not _less(tmp, arr[sift_1]):
break
arr[sift] = tmp
limit += cur - sift
cur += 1
return True
def _partition_left(arr: Ptr[T], begin: int, end: int, T: type):
pivot = arr[begin]
first = begin
last = end
while True:
last -= 1
if not _less(pivot, arr[last]):
break
if last + 1 == end:
while first < last:
first += 1
if _less(pivot, arr[first]):
break
else:
while True:
first += 1
if _less(pivot, arr[first]):
break
while first < last:
arr[first], arr[last] = arr[last], arr[first]
while True:
last -= 1
if not _less(pivot, arr[last]):
break
while True:
first += 1
if _less(pivot, arr[first]):
break
pivot_pos = last
arr[begin] = arr[pivot_pos]
arr[pivot_pos] = pivot
return pivot_pos
def _partition_right(arr: Ptr[T], begin: int, end: int, T: type):
pivot = arr[begin]
first = begin
last = end
while True:
first += 1
if not _less(arr[first], pivot):
break
if first - 1 == begin:
while first < last:
last -= 1
if _less(arr[last], pivot):
break
else:
while True:
last -= 1
if _less(arr[last], pivot):
break
already_partitioned = 0
if first >= last:
already_partitioned = 1
while first < last:
arr[first], arr[last] = arr[last], arr[first]
while True:
first += 1
if not _less(arr[first], pivot):
break
while True:
last -= 1
if _less(arr[last], pivot):
break
pivot_pos = first - 1
arr[begin] = arr[pivot_pos]
arr[pivot_pos] = pivot
return (pivot_pos, already_partitioned)
def _sort2(arr: Ptr[T], i: int, j: int, T: type):
if _less(arr[j], arr[i]):
arr[i], arr[j] = arr[j], arr[i]
def _sort3(arr: Ptr[T], i: int, j: int, k: int, T: type):
_sort2(arr, i, j)
_sort2(arr, j, k)
_sort2(arr, i, j)
def _pdq_sort(
arr: Ptr[T],
begin: int,
end: int,
bad_allowed: int,
leftmost: bool,
T: type
):
while True:
size = end - begin
if size < INSERTION_SORT_THRESHOLD:
insertionsort(arr + begin, size)
return
size_2 = size // 2
if size > NINTHER_THRESHOLD:
_sort3(arr, begin, begin + size_2, end - 1)
_sort3(arr, begin + 1, begin + (size_2 - 1), end - 2)
_sort3(arr, begin + 2, begin + (size_2 + 1), end - 3)
_sort3(
arr, begin + (size_2 - 1), begin + size_2, begin + (size_2 + 1)
)
arr[begin], arr[begin + size_2] = arr[begin + size_2], arr[begin]
else:
_sort3(arr, begin + size_2, begin, end - 1)
if not leftmost and not _less(arr[begin - 1], arr[begin]):
begin = _partition_left(arr, begin, end) + 1
continue
part_result = _partition_right(arr, begin, end)
pivot_pos = part_result[0]
already_partitioned = part_result[1] == 1
l_size = pivot_pos - begin
r_size = end - (pivot_pos + 1)
highly_unbalanced = (l_size < (size // 8)) or (r_size < (size // 8))
if highly_unbalanced:
bad_allowed -= 1
if bad_allowed == 0:
heapsort(arr + begin, end - begin)
return
if l_size >= INSERTION_SORT_THRESHOLD:
arr[begin], arr[begin + l_size // 4] = (
arr[begin + l_size // 4],
arr[begin],
)
arr[pivot_pos - 1], arr[pivot_pos - l_size // 4] = (
arr[pivot_pos - l_size // 4],
arr[pivot_pos - 1],
)
if l_size > NINTHER_THRESHOLD:
arr[begin + 1], arr[begin + (l_size // 4 + 1)] = (
arr[begin + (l_size // 4 + 1)],
arr[begin + 1],
)
arr[begin + 2], arr[begin + (l_size // 4 + 2)] = (
arr[begin + (l_size // 4 + 2)],
arr[begin + 2],
)
arr[pivot_pos - 2], arr[pivot_pos - (l_size // 4 + 1)] = (
arr[pivot_pos - (l_size // 4 + 1)],
arr[pivot_pos - 2],
)
arr[pivot_pos - 3], arr[pivot_pos - (l_size // 4 + 2)] = (
arr[pivot_pos - (l_size // 4 + 2)],
arr[pivot_pos - 3],
)
if r_size >= INSERTION_SORT_THRESHOLD:
arr[pivot_pos + 1], arr[pivot_pos + (1 + r_size // 4)] = (
arr[pivot_pos + (1 + r_size // 4)],
arr[pivot_pos + 1],
)
arr[end - 1], arr[end - r_size // 4] = (
arr[end - r_size // 4],
arr[end - 1],
)
if r_size > NINTHER_THRESHOLD:
arr[pivot_pos + 2], arr[pivot_pos + (2 + r_size // 4)] = (
arr[pivot_pos + (2 + r_size // 4)],
arr[pivot_pos + 2],
)
arr[pivot_pos + 3], arr[pivot_pos + (3 + r_size // 4)] = (
arr[pivot_pos + (3 + r_size // 4)],
arr[pivot_pos + 3],
)
arr[end - 2], arr[end - (1 + r_size // 4)] = (
arr[end - (1 + r_size // 4)],
arr[end - 2],
)
arr[end - 3], arr[end - (2 + r_size // 4)] = (
arr[end - (2 + r_size // 4)],
arr[end - 3],
)
else:
if (
already_partitioned
and _partial_insertion_sort(arr, begin, pivot_pos)
and _partial_insertion_sort(arr, pivot_pos + 1, end)
):
return
_pdq_sort(arr, begin, pivot_pos, bad_allowed, leftmost)
begin = pivot_pos + 1
leftmost = False
# C stubs for vectorized quicksort from Highway
from C import cnp_sort_int16(cobj, int)
from C import cnp_sort_uint16(cobj, int)
from C import cnp_sort_int32(cobj, int)
from C import cnp_sort_uint32(cobj, int)
from C import cnp_sort_int64(cobj, int)
from C import cnp_sort_uint64(cobj, int)
from C import cnp_sort_int16(cobj, int)
from C import cnp_sort_uint128(cobj, int)
from C import cnp_sort_float32(cobj, int)
from C import cnp_sort_float64(cobj, int)
def quicksort(start: Ptr[T], n: int, T: type):
if T is int:
cnp_sort_int64(start.as_byte(), n)
elif T is i16:
cnp_sort_int16(start.as_byte(), n)
elif T is u16:
cnp_sort_uint16(start.as_byte(), n)
elif T is i32:
cnp_sort_int32(start.as_byte(), n)
elif T is u32:
cnp_sort_uint32(start.as_byte(), n)
elif T is i64:
cnp_sort_int64(start.as_byte(), n)
elif T is u64:
cnp_sort_uint64(start.as_byte(), n)
elif T is float32:
cnp_sort_float32(start.as_byte(), n)
elif T is float:
cnp_sort_float64(start.as_byte(), n)
else:
_pdq_sort(start, 0, n, _floor_log2(n), True)
def _apartial_insertion_sort(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type):
if begin == end:
return True
limit = 0
cur = begin + 1
while cur != end:
if limit > PARTIAL_INSERTION_SORT_LIMIT:
return False
sift = cur
sift_1 = cur - 1
if _less(arr[tosort[sift]], arr[tosort[sift_1]]):
itmp = tosort[sift]
tmp = arr[itmp]
while True:
tosort[sift] = tosort[sift_1]
sift -= 1
sift_1 -= 1
if sift == begin or not _less(tmp, arr[tosort[sift_1]]):
break
tosort[sift] = itmp
limit += cur - sift
cur += 1
return True
def _apartition_left(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type):
ipivot = tosort[begin]
pivot = arr[ipivot]
first = begin
last = end
while True:
last -= 1
if not _less(pivot, arr[tosort[last]]):
break
if last + 1 == end:
while first < last:
first += 1
if _less(pivot, arr[tosort[first]]):
break
else:
while True:
first += 1
if _less(pivot, arr[tosort[first]]):
break
while first < last:
tosort[first], tosort[last] = tosort[last], tosort[first]
while True:
last -= 1
if not _less(pivot, arr[tosort[last]]):
break
while True:
first += 1
if _less(pivot, arr[tosort[first]]):
break
pivot_pos = last
tosort[begin] = tosort[pivot_pos]
tosort[pivot_pos] = ipivot
return pivot_pos
def _apartition_right(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type):
ipivot = tosort[begin]
pivot = arr[ipivot]
first = begin
last = end
while True:
first += 1
if not _less(arr[tosort[first]], pivot):
break
if first - 1 == begin:
while first < last:
last -= 1
if _less(arr[tosort[last]], pivot):
break
else:
while True:
last -= 1
if _less(arr[tosort[last]], pivot):
break
already_partitioned = 0
if first >= last:
already_partitioned = 1
while first < last:
tosort[first], tosort[last] = tosort[last], tosort[first]
while True:
first += 1
if not _less(arr[tosort[first]], pivot):
break
while True:
last -= 1
if _less(arr[tosort[last]], pivot):
break
pivot_pos = first - 1
tosort[begin] = tosort[pivot_pos]
tosort[pivot_pos] = ipivot
return (pivot_pos, already_partitioned)
def _asort2(arr: Ptr[T], tosort: Ptr[int], i: int, j: int, T: type):
if _less(arr[tosort[j]], arr[tosort[i]]):
tosort[i], tosort[j] = tosort[j], tosort[i]
def _asort3(arr: Ptr[T], tosort: Ptr[int], i: int, j: int, k: int, T: type):
_asort2(arr, tosort, i, j)
_asort2(arr, tosort, j, k)
_asort2(arr, tosort, i, j)
def _apdq_sort(
arr: Ptr[T],
tosort: Ptr[int],
begin: int,
end: int,
bad_allowed: int,
leftmost: bool,
T: type
):
while True:
size = end - begin
if size < INSERTION_SORT_THRESHOLD:
ainsertionsort(arr, tosort + begin, size)
return
size_2 = size // 2
if size > NINTHER_THRESHOLD:
_asort3(arr, tosort, begin, begin + size_2, end - 1)
_asort3(arr, tosort, begin + 1, begin + (size_2 - 1), end - 2)
_asort3(arr, tosort, begin + 2, begin + (size_2 + 1), end - 3)
_asort3(
arr, tosort, begin + (size_2 - 1), begin + size_2, begin + (size_2 + 1)
)
tosort[begin], tosort[begin + size_2] = tosort[begin + size_2], tosort[begin]
else:
_asort3(arr, tosort, begin + size_2, begin, end - 1)
if not leftmost and not _less(arr[tosort[begin - 1]], arr[tosort[begin]]):
begin = _apartition_left(arr, tosort, begin, end) + 1
continue
part_result = _apartition_right(arr, tosort, begin, end)
pivot_pos = part_result[0]
already_partitioned = part_result[1] == 1
l_size = pivot_pos - begin
r_size = end - (pivot_pos + 1)
highly_unbalanced = (l_size < (size // 8)) or (r_size < (size // 8))
if highly_unbalanced:
bad_allowed -= 1
if bad_allowed == 0:
aheapsort(arr, tosort + begin, end - begin)
return
if l_size >= INSERTION_SORT_THRESHOLD:
tosort[begin], tosort[begin + l_size // 4] = (
tosort[begin + l_size // 4],
tosort[begin],
)
tosort[pivot_pos - 1], tosort[pivot_pos - l_size // 4] = (
tosort[pivot_pos - l_size // 4],
tosort[pivot_pos - 1],
)
if l_size > NINTHER_THRESHOLD:
tosort[begin + 1], tosort[begin + (l_size // 4 + 1)] = (
tosort[begin + (l_size // 4 + 1)],
tosort[begin + 1],
)
tosort[begin + 2], tosort[begin + (l_size // 4 + 2)] = (
tosort[begin + (l_size // 4 + 2)],
tosort[begin + 2],
)
tosort[pivot_pos - 2], tosort[pivot_pos - (l_size // 4 + 1)] = (
tosort[pivot_pos - (l_size // 4 + 1)],
tosort[pivot_pos - 2],
)
tosort[pivot_pos - 3], tosort[pivot_pos - (l_size // 4 + 2)] = (
tosort[pivot_pos - (l_size // 4 + 2)],
tosort[pivot_pos - 3],
)
if r_size >= INSERTION_SORT_THRESHOLD:
tosort[pivot_pos + 1], tosort[pivot_pos + (1 + r_size // 4)] = (
tosort[pivot_pos + (1 + r_size // 4)],
tosort[pivot_pos + 1],
)
tosort[end - 1], tosort[end - r_size // 4] = (
tosort[end - r_size // 4],
tosort[end - 1],
)
if r_size > NINTHER_THRESHOLD:
tosort[pivot_pos + 2], tosort[pivot_pos + (2 + r_size // 4)] = (
tosort[pivot_pos + (2 + r_size // 4)],
tosort[pivot_pos + 2],
)
tosort[pivot_pos + 3], tosort[pivot_pos + (3 + r_size // 4)] = (
tosort[pivot_pos + (3 + r_size // 4)],
tosort[pivot_pos + 3],
)
tosort[end - 2], tosort[end - (1 + r_size // 4)] = (
tosort[end - (1 + r_size // 4)],
tosort[end - 2],
)
tosort[end - 3], tosort[end - (2 + r_size // 4)] = (
tosort[end - (2 + r_size // 4)],
tosort[end - 3],
)
else:
if (
already_partitioned
and _apartial_insertion_sort(arr, tosort, begin, pivot_pos)
and _apartial_insertion_sort(arr, tosort, pivot_pos + 1, end)
):
return
_apdq_sort(arr, tosort, begin, pivot_pos, bad_allowed, leftmost)
begin = pivot_pos + 1
leftmost = False
def aquicksort(start: Ptr[T], tosort: Ptr[int], n: int, T: type):
_apdq_sort(start, tosort, 0, n, _floor_log2(n), True)
#########
# Radix #
#########
def key_of(x: UT, T: type, UT: type):
if T is UT:
return x
else:
return x ^ (util.cast(1, UT) << util.cast(util.sizeof(UT) * 8 - 1, UT))
def nth_byte(key: T, l: int, T: type):
return int(key >> util.cast(l << 3, T)) & 0xFF
def _radixsort0(start: Ptr[UT], aux: Ptr[UT], num: int, T: type, UT: type):
m = util.sizeof(UT)
n = (1 << 8)
ncnt = m * n
cnt = Ptr[int](ncnt) # note: compiler should put this on the stack
str.memset(cnt.as_byte(), byte(0), ncnt * util.sizeof(int))
key0 = key_of(start[0], T=T, UT=UT)
for i in range(num):
k = key_of(start[i], T=T, UT=UT)
for l in range(m):
cnt[l*n + nth_byte(k, l)] += 1
ncols = 0
cols = Ptr[u8](m) # again, compiler should put on stack
for l in range(m):
if cnt[l*n + nth_byte(key0, l)] != num:
cols[ncols] = u8(l)
ncols += 1
for l in range(ncols):
a = 0
for i in range(256):
b = cnt[int(cols[l])*n + i]
cnt[int(cols[l])*n + i] = a
a += b
for l in range(ncols):
for i in range(num):
k = key_of(start[i], T=T, UT=UT)
q = cnt + (int(cols[l])*n + nth_byte(k, int(cols[l])))
dst = q[0]
q[0] += 1
aux[dst] = start[i]
temp = aux
aux = start
start = temp
return start
def _radixsort(start: Ptr[UT], num: int, T: type, UT: type):
if num < 2:
return
all_sorted = True
k1 = key_of(start[0], T=T, UT=UT)
for i in range(1, num):
k2 = key_of(start[i], T=T, UT=UT)
if k1 > k2:
all_sorted = False
break
k1 = k2
if all_sorted:
return
aux = Ptr[UT](num)
sort = _radixsort0(start, aux, num, T=T)
if sort != start:
str.memcpy(start.as_byte(), sort.as_byte(), num * util.sizeof(UT))
util.free(aux)
def radixsort(start: Ptr[T], num: int, T: type):
def wrap(start: Ptr, num: int, T: type, UT: type):
_radixsort(Ptr[UT](start.as_byte()), num, T=T, UT=UT)
if T is int:
wrap(start, num, int, u64)
elif isinstance(T, Int):
wrap(start, num, T, UInt[T.N])
elif isinstance(T, UInt):
wrap(start, num, T, T)
elif isinstance(T, byte):
wrap(start, num, i8, u8)
else:
raise ValueError("cannot radixsort type '" + T.__name__ + "'")
def _aradixsort0(start: Ptr[UT], aux: Ptr[int], tosort: Ptr[int], num: int, T: type, UT: type):
m = util.sizeof(UT)
n = (1 << 8)
ncnt = m * n
cnt = Ptr[int](ncnt) # note: compiler should put this on the stack
str.memset(cnt.as_byte(), byte(0), ncnt * util.sizeof(int))
key0 = key_of(start[0], T=T, UT=UT)
for i in range(num):
k = key_of(start[i], T=T, UT=UT)
for l in range(m):
cnt[l*n + nth_byte(k, l)] += 1
ncols = 0
cols = Ptr[u8](m) # again, compiler should put on stack
for l in range(m):
if cnt[l*n + nth_byte(key0, l)] != num:
cols[ncols] = u8(l)
ncols += 1
for l in range(ncols):
a = 0
for i in range(256):
b = cnt[int(cols[l])*n + i]
cnt[int(cols[l])*n + i] = a
a += b
for l in range(ncols):
for i in range(num):
k = key_of(start[tosort[i]], T=T, UT=UT)
q = cnt + (int(cols[l])*n + nth_byte(k, int(cols[l])))
dst = q[0]
q[0] += 1
aux[dst] = tosort[i]
temp = aux
aux = tosort
tosort = temp
return tosort
def _aradixsort(start: Ptr[UT], tosort: Ptr[int], num: int, T: type, UT: type):
if num < 2:
return
all_sorted = True
k1 = key_of(start[tosort[0]], T=T, UT=UT)
for i in range(1, num):
k2 = key_of(start[tosort[i]], T=T, UT=UT)
if k1 > k2:
all_sorted = False
break
k1 = k2
if all_sorted:
return
aux = Ptr[int](num)
sort = _aradixsort0(start, aux, tosort, num, T=T)
if sort != tosort:
str.memcpy(tosort.as_byte(), sort.as_byte(), num * util.sizeof(int))
util.free(aux)
def aradixsort(start: Ptr[T], tosort: Ptr[int], num: int, T: type):
def wrap(start: Ptr, tosort: Ptr[int], num: int, T: type, UT: type):
_aradixsort(Ptr[UT](start.as_byte()), tosort, num, T=T, UT=UT)
if T is int:
wrap(start, tosort, num, int, u64)
elif isinstance(T, Int):
wrap(start, tosort, num, T, UInt[T.N])
elif isinstance(T, UInt):
wrap(start, tosort, num, T, T)
elif isinstance(T, byte):
wrap(start, tosort, num, i8, u8)
else:
raise ValueError("cannot radixsort type '" + T.__name__ + "'")
#######
# Tim #
#######
TIMSORT_STACK_SIZE: Static[int] = 128
def _compute_min_run(num: int):
r = 0
while 64 < num:
r |= num & 1
num >>= 1
return num + r
# run = 2-tuple of (start, length)
# buffer = 2-tuple of (pw, size)
class Buffer[T]:
pw: Ptr[T]
size: int
def __init__(self):
self.pw = Ptr[T]()
self.size = 0
def resize(self, new_size: int):
buffer_pw = self.pw
buffer_size = self.size
if new_size <= buffer_size:
return
elif not buffer_pw:
self.pw = Ptr[T](new_size)
else:
self.pw = util.realloc(buffer_pw, new_size, buffer_size)
self.size = new_size
def free(self):
if self.pw:
util.free(self.pw)
self.pw = Ptr[T]()
def _count_run(arr: Ptr[T], l: int, num: int, minrun: int, T: type):
if num - l == 1:
return 1
pl = arr + l
if not _less(pl[1], pl[0]):
pi = pl + 1
while pi < arr + (num - 1) and not _less(pi[1], pi[0]):
pi += 1
else:
pi = pl + 1
while pi < arr + (num - 1) and _less(pi[1], pi[0]):
pi += 1
pj = pl
pr = pi
while pj < pr:
pj[0], pr[0] = pr[0], pj[0]
pj += 1
pr -= 1
pi += 1
sz = pi - pl
if sz < minrun:
if l + minrun < num:
sz = minrun
else:
sz = num - l
pr = pl + sz
while pi < pr:
vc = pi[0]
pj = pi
while pl < pj and _less(vc, pj[-1]):
pj[0] = pj[-1]
pj -= 1
pj[0] = vc
pi += 1
return sz
def _merge_left(p1: Ptr[T], l1: int, p2: Ptr[T], l2: int, p3: Ptr[T], T: type):
end = p2 + l2
str.memcpy(p3.as_byte(), p1.as_byte(), l1 * util.sizeof(T))
p1[0] = p2[0]
p1 += 1
p2 += 1
while p1 < p2 and p2 < end:
if _less(p2[0], p3[0]):
p1[0] = p2[0]
p1 += 1
p2 += 1
else:
p1[0] = p3[0]
p1 += 1
p3 += 1
if p1 != p2:
str.memcpy(p1.as_byte(), p3.as_byte(), (p2 - p1) * util.sizeof(T))
def _merge_right(p1: Ptr[T], l1: int, p2: Ptr[T], l2: int, p3: Ptr[T], T: type):
start = p1 - 1
str.memcpy(p3.as_byte(), p2.as_byte(), l2 * util.sizeof(T))
p1 += l1 - 1
p2 += l2 - 1
p3 += l2 - 1
p2[0] = p1[0]
p2 -= 1
p1 -= 1
while p1 < p2 and start < p1:
if _less(p3[0], p1[0]):
p2[0] = p1[0]
p2 -= 1
p1 -= 1
else:
p2[0] = p3[0]
p2 -= 1
p3 -= 1
if p1 != p2:
ofs = p2 - start
str.memcpy((start + 1).as_byte(), (p3 - ofs + 1).as_byte(), ofs * util.sizeof(T))
def _gallop_right(arr: Ptr[T], size: int, key: T, T: type):
if _less(key, arr[0]):
return 0
last_ofs = 0
ofs = 1
while True:
if size <= ofs or ofs < 0:
ofs = size
break
if _less(key, arr[ofs]):
break
else:
last_ofs = ofs
ofs = (ofs << 1) + 1
while last_ofs + 1 < ofs:
m = last_ofs + ((ofs - last_ofs) >> 1)
if _less(key, arr[m]):
ofs = m
else:
last_ofs = m
return ofs
def _gallop_left(arr: Ptr[T], size: int, key: T, T: type):
if _less(arr[size - 1], key):
return size
last_ofs = 0
ofs = 1
while True:
if size <= ofs or ofs < 0:
ofs = size
break
if _less(arr[size - ofs - 1], key):
break
else:
last_ofs = ofs
ofs = (ofs << 1) + 1
l = size - ofs - 1
r = size - last_ofs - 1
while l + 1 < r:
m = l + ((r - l) >> 1)
if _less(arr[m], key):
l = m
else:
r = m
return r
def _merge_at(arr: Ptr[T],
stack: Ptr[Tuple[int,int]],
at: int,
buffer: Buffer[T],
T: type):
s1, l1 = stack[at]
s2, l2 = stack[at + 1]
k = _gallop_right(arr + s1, l1, arr[s2])
if l1 == k:
return
p1 = arr + (s1 + k)
l1 -= k
p2 = arr + s2
l2 = _gallop_left(arr + s2, l2, arr[s2 - 1])
if l2 < l1:
buffer.resize(l2)
_merge_right(p1, l1, p2, l2, buffer.pw)
else:
buffer.resize(l1)
_merge_left(p1, l1, p2, l2, buffer.pw)
def _try_collapse(arr: Ptr[T],
stack: Ptr[Tuple[int,int]],
stack_ptr: Ptr[int],
buffer: Buffer[T],
T: type):
top = stack_ptr[0]
while 1 < top:
B = stack[top - 2][1]
C = stack[top - 1][1]
if ((2 < top and stack[top - 3][1] <= B + C) or
(3 < top and stack[top - 4][1] <= stack[top - 3][1] + B)):
A = stack[top - 3][1]
if A <= C:
_merge_at(arr, stack, top - 3, buffer)
s, l = stack[top - 3]
stack[top - 3] = (s, l + B)
stack[top - 2] = stack[top - 1]
top -= 1
else:
_merge_at(arr, stack, top - 2, buffer)
s, l = stack[top - 2]
stack[top - 2] = (s, l + C)
top -= 1
elif 1 < top and B <= C:
_merge_at(arr, stack, top - 2, buffer)
s, l = stack[top - 2]
stack[top - 2] = (s, l + C)
top -= 1
else:
break
stack_ptr[0] = top
def _force_collapse(arr: Ptr[T],
stack: Ptr[Tuple[int,int]],
stack_ptr: Ptr[int],
buffer: Buffer[T],
T: type):
top = stack_ptr[0]
while 2 < top:
if stack[top - 3][1] <= stack[top - 1][1]:
_merge_at(arr, stack, top - 3, buffer)
_, l1 = stack[top - 2]
s, l2 = stack[top - 3]
stack[top - 3] = (s, l1 + l2)
stack[top - 2] = stack[top - 1]
top -= 1
else:
_merge_at(arr, stack, top - 2, buffer)
_, l1 = stack[top - 1]
s, l2 = stack[top - 2]
stack[top - 2] = (s, l1 + l2)
top -= 1
if 1 < top:
_merge_at(arr, stack, top - 2, buffer)
def timsort(start: Ptr[T], num: int, T: type):
stack = __array__[Tuple[int,int]](TIMSORT_STACK_SIZE)
buffer = Buffer[T]()
stack_ptr = 0
minrun = _compute_min_run(num)
l = 0
while l < num:
n = _count_run(start, l, num, minrun)
stack[stack_ptr] = (l, n)
stack_ptr += 1
_try_collapse(start, stack.ptr, __ptr__(stack_ptr), buffer)
l += n
_force_collapse(start, stack.ptr, __ptr__(stack_ptr), buffer)
buffer.free()
def _acount_run(arr: Ptr[T], tosort: Ptr[int], l: int, num: int, minrun: int, T: type):
if num - l == 1:
return 1
pl = tosort + l
if not _less(arr[pl[1]], arr[pl[0]]):
pi = pl + 1
while pi < tosort + (num - 1) and not _less(arr[pi[1]], arr[pi[0]]):
pi += 1
else:
pi = pl + 1
while pi < tosort + (num - 1) and _less(arr[pi[1]], arr[pi[0]]):
pi += 1
pj = pl
pr = pi
while pj < pr:
pj[0], pr[0] = pr[0], pj[0]
pj += 1
pr -= 1
pi += 1
sz = pi - pl
if sz < minrun:
if l + minrun < num:
sz = minrun
else:
sz = num - l
pr = pl + sz
while pi < pr:
vi = pi[0]
vc = arr[vi]
pj = pi
while pl < pj and _less(vc, arr[pj[-1]]):
pj[0] = pj[-1]
pj -= 1
pj[0] = vi
pi += 1
return sz
def _amerge_left(arr: Ptr[T], p1: Ptr[int], l1: int, p2: Ptr[int], l2: int, p3: Ptr[int], T: type):
end = p2 + l2
str.memcpy(p3.as_byte(), p1.as_byte(), l1 * util.sizeof(int))
p1[0] = p2[0]
p1 += 1
p2 += 1
while p1 < p2 and p2 < end:
if _less(arr[p2[0]], arr[p3[0]]):
p1[0] = p2[0]
p1 += 1
p2 += 1
else:
p1[0] = p3[0]
p1 += 1
p3 += 1
if p1 != p2:
str.memcpy(p1.as_byte(), p3.as_byte(), (p2 - p1) * util.sizeof(int))
def _amerge_right(arr: Ptr[T], p1: Ptr[int], l1: int, p2: Ptr[int], l2: int, p3: Ptr[int], T: type):
start = p1 - 1
str.memcpy(p3.as_byte(), p2.as_byte(), l2 * util.sizeof(int))
p1 += l1 - 1
p2 += l2 - 1
p3 += l2 - 1
p2[0] = p1[0]
p2 -= 1
p1 -= 1
while p1 < p2 and start < p1:
if _less(arr[p3[0]], arr[p1[0]]):
p2[0] = p1[0]
p2 -= 1
p1 -= 1
else:
p2[0] = p3[0]
p2 -= 1
p3 -= 1
if p1 != p2:
ofs = p2 - start
str.memcpy((start + 1).as_byte(), (p3 - ofs + 1).as_byte(), ofs * util.sizeof(int))
def _agallop_right(arr: Ptr[T], tosort: Ptr[int], size: int, key: T, T: type):
if _less(key, arr[tosort[0]]):
return 0
last_ofs = 0
ofs = 1
while True:
if size <= ofs or ofs < 0:
ofs = size
break
if _less(key, arr[tosort[ofs]]):
break
else:
last_ofs = ofs
ofs = (ofs << 1) + 1
while last_ofs + 1 < ofs:
m = last_ofs + ((ofs - last_ofs) >> 1)
if _less(key, arr[tosort[m]]):
ofs = m
else:
last_ofs = m
return ofs
def _agallop_left(arr: Ptr[T], tosort: Ptr[int], size: int, key: T, T: type):
if _less(arr[tosort[size - 1]], key):
return size
last_ofs = 0
ofs = 1
while True:
if size <= ofs or ofs < 0:
ofs = size
break
if _less(arr[tosort[size - ofs - 1]], key):
break
else:
last_ofs = ofs
ofs = (ofs << 1) + 1
l = size - ofs - 1
r = size - last_ofs - 1
while l + 1 < r:
m = l + ((r - l) >> 1)
if _less(arr[tosort[m]], key):
l = m
else:
r = m
return r
def _amerge_at(arr: Ptr[T],
tosort: Ptr[int],
stack: Ptr[Tuple[int,int]],
at: int,
buffer: Buffer[int],
T: type):
s1, l1 = stack[at]
s2, l2 = stack[at + 1]
k = _agallop_right(arr, tosort + s1, l1, arr[tosort[s2]])
if l1 == k:
return
p1 = tosort + (s1 + k)
l1 -= k
p2 = tosort + s2
l2 = _agallop_left(arr, tosort + s2, l2, arr[tosort[s2 - 1]])
if l2 < l1:
buffer.resize(l2)
_amerge_right(arr, p1, l1, p2, l2, buffer.pw)
else:
buffer.resize(l1)
_amerge_left(arr, p1, l1, p2, l2, buffer.pw)
def _atry_collapse(arr: Ptr[T],
tosort: Ptr[int],
stack: Ptr[Tuple[int,int]],
stack_ptr: Ptr[int],
buffer: Buffer[int],
T: type):
top = stack_ptr[0]
while 1 < top:
B = stack[top - 2][1]
C = stack[top - 1][1]
if ((2 < top and stack[top - 3][1] <= B + C) or
(3 < top and stack[top - 4][1] <= stack[top - 3][1] + B)):
A = stack[top - 3][1]
if A <= C:
_amerge_at(arr, tosort, stack, top - 3, buffer)
s, l = stack[top - 3]
stack[top - 3] = (s, l + B)
stack[top - 2] = stack[top - 1]
top -= 1
else:
_amerge_at(arr, tosort, stack, top - 2, buffer)
s, l = stack[top - 2]
stack[top - 2] = (s, l + C)
top -= 1
elif 1 < top and B <= C:
_amerge_at(arr, tosort, stack, top - 2, buffer)
s, l = stack[top - 2]
stack[top - 2] = (s, l + C)
top -= 1
else:
break
stack_ptr[0] = top
def _aforce_collapse(arr: Ptr[T],
tosort: Ptr[int],
stack: Ptr[Tuple[int,int]],
stack_ptr: Ptr[int],
buffer: Buffer[int],
T: type):
top = stack_ptr[0]
while 2 < top:
if stack[top - 3][1] <= stack[top - 1][1]:
_amerge_at(arr, tosort, stack, top - 3, buffer)
_, l1 = stack[top - 2]
s, l2 = stack[top - 3]
stack[top - 3] = (s, l1 + l2)
stack[top - 2] = stack[top - 1]
top -= 1
else:
_amerge_at(arr, tosort, stack, top - 2, buffer)
_, l1 = stack[top - 1]
s, l2 = stack[top - 2]
stack[top - 2] = (s, l1 + l2)
top -= 1
if 1 < top:
_amerge_at(arr, tosort, stack, top - 2, buffer)
def atimsort(start: Ptr[T], tosort: Ptr[int], num: int, T: type):
stack = __array__[Tuple[int,int]](TIMSORT_STACK_SIZE)
buffer = Buffer[int]()
stack_ptr = 0
minrun = _compute_min_run(num)
l = 0
while l < num:
n = _acount_run(start, tosort, l, num, minrun)
stack[stack_ptr] = (l, n)
stack_ptr += 1
_atry_collapse(start, tosort, stack.ptr, __ptr__(stack_ptr), buffer)
l += n
_aforce_collapse(start, tosort, stack.ptr, __ptr__(stack_ptr), buffer)
buffer.free()
##########
# Stable #
##########
def stablesort(start: Ptr[T], num: int, T: type):
if T is int or T is byte or isinstance(T, Int) or isinstance(T, UInt):
return radixsort(start, num)
else:
return timsort(start, num)
def astablesort(start: Ptr[T], tosort: Ptr[int], num: int, T: type):
if T is int or T is byte or isinstance(T, Int) or isinstance(T, UInt):
return aradixsort(start, tosort, num)
else:
return atimsort(start, tosort, num)
#############
# Selection #
#############
MAX_PIVOT_STACK: Static[int] = 50
@tuple
class Sortee:
v: Ptr[T]
T: type
def __call__(self, i: int):
return self.v[i]
def swap(self, i: int, j: int):
v = self.v
tmp = v[i]
v[i] = v[j]
v[j] = tmp
@tuple
class Idx:
def __call__(self, i: int):
return i
@tuple
class ArgSortee:
tosort: Ptr[int]
def __call__(self, i: int):
return self.tosort[i]
def swap(self, i: int, j: int):
v = self.tosort
tmp = v[i]
v[i] = v[j]
v[j] = tmp
@tuple
class ArgIdx:
tosort: Ptr[int]
def __call__(self, i: int):
return self.tosort[i]
def _store_pivot(pivot: int, kth: int, pivots: Ptr[int], npiv: Ptr[int]):
if not pivots:
return
if pivot == kth and npiv[0] == MAX_PIVOT_STACK:
pivots[npiv[0] - 1] = pivot
elif pivot >= kth and npiv[0] < MAX_PIVOT_STACK:
pivots[npiv[0]] = pivot
npiv[0] += 1
def _median3_swap(v: Ptr[T], tosort: Ptr[int], low: int, mid: int, high: int, arg: Static[int], T: type):
if arg:
idx = ArgIdx(tosort)
sortee = ArgSortee(tosort)
else:
idx = Idx()
sortee = Sortee(v)
if _less(v[idx(high)], v[idx(mid)]):
sortee.swap(high, mid)
if _less(v[idx(high)], v[idx(low)]):
sortee.swap(high, low)
if _less(v[idx(low)], v[idx(mid)]):
sortee.swap(low, mid)
sortee.swap(mid, low + 1)
def _median5(v: Ptr[T], tosort: Ptr[int], arg: Static[int], T: type):
if arg:
idx = ArgIdx(tosort)
sortee = ArgSortee(tosort)
else:
idx = Idx()
sortee = Sortee(v)
if _less(v[idx(1)], v[idx(0)]):
sortee.swap(1, 0)
if _less(v[idx(4)], v[idx(3)]):
sortee.swap(4, 3)
if _less(v[idx(3)], v[idx(0)]):
sortee.swap(3, 0)
if _less(v[idx(4)], v[idx(1)]):
sortee.swap(4, 1)
if _less(v[idx(2)], v[idx(1)]):
sortee.swap(2, 1)
if _less(v[idx(3)], v[idx(2)]):
if _less(v[idx(3)], v[idx(1)]):
return 1
else:
return 3
else:
return 2
def _unguarded_partition(v: Ptr[T], tosort: Ptr[int], pivot: T, ll: int, hh: int, arg: Static[int], T: type):
if arg:
idx = ArgIdx(tosort)
sortee = ArgSortee(tosort)
else:
idx = Idx()
sortee = Sortee(v)
while True:
while True:
ll += 1
if not _less(v[idx(ll)], pivot):
break
while True:
hh -= 1
if not _less(pivot, v[idx(hh)]):
break
if hh < ll:
break
sortee.swap(ll, hh)
return ll, hh
def _median_of_median5(v: Ptr[T], tosort: Ptr[int], num: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[int], isel, T: type):
if arg:
idx = ArgIdx(tosort)
sortee = ArgSortee(tosort)
else:
idx = Idx()
sortee = Sortee(v)
right = num - 1
nmed = (right + 1) // 5
subleft = 0
for i in range(nmed):
m = _median5(v + (0 if arg else subleft), tosort + (subleft if arg else 0), arg)
sortee.swap(subleft + m, i)
subleft += 5
if nmed > 2:
isel(v, tosort, nmed, nmed // 2, pivots, npiv, arg)
return nmed // 2
def _dumb_select(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, arg: Static[int], T: type):
if arg:
idx = ArgIdx(tosort)
sortee = ArgSortee(tosort)
else:
idx = Idx()
sortee = Sortee(v)
for i in range(kth + 1):
minidx = i
minval = v[idx(i)]
for k in range(i + 1, num):
if _less(v[idx(k)], minval):
minidx = k
minval = v[idx(k)]
sortee.swap(i, minidx)
def _msb(unum: u64):
depth_limit = 0
while unum >> u64(1):
depth_limit += 1
unum >>= u64(1)
return depth_limit
def _introselect(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[int], T: type):
if arg:
idx = ArgIdx(tosort)
sortee = ArgSortee(tosort)
else:
idx = Idx()
sortee = Sortee(v)
low = 0
high = num - 1
if not npiv:
pivots = Ptr[int]()
while pivots and npiv[0] > 0:
if pivots[npiv[0] - 1] > kth:
high = pivots[npiv[0] - 1] - 1
break
elif pivots[npiv[0] - 1] == kth:
return
low = pivots[npiv[0] - 1] + 1
npiv[0] -= 1
if kth - low < 3:
_dumb_select(v + (0 if arg else low), tosort + (low if arg else 0), high - low + 1, kth - low, arg)
_store_pivot(kth, kth, pivots, npiv)
return
elif (T is float or T is float32) and kth == num - 1:
maxidx = low
maxval = v[idx(low)]
for k in range(low + 1, num):
if not _less(v[idx(k)], maxval):
maxidx = k
maxval = v[idx(k)]
sortee.swap(kth, maxidx)
return
depth_limit = _msb(u64(num)) * 2
while low + 1 < high:
ll = low + 1
hh = high
if depth_limit > 0 or hh - ll < 5:
mid = low + (high - low) // 2
_median3_swap(v, tosort, low, mid, high, arg)
else:
mid = ll + _median_of_median5(v + (0 if arg else ll), tosort + (ll if arg else 0), hh - ll, Ptr[int](), Ptr[int](), arg, _introselect)
sortee.swap(mid, low)
ll -= 1
hh += 1
depth_limit -= 1
ll, hh = _unguarded_partition(v, tosort, v[idx(low)], ll, hh, arg)
sortee.swap(low, hh)
if hh != kth:
_store_pivot(hh, kth, pivots, npiv)
if hh >= kth:
high = hh - 1
if hh <= kth:
low = ll
if high == low + 1:
if _less(v[idx(high)], v[idx(low)]):
sortee.swap(high, low)
_store_pivot(kth, kth, pivots, npiv)
def _partition(v: Ptr[T], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], T: type):
_introselect(v, Ptr[int](), num, kth, pivots, npiv, arg=False)
def _argpartition(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], T: type):
_introselect(v, tosort, num, kth, pivots, npiv, arg=True)
def _partition_compact(a: ndarray, kth: Ptr[int], nkth: int, axis: int):
n = a.shape[axis]
pivots = __array__[int](MAX_PIVOT_STACK)
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
idx1 = util.tuple_insert(idx, axis, 0)
p = a._ptr(idx1)
npiv = 0
for i in range(nkth):
_partition(p, n, kth[i], pivots.ptr, __ptr__(npiv))
def _partition_buffered(a: ndarray, kth: Ptr[int], nkth: int, axis: int):
n = a.shape[axis]
st = a.strides[axis]
buf = Ptr[a.dtype](n)
pivots = __array__[int](MAX_PIVOT_STACK)
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
npiv = 0
idx1 = util.tuple_insert(idx, axis, 0)
p = a._ptr(idx1).as_byte()
p0 = p
for i in range(n):
buf[i] = (Ptr[a.dtype](p))[0]
p += st
for i in range(nkth):
_partition(buf, n, kth[i], pivots.ptr, __ptr__(npiv))
p = p0
for i in range(n):
Ptr[a.dtype](p)[0] = buf[i]
p += st
util.free(buf)
def _fix_kth(kth, n: int):
kth = asarray(kth, order='C')
if kth.dtype is not int:
compile_error("Partition index must be integer")
if kth.ndim > 1:
compile_error("kth array must have dimension <= 1")
pkth = kth.data
nkth = kth.size
for i in range(nkth):
if pkth[i] < 0:
pkth[i] += n
if pkth[i] < 0 or pkth[i] >= n:
raise ValueError(f"kth(={pkth[i]}) out of bounds ({n})")
if kth.ndim == 1:
kth.sort()
return pkth, nkth
def _partition_helper(a: ndarray, kth, axis, kind: str, force_compact: Static[int] = False):
if axis is None:
_partition_helper(a.flatten(), kth=kth, axis=-1, kind=kind, force_compact=force_compact)
return
if kind != 'introselect':
raise ValueError(f"select kind must be 'introselect' (got {repr(kind)})")
axis = util.normalize_axis_index(axis, a.ndim)
n = a.shape[axis]
pkth, nkth = _fix_kth(kth, n)
if force_compact:
_partition_compact(a, pkth, nkth, axis)
else:
if a.strides[axis] == a.itemsize:
_partition_compact(a, pkth, nkth, axis)
else:
_partition_buffered(a, pkth, nkth, axis)
def partition(a, kth, axis = -1, kind: str = 'introselect'):
if isinstance(a, ndarray):
b = a.copy(order='C')
else:
b = asarray(a)
_partition_helper(b, kth=kth, axis=axis, kind=kind, force_compact=(b.ndim == 1))
return b
def argpartition(a: ndarray, kth, axis = -1, kind: str = 'introselect'):
if axis is None:
return _argpartition(a.flatten(), kth=kth, axis=-1, kind=kind)
if kind != 'introselect':
raise ValueError(f"select kind must be 'introselect' (got {repr(kind)})")
axis = util.normalize_axis_index(axis, a.ndim)
n = a.shape[axis]
pkth, nkth = _fix_kth(kth, n)
b = empty(a.shape, int)
a_stride = a.strides[axis]
b_stride = b.strides[axis]
a_compact = (a_stride == a.itemsize)
b_compact = (b_stride == b.itemsize)
a_buf = Ptr[a.dtype]() if a_compact else Ptr[a.dtype](n)
b_buf = Ptr[b.dtype]() if b_compact else Ptr[b.dtype](n)
pivots = __array__[int](MAX_PIVOT_STACK)
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
npiv = 0
idx1 = util.tuple_insert(idx, axis, 0)
pa = a._ptr(idx1)
pb = b._ptr(idx1)
qa = pa.as_byte()
qb = pb.as_byte()
if not a_compact:
for i in range(n):
a_buf[i] = (Ptr[a.dtype](qa))[0]
qa += a_stride
if not b_compact:
for i in range(n):
b_buf[i] = i
else:
for i in range(n):
pb[i] = i
start = pa if a_compact else a_buf
tosort = pb if b_compact else b_buf
for i in range(nkth):
_argpartition(start, tosort, n, pkth[i], pivots.ptr, __ptr__(npiv))
if not a_compact:
qa = pa.as_byte()
for i in range(n):
Ptr[a.dtype](qa)[0] = a_buf[i]
qa += a_stride
if not b_compact:
qb = pb.as_byte()
for i in range(n):
Ptr[b.dtype](qb)[0] = b_buf[i]
qb += b_stride
if not a_compact:
util.free(a_buf)
if not b_compact:
util.free(b_buf)
return b
###########
# Lexsort #
###########
def _lexsort(sort_keys, n: int, axis: int):
sort_keys0 = asarray(sort_keys[0])
nd: Static[int] = sort_keys0.ndim
if nd == 0 and (axis == 0 or axis == -1):
pass
else:
axis = util.normalize_axis_index(axis, nd)
if nd == 0:
return 0
if sort_keys0.size <= 1:
return zeros(sort_keys[0].shape, int)
ret = empty(sort_keys0.shape, int)
rstride = ret.strides[axis]
maxelsize = sort_keys0.itemsize
needcopy = (rstride != ret.itemsize)
N = sort_keys0.shape[axis]
for j in range(n):
key = sort_keys[j]
needcopy |= (key.strides[axis] != key.itemsize)
maxelsize = max(maxelsize, key.itemsize)
valbuffer = cobj()
indbuffer = Ptr[int]()
if needcopy:
valbufsize = N * maxelsize
if valbufsize == 0:
valbufsize = 1
valbuffer = cobj(valbufsize)
indbufsize = N * ret.itemsize
indbuffer = Ptr[int](indbufsize)
for idx in util.multirange(util.tuple_delete(ret.shape, axis)):
idx1 = util.tuple_insert(idx, axis, 0)
q = ret._ptr(idx1).as_byte()
for i in range(N):
indbuffer[i] = i
for j in range(n):
key = sort_keys[j]
buf = Ptr[key.dtype](valbuffer.as_byte())
p = key._ptr(idx1).as_byte()
s = key.strides[axis]
for i in range(N):
buf[i] = (Ptr[key.dtype](p))[0]
p += s
astablesort(buf, indbuffer, N)
for i in range(N):
z = Ptr[int](q)
z[0] = indbuffer[i]
q += rstride
else:
for idx in util.multirange(util.tuple_delete(ret.shape, axis)):
idx1 = util.tuple_insert(idx, axis, 0)
q = ret._ptr(idx1)
for i in range(N):
q[i] = i
for j in range(n):
p = sort_keys[j]._ptr(idx1)
astablesort(p, q, N)
if valbuffer:
util.free(valbuffer)
if indbuffer:
util.free(indbuffer)
return ret
def lexsort(keys, axis: int = -1):
n = len(keys)
if n == 0:
raise ValueError("need sequence of keys with len > 0 in lexsort")
if isinstance(keys, Tuple):
sort_keys = tuple(asarray(key) for key in keys)
for i in staticrange(1, staticlen(sort_keys)):
if sort_keys[i].ndim != sort_keys[0].ndim:
compile_error("all keys need to be the same shape")
if sort_keys[i].shape != sort_keys[0].shape:
raise ValueError("all keys need to be the same shape")
return _lexsort(sort_keys, n, axis)
elif isinstance(keys, List):
if isinstance(keys[0], ndarray):
sort_keys = keys
else:
sort_keys = [asarray(key) for key in keys]
for i in range(1, n):
if sort_keys[i].shape != sort_keys[0].shape:
raise ValueError("all keys need to be the same shape")
return _lexsort(sort_keys, n, axis)
elif isinstance(keys, ndarray):
return _lexsort(keys, n, axis)
else:
compile_error("keys must be an ndarray or a tuple")
#######
# API #
#######
def _sort_compact(a: ndarray, axis: int, sorter):
n = a.shape[axis]
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
idx1 = util.tuple_insert(idx, axis, 0)
p = a._ptr(idx1)
sorter(p, n)
def _sort_buffered(a: ndarray, axis: int, sorter):
n = a.shape[axis]
st = a.strides[axis]
buf = Ptr[a.dtype](n)
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
idx1 = util.tuple_insert(idx, axis, 0)
p = a._ptr(idx1).as_byte()
p0 = p
for i in range(n):
buf[i] = (Ptr[a.dtype](p))[0]
p += st
sorter(buf, n)
p = p0
for i in range(n):
Ptr[a.dtype](p)[0] = buf[i]
p += st
util.free(buf)
def _sort_dispatch(a: ndarray, axis: int, sorter, force_compact: Static[int] = False):
axis = util.normalize_axis_index(axis, a.ndim)
if force_compact:
_sort_compact(a, axis, sorter)
else:
if a.strides[axis] == a.itemsize:
_sort_compact(a, axis, sorter)
else:
_sort_buffered(a, axis, sorter)
def _sort(a: ndarray, axis: int, kind: Optional[str], force_compact: Static[int] = False):
if kind is None or kind == 'quicksort' or kind == 'quick':
_sort_dispatch(a, axis, quicksort, force_compact)
elif kind == 'mergesort' or kind == 'merge' or kind == 'stable':
_sort_dispatch(a, axis, stablesort, force_compact)
elif kind == 'heapsort' or kind == 'heap':
_sort_dispatch(a, axis, heapsort, force_compact)
else:
raise ValueError(f"sort kind must be one of 'quick', 'heap', or 'stable' (got {repr(kind)})")
def sort(a, axis = -1, kind: Optional[str] = None):
if axis is None:
return sort(asarray(a).flatten(), axis=-1, kind=kind)
if isinstance(a, ndarray):
b = a.copy(order='C')
else:
b = asarray(a)
_sort(b, axis=axis, kind=kind, force_compact=(b.ndim == 1))
return b
def _asort(a: ndarray, axis: int, sorter):
axis = util.normalize_axis_index(axis, a.ndim)
n = a.shape[axis]
b = empty(a.shape, int)
a_stride = a.strides[axis]
b_stride = b.strides[axis]
a_compact = (a_stride == a.itemsize)
b_compact = (b_stride == b.itemsize)
a_buf = Ptr[a.dtype]() if a_compact else Ptr[a.dtype](n)
b_buf = Ptr[b.dtype]() if b_compact else Ptr[b.dtype](n)
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
idx1 = util.tuple_insert(idx, axis, 0)
pa = a._ptr(idx1)
pb = b._ptr(idx1)
qa = pa.as_byte()
qb = pb.as_byte()
if not a_compact:
for i in range(n):
a_buf[i] = (Ptr[a.dtype](qa))[0]
qa += a_stride
if not b_compact:
for i in range(n):
b_buf[i] = i
else:
for i in range(n):
pb[i] = i
start = pa if a_compact else a_buf
tosort = pb if b_compact else b_buf
sorter(start, tosort, n)
if not a_compact:
qa = pa.as_byte()
for i in range(n):
Ptr[a.dtype](qa)[0] = a_buf[i]
qa += a_stride
if not b_compact:
qb = pb.as_byte()
for i in range(n):
Ptr[b.dtype](qb)[0] = b_buf[i]
qb += b_stride
if not a_compact:
util.free(a_buf)
if not b_compact:
util.free(b_buf)
return b
def argsort(a, axis = -1, kind: Optional[str] = None):
if axis is None:
return argsort(asarray(a).flatten(), axis=-1, kind=kind)
a = asarray(a)
if kind is None or kind == 'quicksort' or kind == 'quick':
return _asort(a, axis, aquicksort)
elif kind == 'mergesort' or kind == 'merge' or kind == 'stable':
return _asort(a, axis, astablesort)
elif kind == 'heapsort' or kind == 'heap':
return _asort(a, axis, aheapsort)
else:
raise ValueError(f"sort kind must be one of 'quick', 'heap', or 'stable' (got {repr(kind)})")
def sort_complex(a):
b = array(a, copy=True)
b.sort()
if not (b.dtype is complex or b.dtype is complex64):
if b.dtype is byte or b.dtype is i8 or b.dtype is u8 or b.dtype is i16 or b.dtype is u16:
return b.astype(complex64)
else:
return b.astype(complex)
else:
return b
@extend
class ndarray:
def sort(self, axis: int = -1, kind: Optional[str] = None):
_sort(self, axis=axis, kind=kind, force_compact=False)
def argsort(self, axis: int = -1, kind: Optional[str] = None):
return argsort(self, axis=axis, kind=kind)
def partition(self, kth, axis: int = -1, kind: str = 'introselect'):
_partition_helper(self, kth=kth, axis=axis, kind=kind, force_compact=False)
def argpartition(self, kth, axis: int = -1, kind: str = 'introselect'):
return argpartition(self, kth=kth, axis=axis, kind=kind)
# TODO: support 'order' argument on sort, argsort, ndarray.sort, partition and argpartition