mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
2184 lines
58 KiB
Python
2184 lines
58 KiB
Python
|
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
||
|
|
||
|
from .ndarray import ndarray
|
||
|
from .routines import array, asarray, empty, zeros
|
||
|
import util
|
||
|
|
||
|
##########
|
||
|
# Common #
|
||
|
##########
|
||
|
|
||
|
def _less(a: T, b: T, T: type):
|
||
|
if T is float or T is float32:
|
||
|
return a < b or (b != b and a == a)
|
||
|
elif T is complex or T is complex64:
|
||
|
if a.real < b.real:
|
||
|
return a.imag == a.imag or b.imag != b.imag
|
||
|
elif a.real > b.real:
|
||
|
return b.imag != b.imag and a.imag == a.imag
|
||
|
elif a.real == b.real or (a.real != a.real and b.real != b.real):
|
||
|
return a.imag < b.imag or (b.imag != b.imag and a.imag == a.imag)
|
||
|
else:
|
||
|
return b.real != b.real
|
||
|
else:
|
||
|
return a < b
|
||
|
|
||
|
#############
|
||
|
# Insertion #
|
||
|
#############
|
||
|
|
||
|
def insertionsort(start: Ptr[T], n: int, T: type):
|
||
|
i = 1
|
||
|
while i < n:
|
||
|
x = start[i]
|
||
|
j = i - 1
|
||
|
while j >= 0 and _less(x, start[j]):
|
||
|
start[j + 1] = start[j]
|
||
|
j -= 1
|
||
|
start[j + 1] = x
|
||
|
i += 1
|
||
|
|
||
|
def ainsertionsort(v: Ptr[T], tosort: Ptr[int], n: int, T: type):
|
||
|
i = 1
|
||
|
while i < n:
|
||
|
x = tosort[i]
|
||
|
j = i - 1
|
||
|
while j >= 0 and _less(v[x], v[tosort[j]]):
|
||
|
tosort[j + 1] = tosort[j]
|
||
|
j -= 1
|
||
|
tosort[j + 1] = x
|
||
|
i += 1
|
||
|
|
||
|
########
|
||
|
# Heap #
|
||
|
########
|
||
|
|
||
|
def heapsort(start: Ptr[T], n: int, T: type):
|
||
|
a = start - 1
|
||
|
|
||
|
for l in range(n >> 1, 0, -1):
|
||
|
tmp = a[l]
|
||
|
i = l
|
||
|
j = l << 1
|
||
|
while j <= n:
|
||
|
if j < n and _less(a[j], a[j + 1]):
|
||
|
j += 1
|
||
|
if _less(tmp, a[j]):
|
||
|
a[i] = a[j]
|
||
|
i = j
|
||
|
j += j
|
||
|
else:
|
||
|
break
|
||
|
a[i] = tmp
|
||
|
|
||
|
while n > 1:
|
||
|
tmp = a[n]
|
||
|
a[n] = a[1]
|
||
|
n -= 1
|
||
|
i = 1
|
||
|
j = 2
|
||
|
while j <= n:
|
||
|
if j < n and _less(a[j], a[j + 1]):
|
||
|
j += 1
|
||
|
if _less(tmp, a[j]):
|
||
|
a[i] = a[j]
|
||
|
i = j
|
||
|
j += j
|
||
|
else:
|
||
|
break
|
||
|
a[i] = tmp
|
||
|
|
||
|
def aheapsort(vv: Ptr[T], tosort: Ptr[int], n: int, T: type):
|
||
|
v = vv
|
||
|
a = tosort - 1
|
||
|
|
||
|
for l in range(n >> 1, 0, -1):
|
||
|
tmp = a[l]
|
||
|
i = l
|
||
|
j = l << 1
|
||
|
while j <= n:
|
||
|
if j < n and _less(v[a[j]], v[a[j + 1]]):
|
||
|
j += 1
|
||
|
if _less(v[tmp], v[a[j]]):
|
||
|
a[i] = a[j]
|
||
|
i = j
|
||
|
j += j
|
||
|
else:
|
||
|
break
|
||
|
a[i] = tmp
|
||
|
|
||
|
while n > 1:
|
||
|
tmp = a[n]
|
||
|
a[n] = a[1]
|
||
|
n -= 1
|
||
|
i = 1
|
||
|
j = 2
|
||
|
while j <= n:
|
||
|
if j < n and _less(v[a[j]], v[a[j + 1]]):
|
||
|
j += 1
|
||
|
if _less(v[tmp], v[a[j]]):
|
||
|
a[i] = a[j]
|
||
|
i = j
|
||
|
j += j
|
||
|
else:
|
||
|
break
|
||
|
a[i] = tmp
|
||
|
|
||
|
#########
|
||
|
# Merge #
|
||
|
#########
|
||
|
|
||
|
PYA_QS_STACK : Static[int] = 100
|
||
|
SMALL_QUICKSORT: Static[int] = 15
|
||
|
SMALL_MERGESORT: Static[int] = 20
|
||
|
SMALL_STRING : Static[int] = 16
|
||
|
|
||
|
def _mergesort0(pl: Ptr[T], pr: Ptr[T], pw: Ptr[T], T: type):
|
||
|
if pr - pl > SMALL_MERGESORT:
|
||
|
pm = pl + ((pr - pl) >> 1)
|
||
|
_mergesort0(pl, pm, pw)
|
||
|
_mergesort0(pm, pr, pw)
|
||
|
|
||
|
pi = pw
|
||
|
pj = pl
|
||
|
while pj < pm:
|
||
|
pi[0] = pj[0]
|
||
|
pi += 1
|
||
|
pj += 1
|
||
|
|
||
|
pi = pw + (pm - pl)
|
||
|
pj = pw
|
||
|
pk = pl
|
||
|
|
||
|
while pj < pi and pm < pr:
|
||
|
if _less(pm[0], pj[0]):
|
||
|
pk[0] = pm[0]
|
||
|
pk += 1
|
||
|
pm += 1
|
||
|
else:
|
||
|
pk[0] = pj[0]
|
||
|
pk += 1
|
||
|
pj += 1
|
||
|
|
||
|
while pj < pi:
|
||
|
pk[0] = pj[0]
|
||
|
pk += 1
|
||
|
pj += 1
|
||
|
else:
|
||
|
pi = pl + 1
|
||
|
while pi < pr:
|
||
|
vp = pi[0]
|
||
|
pj = pi
|
||
|
pk = pi - 1
|
||
|
while pj > pl and _less(vp, pk[0]):
|
||
|
pj[0] = pk[0]
|
||
|
pj -= 1
|
||
|
pk -= 1
|
||
|
pj[0] = vp
|
||
|
pi += 1
|
||
|
|
||
|
def mergesort(start: Ptr[T], num: int, T: type):
|
||
|
pl = start
|
||
|
pr = pl + num
|
||
|
pw = Ptr[T](num // 2)
|
||
|
_mergesort0(pl, pr, pw)
|
||
|
util.free(pw)
|
||
|
|
||
|
def _amergesort0(pl: Ptr[int], pr: Ptr[int], v: Ptr[T], pw: Ptr[int], T: type):
|
||
|
if pr - pl > SMALL_MERGESORT:
|
||
|
pm = pl + ((pr - pl) >> 1)
|
||
|
_amergesort0(pl, pm, v, pw)
|
||
|
_amergesort0(pm, pr, v, pw)
|
||
|
|
||
|
pi = pw
|
||
|
pj = pl
|
||
|
while pj < pm:
|
||
|
pi[0] = pj[0]
|
||
|
pi += 1
|
||
|
pj += 1
|
||
|
|
||
|
pi = pw + (pm - pl)
|
||
|
pj = pw
|
||
|
pk = pl
|
||
|
|
||
|
while pj < pi and pm < pr:
|
||
|
if _less(v[pm[0]], v[pj[0]]):
|
||
|
pk[0] = pm[0]
|
||
|
pk += 1
|
||
|
pm += 1
|
||
|
else:
|
||
|
pk[0] = pj[0]
|
||
|
pk += 1
|
||
|
pj += 1
|
||
|
|
||
|
while pj < pi:
|
||
|
pk[0] = pj[0]
|
||
|
pk += 1
|
||
|
pj += 1
|
||
|
else:
|
||
|
pi = pl + 1
|
||
|
while pi < pr:
|
||
|
vi = pi[0]
|
||
|
vp = v[vi]
|
||
|
pj = pi
|
||
|
pk = pi - 1
|
||
|
while pj > pl and _less(vp, v[pk[0]]):
|
||
|
pj[0] = pk[0]
|
||
|
pj -= 1
|
||
|
pk -= 1
|
||
|
pj[0] = vi
|
||
|
pi += 1
|
||
|
|
||
|
def amergesort(v: Ptr[T], tosort: Ptr[int], num: int, T: type):
|
||
|
pl = tosort
|
||
|
pr = pl + num
|
||
|
pw = Ptr[int](num // 2)
|
||
|
_amergesort0(pl, pr, v, pw)
|
||
|
util.free(pw)
|
||
|
|
||
|
###############
|
||
|
# Quick (PDQ) #
|
||
|
###############
|
||
|
|
||
|
INSERTION_SORT_THRESHOLD : Static[int] = 24
|
||
|
NINTHER_THRESHOLD : Static[int] = 128
|
||
|
PARTIAL_INSERTION_SORT_LIMIT: Static[int] = 8
|
||
|
|
||
|
def _floor_log2(n: int):
|
||
|
log = 0
|
||
|
while True:
|
||
|
n >>= 1
|
||
|
if n == 0:
|
||
|
break
|
||
|
log += 1
|
||
|
return log
|
||
|
|
||
|
def _partial_insertion_sort(arr: Ptr[T], begin: int, end: int, T: type):
|
||
|
if begin == end:
|
||
|
return True
|
||
|
|
||
|
limit = 0
|
||
|
cur = begin + 1
|
||
|
while cur != end:
|
||
|
if limit > PARTIAL_INSERTION_SORT_LIMIT:
|
||
|
return False
|
||
|
|
||
|
sift = cur
|
||
|
sift_1 = cur - 1
|
||
|
|
||
|
if _less(arr[sift], arr[sift_1]):
|
||
|
tmp = arr[sift]
|
||
|
|
||
|
while True:
|
||
|
arr[sift] = arr[sift_1]
|
||
|
sift -= 1
|
||
|
sift_1 -= 1
|
||
|
if sift == begin or not _less(tmp, arr[sift_1]):
|
||
|
break
|
||
|
|
||
|
arr[sift] = tmp
|
||
|
limit += cur - sift
|
||
|
|
||
|
cur += 1
|
||
|
|
||
|
return True
|
||
|
|
||
|
def _partition_left(arr: Ptr[T], begin: int, end: int, T: type):
|
||
|
pivot = arr[begin]
|
||
|
first = begin
|
||
|
last = end
|
||
|
|
||
|
while True:
|
||
|
last -= 1
|
||
|
if not _less(pivot, arr[last]):
|
||
|
break
|
||
|
|
||
|
if last + 1 == end:
|
||
|
while first < last:
|
||
|
first += 1
|
||
|
if _less(pivot, arr[first]):
|
||
|
break
|
||
|
else:
|
||
|
while True:
|
||
|
first += 1
|
||
|
if _less(pivot, arr[first]):
|
||
|
break
|
||
|
|
||
|
while first < last:
|
||
|
arr[first], arr[last] = arr[last], arr[first]
|
||
|
while True:
|
||
|
last -= 1
|
||
|
if not _less(pivot, arr[last]):
|
||
|
break
|
||
|
while True:
|
||
|
first += 1
|
||
|
if _less(pivot, arr[first]):
|
||
|
break
|
||
|
|
||
|
pivot_pos = last
|
||
|
arr[begin] = arr[pivot_pos]
|
||
|
arr[pivot_pos] = pivot
|
||
|
|
||
|
return pivot_pos
|
||
|
|
||
|
def _partition_right(arr: Ptr[T], begin: int, end: int, T: type):
|
||
|
pivot = arr[begin]
|
||
|
first = begin
|
||
|
last = end
|
||
|
|
||
|
while True:
|
||
|
first += 1
|
||
|
if not _less(arr[first], pivot):
|
||
|
break
|
||
|
|
||
|
if first - 1 == begin:
|
||
|
while first < last:
|
||
|
last -= 1
|
||
|
if _less(arr[last], pivot):
|
||
|
break
|
||
|
else:
|
||
|
while True:
|
||
|
last -= 1
|
||
|
if _less(arr[last], pivot):
|
||
|
break
|
||
|
|
||
|
already_partitioned = 0
|
||
|
if first >= last:
|
||
|
already_partitioned = 1
|
||
|
|
||
|
while first < last:
|
||
|
arr[first], arr[last] = arr[last], arr[first]
|
||
|
|
||
|
while True:
|
||
|
first += 1
|
||
|
if not _less(arr[first], pivot):
|
||
|
break
|
||
|
|
||
|
while True:
|
||
|
last -= 1
|
||
|
if _less(arr[last], pivot):
|
||
|
break
|
||
|
|
||
|
pivot_pos = first - 1
|
||
|
arr[begin] = arr[pivot_pos]
|
||
|
arr[pivot_pos] = pivot
|
||
|
|
||
|
return (pivot_pos, already_partitioned)
|
||
|
|
||
|
def _sort2(arr: Ptr[T], i: int, j: int, T: type):
|
||
|
if _less(arr[j], arr[i]):
|
||
|
arr[i], arr[j] = arr[j], arr[i]
|
||
|
|
||
|
def _sort3(arr: Ptr[T], i: int, j: int, k: int, T: type):
|
||
|
_sort2(arr, i, j)
|
||
|
_sort2(arr, j, k)
|
||
|
_sort2(arr, i, j)
|
||
|
|
||
|
def _pdq_sort(
|
||
|
arr: Ptr[T],
|
||
|
begin: int,
|
||
|
end: int,
|
||
|
bad_allowed: int,
|
||
|
leftmost: bool,
|
||
|
T: type
|
||
|
):
|
||
|
while True:
|
||
|
size = end - begin
|
||
|
if size < INSERTION_SORT_THRESHOLD:
|
||
|
insertionsort(arr + begin, size)
|
||
|
return
|
||
|
|
||
|
size_2 = size // 2
|
||
|
if size > NINTHER_THRESHOLD:
|
||
|
_sort3(arr, begin, begin + size_2, end - 1)
|
||
|
_sort3(arr, begin + 1, begin + (size_2 - 1), end - 2)
|
||
|
_sort3(arr, begin + 2, begin + (size_2 + 1), end - 3)
|
||
|
_sort3(
|
||
|
arr, begin + (size_2 - 1), begin + size_2, begin + (size_2 + 1)
|
||
|
)
|
||
|
arr[begin], arr[begin + size_2] = arr[begin + size_2], arr[begin]
|
||
|
else:
|
||
|
_sort3(arr, begin + size_2, begin, end - 1)
|
||
|
|
||
|
if not leftmost and not _less(arr[begin - 1], arr[begin]):
|
||
|
begin = _partition_left(arr, begin, end) + 1
|
||
|
continue
|
||
|
|
||
|
part_result = _partition_right(arr, begin, end)
|
||
|
pivot_pos = part_result[0]
|
||
|
already_partitioned = part_result[1] == 1
|
||
|
|
||
|
l_size = pivot_pos - begin
|
||
|
r_size = end - (pivot_pos + 1)
|
||
|
highly_unbalanced = (l_size < (size // 8)) or (r_size < (size // 8))
|
||
|
|
||
|
if highly_unbalanced:
|
||
|
bad_allowed -= 1
|
||
|
if bad_allowed == 0:
|
||
|
heapsort(arr + begin, end - begin)
|
||
|
return
|
||
|
|
||
|
if l_size >= INSERTION_SORT_THRESHOLD:
|
||
|
arr[begin], arr[begin + l_size // 4] = (
|
||
|
arr[begin + l_size // 4],
|
||
|
arr[begin],
|
||
|
)
|
||
|
arr[pivot_pos - 1], arr[pivot_pos - l_size // 4] = (
|
||
|
arr[pivot_pos - l_size // 4],
|
||
|
arr[pivot_pos - 1],
|
||
|
)
|
||
|
|
||
|
if l_size > NINTHER_THRESHOLD:
|
||
|
arr[begin + 1], arr[begin + (l_size // 4 + 1)] = (
|
||
|
arr[begin + (l_size // 4 + 1)],
|
||
|
arr[begin + 1],
|
||
|
)
|
||
|
arr[begin + 2], arr[begin + (l_size // 4 + 2)] = (
|
||
|
arr[begin + (l_size // 4 + 2)],
|
||
|
arr[begin + 2],
|
||
|
)
|
||
|
arr[pivot_pos - 2], arr[pivot_pos - (l_size // 4 + 1)] = (
|
||
|
arr[pivot_pos - (l_size // 4 + 1)],
|
||
|
arr[pivot_pos - 2],
|
||
|
)
|
||
|
arr[pivot_pos - 3], arr[pivot_pos - (l_size // 4 + 2)] = (
|
||
|
arr[pivot_pos - (l_size // 4 + 2)],
|
||
|
arr[pivot_pos - 3],
|
||
|
)
|
||
|
|
||
|
if r_size >= INSERTION_SORT_THRESHOLD:
|
||
|
arr[pivot_pos + 1], arr[pivot_pos + (1 + r_size // 4)] = (
|
||
|
arr[pivot_pos + (1 + r_size // 4)],
|
||
|
arr[pivot_pos + 1],
|
||
|
)
|
||
|
arr[end - 1], arr[end - r_size // 4] = (
|
||
|
arr[end - r_size // 4],
|
||
|
arr[end - 1],
|
||
|
)
|
||
|
|
||
|
if r_size > NINTHER_THRESHOLD:
|
||
|
arr[pivot_pos + 2], arr[pivot_pos + (2 + r_size // 4)] = (
|
||
|
arr[pivot_pos + (2 + r_size // 4)],
|
||
|
arr[pivot_pos + 2],
|
||
|
)
|
||
|
arr[pivot_pos + 3], arr[pivot_pos + (3 + r_size // 4)] = (
|
||
|
arr[pivot_pos + (3 + r_size // 4)],
|
||
|
arr[pivot_pos + 3],
|
||
|
)
|
||
|
arr[end - 2], arr[end - (1 + r_size // 4)] = (
|
||
|
arr[end - (1 + r_size // 4)],
|
||
|
arr[end - 2],
|
||
|
)
|
||
|
arr[end - 3], arr[end - (2 + r_size // 4)] = (
|
||
|
arr[end - (2 + r_size // 4)],
|
||
|
arr[end - 3],
|
||
|
)
|
||
|
else:
|
||
|
if (
|
||
|
already_partitioned
|
||
|
and _partial_insertion_sort(arr, begin, pivot_pos)
|
||
|
and _partial_insertion_sort(arr, pivot_pos + 1, end)
|
||
|
):
|
||
|
return
|
||
|
|
||
|
_pdq_sort(arr, begin, pivot_pos, bad_allowed, leftmost)
|
||
|
begin = pivot_pos + 1
|
||
|
leftmost = False
|
||
|
|
||
|
# C stubs for vectorized quicksort from Highway
|
||
|
from C import cnp_sort_int16(cobj, int)
|
||
|
from C import cnp_sort_uint16(cobj, int)
|
||
|
from C import cnp_sort_int32(cobj, int)
|
||
|
from C import cnp_sort_uint32(cobj, int)
|
||
|
from C import cnp_sort_int64(cobj, int)
|
||
|
from C import cnp_sort_uint64(cobj, int)
|
||
|
from C import cnp_sort_int16(cobj, int)
|
||
|
from C import cnp_sort_uint128(cobj, int)
|
||
|
from C import cnp_sort_float32(cobj, int)
|
||
|
from C import cnp_sort_float64(cobj, int)
|
||
|
|
||
|
def quicksort(start: Ptr[T], n: int, T: type):
|
||
|
if T is int:
|
||
|
cnp_sort_int64(start.as_byte(), n)
|
||
|
elif T is i16:
|
||
|
cnp_sort_int16(start.as_byte(), n)
|
||
|
elif T is u16:
|
||
|
cnp_sort_uint16(start.as_byte(), n)
|
||
|
elif T is i32:
|
||
|
cnp_sort_int32(start.as_byte(), n)
|
||
|
elif T is u32:
|
||
|
cnp_sort_uint32(start.as_byte(), n)
|
||
|
elif T is i64:
|
||
|
cnp_sort_int64(start.as_byte(), n)
|
||
|
elif T is u64:
|
||
|
cnp_sort_uint64(start.as_byte(), n)
|
||
|
elif T is float32:
|
||
|
cnp_sort_float32(start.as_byte(), n)
|
||
|
elif T is float:
|
||
|
cnp_sort_float64(start.as_byte(), n)
|
||
|
else:
|
||
|
_pdq_sort(start, 0, n, _floor_log2(n), True)
|
||
|
|
||
|
def _apartial_insertion_sort(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type):
|
||
|
if begin == end:
|
||
|
return True
|
||
|
|
||
|
limit = 0
|
||
|
cur = begin + 1
|
||
|
while cur != end:
|
||
|
if limit > PARTIAL_INSERTION_SORT_LIMIT:
|
||
|
return False
|
||
|
|
||
|
sift = cur
|
||
|
sift_1 = cur - 1
|
||
|
|
||
|
if _less(arr[tosort[sift]], arr[tosort[sift_1]]):
|
||
|
itmp = tosort[sift]
|
||
|
tmp = arr[itmp]
|
||
|
|
||
|
while True:
|
||
|
tosort[sift] = tosort[sift_1]
|
||
|
sift -= 1
|
||
|
sift_1 -= 1
|
||
|
if sift == begin or not _less(tmp, arr[tosort[sift_1]]):
|
||
|
break
|
||
|
|
||
|
tosort[sift] = itmp
|
||
|
limit += cur - sift
|
||
|
|
||
|
cur += 1
|
||
|
|
||
|
return True
|
||
|
|
||
|
def _apartition_left(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type):
|
||
|
ipivot = tosort[begin]
|
||
|
pivot = arr[ipivot]
|
||
|
first = begin
|
||
|
last = end
|
||
|
|
||
|
while True:
|
||
|
last -= 1
|
||
|
if not _less(pivot, arr[tosort[last]]):
|
||
|
break
|
||
|
|
||
|
if last + 1 == end:
|
||
|
while first < last:
|
||
|
first += 1
|
||
|
if _less(pivot, arr[tosort[first]]):
|
||
|
break
|
||
|
else:
|
||
|
while True:
|
||
|
first += 1
|
||
|
if _less(pivot, arr[tosort[first]]):
|
||
|
break
|
||
|
|
||
|
while first < last:
|
||
|
tosort[first], tosort[last] = tosort[last], tosort[first]
|
||
|
while True:
|
||
|
last -= 1
|
||
|
if not _less(pivot, arr[tosort[last]]):
|
||
|
break
|
||
|
while True:
|
||
|
first += 1
|
||
|
if _less(pivot, arr[tosort[first]]):
|
||
|
break
|
||
|
|
||
|
pivot_pos = last
|
||
|
tosort[begin] = tosort[pivot_pos]
|
||
|
tosort[pivot_pos] = ipivot
|
||
|
|
||
|
return pivot_pos
|
||
|
|
||
|
def _apartition_right(arr: Ptr[T], tosort: Ptr[int], begin: int, end: int, T: type):
|
||
|
ipivot = tosort[begin]
|
||
|
pivot = arr[ipivot]
|
||
|
first = begin
|
||
|
last = end
|
||
|
|
||
|
while True:
|
||
|
first += 1
|
||
|
if not _less(arr[tosort[first]], pivot):
|
||
|
break
|
||
|
|
||
|
if first - 1 == begin:
|
||
|
while first < last:
|
||
|
last -= 1
|
||
|
if _less(arr[tosort[last]], pivot):
|
||
|
break
|
||
|
else:
|
||
|
while True:
|
||
|
last -= 1
|
||
|
if _less(arr[tosort[last]], pivot):
|
||
|
break
|
||
|
|
||
|
already_partitioned = 0
|
||
|
if first >= last:
|
||
|
already_partitioned = 1
|
||
|
|
||
|
while first < last:
|
||
|
tosort[first], tosort[last] = tosort[last], tosort[first]
|
||
|
|
||
|
while True:
|
||
|
first += 1
|
||
|
if not _less(arr[tosort[first]], pivot):
|
||
|
break
|
||
|
|
||
|
while True:
|
||
|
last -= 1
|
||
|
if _less(arr[tosort[last]], pivot):
|
||
|
break
|
||
|
|
||
|
pivot_pos = first - 1
|
||
|
tosort[begin] = tosort[pivot_pos]
|
||
|
tosort[pivot_pos] = ipivot
|
||
|
|
||
|
return (pivot_pos, already_partitioned)
|
||
|
|
||
|
def _asort2(arr: Ptr[T], tosort: Ptr[int], i: int, j: int, T: type):
|
||
|
if _less(arr[tosort[j]], arr[tosort[i]]):
|
||
|
tosort[i], tosort[j] = tosort[j], tosort[i]
|
||
|
|
||
|
def _asort3(arr: Ptr[T], tosort: Ptr[int], i: int, j: int, k: int, T: type):
|
||
|
_asort2(arr, tosort, i, j)
|
||
|
_asort2(arr, tosort, j, k)
|
||
|
_asort2(arr, tosort, i, j)
|
||
|
|
||
|
def _apdq_sort(
|
||
|
arr: Ptr[T],
|
||
|
tosort: Ptr[int],
|
||
|
begin: int,
|
||
|
end: int,
|
||
|
bad_allowed: int,
|
||
|
leftmost: bool,
|
||
|
T: type
|
||
|
):
|
||
|
while True:
|
||
|
size = end - begin
|
||
|
if size < INSERTION_SORT_THRESHOLD:
|
||
|
ainsertionsort(arr, tosort + begin, size)
|
||
|
return
|
||
|
|
||
|
size_2 = size // 2
|
||
|
if size > NINTHER_THRESHOLD:
|
||
|
_asort3(arr, tosort, begin, begin + size_2, end - 1)
|
||
|
_asort3(arr, tosort, begin + 1, begin + (size_2 - 1), end - 2)
|
||
|
_asort3(arr, tosort, begin + 2, begin + (size_2 + 1), end - 3)
|
||
|
_asort3(
|
||
|
arr, tosort, begin + (size_2 - 1), begin + size_2, begin + (size_2 + 1)
|
||
|
)
|
||
|
tosort[begin], tosort[begin + size_2] = tosort[begin + size_2], tosort[begin]
|
||
|
else:
|
||
|
_asort3(arr, tosort, begin + size_2, begin, end - 1)
|
||
|
|
||
|
if not leftmost and not _less(arr[tosort[begin - 1]], arr[tosort[begin]]):
|
||
|
begin = _apartition_left(arr, tosort, begin, end) + 1
|
||
|
continue
|
||
|
|
||
|
part_result = _apartition_right(arr, tosort, begin, end)
|
||
|
pivot_pos = part_result[0]
|
||
|
already_partitioned = part_result[1] == 1
|
||
|
|
||
|
l_size = pivot_pos - begin
|
||
|
r_size = end - (pivot_pos + 1)
|
||
|
highly_unbalanced = (l_size < (size // 8)) or (r_size < (size // 8))
|
||
|
|
||
|
if highly_unbalanced:
|
||
|
bad_allowed -= 1
|
||
|
if bad_allowed == 0:
|
||
|
aheapsort(arr, tosort + begin, end - begin)
|
||
|
return
|
||
|
|
||
|
if l_size >= INSERTION_SORT_THRESHOLD:
|
||
|
tosort[begin], tosort[begin + l_size // 4] = (
|
||
|
tosort[begin + l_size // 4],
|
||
|
tosort[begin],
|
||
|
)
|
||
|
tosort[pivot_pos - 1], tosort[pivot_pos - l_size // 4] = (
|
||
|
tosort[pivot_pos - l_size // 4],
|
||
|
tosort[pivot_pos - 1],
|
||
|
)
|
||
|
|
||
|
if l_size > NINTHER_THRESHOLD:
|
||
|
tosort[begin + 1], tosort[begin + (l_size // 4 + 1)] = (
|
||
|
tosort[begin + (l_size // 4 + 1)],
|
||
|
tosort[begin + 1],
|
||
|
)
|
||
|
tosort[begin + 2], tosort[begin + (l_size // 4 + 2)] = (
|
||
|
tosort[begin + (l_size // 4 + 2)],
|
||
|
tosort[begin + 2],
|
||
|
)
|
||
|
tosort[pivot_pos - 2], tosort[pivot_pos - (l_size // 4 + 1)] = (
|
||
|
tosort[pivot_pos - (l_size // 4 + 1)],
|
||
|
tosort[pivot_pos - 2],
|
||
|
)
|
||
|
tosort[pivot_pos - 3], tosort[pivot_pos - (l_size // 4 + 2)] = (
|
||
|
tosort[pivot_pos - (l_size // 4 + 2)],
|
||
|
tosort[pivot_pos - 3],
|
||
|
)
|
||
|
|
||
|
if r_size >= INSERTION_SORT_THRESHOLD:
|
||
|
tosort[pivot_pos + 1], tosort[pivot_pos + (1 + r_size // 4)] = (
|
||
|
tosort[pivot_pos + (1 + r_size // 4)],
|
||
|
tosort[pivot_pos + 1],
|
||
|
)
|
||
|
tosort[end - 1], tosort[end - r_size // 4] = (
|
||
|
tosort[end - r_size // 4],
|
||
|
tosort[end - 1],
|
||
|
)
|
||
|
|
||
|
if r_size > NINTHER_THRESHOLD:
|
||
|
tosort[pivot_pos + 2], tosort[pivot_pos + (2 + r_size // 4)] = (
|
||
|
tosort[pivot_pos + (2 + r_size // 4)],
|
||
|
tosort[pivot_pos + 2],
|
||
|
)
|
||
|
tosort[pivot_pos + 3], tosort[pivot_pos + (3 + r_size // 4)] = (
|
||
|
tosort[pivot_pos + (3 + r_size // 4)],
|
||
|
tosort[pivot_pos + 3],
|
||
|
)
|
||
|
tosort[end - 2], tosort[end - (1 + r_size // 4)] = (
|
||
|
tosort[end - (1 + r_size // 4)],
|
||
|
tosort[end - 2],
|
||
|
)
|
||
|
tosort[end - 3], tosort[end - (2 + r_size // 4)] = (
|
||
|
tosort[end - (2 + r_size // 4)],
|
||
|
tosort[end - 3],
|
||
|
)
|
||
|
else:
|
||
|
if (
|
||
|
already_partitioned
|
||
|
and _apartial_insertion_sort(arr, tosort, begin, pivot_pos)
|
||
|
and _apartial_insertion_sort(arr, tosort, pivot_pos + 1, end)
|
||
|
):
|
||
|
return
|
||
|
|
||
|
_apdq_sort(arr, tosort, begin, pivot_pos, bad_allowed, leftmost)
|
||
|
begin = pivot_pos + 1
|
||
|
leftmost = False
|
||
|
|
||
|
def aquicksort(start: Ptr[T], tosort: Ptr[int], n: int, T: type):
|
||
|
_apdq_sort(start, tosort, 0, n, _floor_log2(n), True)
|
||
|
|
||
|
#########
|
||
|
# Radix #
|
||
|
#########
|
||
|
|
||
|
def key_of(x: UT, T: type, UT: type):
|
||
|
if T is UT:
|
||
|
return x
|
||
|
else:
|
||
|
return x ^ (util.cast(1, UT) << util.cast(util.sizeof(UT) * 8 - 1, UT))
|
||
|
|
||
|
def nth_byte(key: T, l: int, T: type):
|
||
|
return int(key >> util.cast(l << 3, T)) & 0xFF
|
||
|
|
||
|
def _radixsort0(start: Ptr[UT], aux: Ptr[UT], num: int, T: type, UT: type):
|
||
|
m = util.sizeof(UT)
|
||
|
n = (1 << 8)
|
||
|
ncnt = m * n
|
||
|
cnt = Ptr[int](ncnt) # note: compiler should put this on the stack
|
||
|
str.memset(cnt.as_byte(), byte(0), ncnt * util.sizeof(int))
|
||
|
key0 = key_of(start[0], T=T, UT=UT)
|
||
|
|
||
|
for i in range(num):
|
||
|
k = key_of(start[i], T=T, UT=UT)
|
||
|
for l in range(m):
|
||
|
cnt[l*n + nth_byte(k, l)] += 1
|
||
|
|
||
|
ncols = 0
|
||
|
cols = Ptr[u8](m) # again, compiler should put on stack
|
||
|
for l in range(m):
|
||
|
if cnt[l*n + nth_byte(key0, l)] != num:
|
||
|
cols[ncols] = u8(l)
|
||
|
ncols += 1
|
||
|
|
||
|
for l in range(ncols):
|
||
|
a = 0
|
||
|
for i in range(256):
|
||
|
b = cnt[int(cols[l])*n + i]
|
||
|
cnt[int(cols[l])*n + i] = a
|
||
|
a += b
|
||
|
|
||
|
for l in range(ncols):
|
||
|
for i in range(num):
|
||
|
k = key_of(start[i], T=T, UT=UT)
|
||
|
q = cnt + (int(cols[l])*n + nth_byte(k, int(cols[l])))
|
||
|
dst = q[0]
|
||
|
q[0] += 1
|
||
|
aux[dst] = start[i]
|
||
|
|
||
|
temp = aux
|
||
|
aux = start
|
||
|
start = temp
|
||
|
|
||
|
return start
|
||
|
|
||
|
def _radixsort(start: Ptr[UT], num: int, T: type, UT: type):
|
||
|
if num < 2:
|
||
|
return
|
||
|
|
||
|
all_sorted = True
|
||
|
k1 = key_of(start[0], T=T, UT=UT)
|
||
|
for i in range(1, num):
|
||
|
k2 = key_of(start[i], T=T, UT=UT)
|
||
|
if k1 > k2:
|
||
|
all_sorted = False
|
||
|
break
|
||
|
k1 = k2
|
||
|
|
||
|
if all_sorted:
|
||
|
return
|
||
|
|
||
|
aux = Ptr[UT](num)
|
||
|
sort = _radixsort0(start, aux, num, T=T)
|
||
|
if sort != start:
|
||
|
str.memcpy(start.as_byte(), sort.as_byte(), num * util.sizeof(UT))
|
||
|
|
||
|
util.free(aux)
|
||
|
|
||
|
def radixsort(start: Ptr[T], num: int, T: type):
|
||
|
def wrap(start: Ptr, num: int, T: type, UT: type):
|
||
|
_radixsort(Ptr[UT](start.as_byte()), num, T=T, UT=UT)
|
||
|
|
||
|
if T is int:
|
||
|
wrap(start, num, int, u64)
|
||
|
elif isinstance(T, Int):
|
||
|
wrap(start, num, T, UInt[T.N])
|
||
|
elif isinstance(T, UInt):
|
||
|
wrap(start, num, T, T)
|
||
|
elif isinstance(T, byte):
|
||
|
wrap(start, num, i8, u8)
|
||
|
else:
|
||
|
raise ValueError("cannot radixsort type '" + T.__name__ + "'")
|
||
|
|
||
|
def _aradixsort0(start: Ptr[UT], aux: Ptr[int], tosort: Ptr[int], num: int, T: type, UT: type):
|
||
|
m = util.sizeof(UT)
|
||
|
n = (1 << 8)
|
||
|
ncnt = m * n
|
||
|
cnt = Ptr[int](ncnt) # note: compiler should put this on the stack
|
||
|
str.memset(cnt.as_byte(), byte(0), ncnt * util.sizeof(int))
|
||
|
key0 = key_of(start[0], T=T, UT=UT)
|
||
|
|
||
|
for i in range(num):
|
||
|
k = key_of(start[i], T=T, UT=UT)
|
||
|
for l in range(m):
|
||
|
cnt[l*n + nth_byte(k, l)] += 1
|
||
|
|
||
|
ncols = 0
|
||
|
cols = Ptr[u8](m) # again, compiler should put on stack
|
||
|
for l in range(m):
|
||
|
if cnt[l*n + nth_byte(key0, l)] != num:
|
||
|
cols[ncols] = u8(l)
|
||
|
ncols += 1
|
||
|
|
||
|
for l in range(ncols):
|
||
|
a = 0
|
||
|
for i in range(256):
|
||
|
b = cnt[int(cols[l])*n + i]
|
||
|
cnt[int(cols[l])*n + i] = a
|
||
|
a += b
|
||
|
|
||
|
for l in range(ncols):
|
||
|
for i in range(num):
|
||
|
k = key_of(start[tosort[i]], T=T, UT=UT)
|
||
|
q = cnt + (int(cols[l])*n + nth_byte(k, int(cols[l])))
|
||
|
dst = q[0]
|
||
|
q[0] += 1
|
||
|
aux[dst] = tosort[i]
|
||
|
|
||
|
temp = aux
|
||
|
aux = tosort
|
||
|
tosort = temp
|
||
|
|
||
|
return tosort
|
||
|
|
||
|
def _aradixsort(start: Ptr[UT], tosort: Ptr[int], num: int, T: type, UT: type):
|
||
|
if num < 2:
|
||
|
return
|
||
|
|
||
|
all_sorted = True
|
||
|
k1 = key_of(start[tosort[0]], T=T, UT=UT)
|
||
|
for i in range(1, num):
|
||
|
k2 = key_of(start[tosort[i]], T=T, UT=UT)
|
||
|
if k1 > k2:
|
||
|
all_sorted = False
|
||
|
break
|
||
|
k1 = k2
|
||
|
|
||
|
if all_sorted:
|
||
|
return
|
||
|
|
||
|
aux = Ptr[int](num)
|
||
|
sort = _aradixsort0(start, aux, tosort, num, T=T)
|
||
|
if sort != tosort:
|
||
|
str.memcpy(tosort.as_byte(), sort.as_byte(), num * util.sizeof(int))
|
||
|
|
||
|
util.free(aux)
|
||
|
|
||
|
def aradixsort(start: Ptr[T], tosort: Ptr[int], num: int, T: type):
|
||
|
def wrap(start: Ptr, tosort: Ptr[int], num: int, T: type, UT: type):
|
||
|
_aradixsort(Ptr[UT](start.as_byte()), tosort, num, T=T, UT=UT)
|
||
|
|
||
|
if T is int:
|
||
|
wrap(start, tosort, num, int, u64)
|
||
|
elif isinstance(T, Int):
|
||
|
wrap(start, tosort, num, T, UInt[T.N])
|
||
|
elif isinstance(T, UInt):
|
||
|
wrap(start, tosort, num, T, T)
|
||
|
elif isinstance(T, byte):
|
||
|
wrap(start, tosort, num, i8, u8)
|
||
|
else:
|
||
|
raise ValueError("cannot radixsort type '" + T.__name__ + "'")
|
||
|
|
||
|
#######
|
||
|
# Tim #
|
||
|
#######
|
||
|
|
||
|
TIMSORT_STACK_SIZE: Static[int] = 128
|
||
|
|
||
|
def _compute_min_run(num: int):
|
||
|
r = 0
|
||
|
|
||
|
while 64 < num:
|
||
|
r |= num & 1
|
||
|
num >>= 1
|
||
|
|
||
|
return num + r
|
||
|
|
||
|
# run = 2-tuple of (start, length)
|
||
|
# buffer = 2-tuple of (pw, size)
|
||
|
|
||
|
class Buffer[T]:
|
||
|
pw: Ptr[T]
|
||
|
size: int
|
||
|
|
||
|
def __init__(self):
|
||
|
self.pw = Ptr[T]()
|
||
|
self.size = 0
|
||
|
|
||
|
def resize(self, new_size: int):
|
||
|
buffer_pw = self.pw
|
||
|
buffer_size = self.size
|
||
|
|
||
|
if new_size <= buffer_size:
|
||
|
return
|
||
|
elif not buffer_pw:
|
||
|
self.pw = Ptr[T](new_size)
|
||
|
else:
|
||
|
self.pw = util.realloc(buffer_pw, new_size, buffer_size)
|
||
|
|
||
|
self.size = new_size
|
||
|
|
||
|
def free(self):
|
||
|
if self.pw:
|
||
|
util.free(self.pw)
|
||
|
self.pw = Ptr[T]()
|
||
|
|
||
|
def _count_run(arr: Ptr[T], l: int, num: int, minrun: int, T: type):
|
||
|
if num - l == 1:
|
||
|
return 1
|
||
|
|
||
|
pl = arr + l
|
||
|
|
||
|
if not _less(pl[1], pl[0]):
|
||
|
pi = pl + 1
|
||
|
while pi < arr + (num - 1) and not _less(pi[1], pi[0]):
|
||
|
pi += 1
|
||
|
else:
|
||
|
pi = pl + 1
|
||
|
while pi < arr + (num - 1) and _less(pi[1], pi[0]):
|
||
|
pi += 1
|
||
|
|
||
|
pj = pl
|
||
|
pr = pi
|
||
|
while pj < pr:
|
||
|
pj[0], pr[0] = pr[0], pj[0]
|
||
|
pj += 1
|
||
|
pr -= 1
|
||
|
|
||
|
pi += 1
|
||
|
sz = pi - pl
|
||
|
|
||
|
if sz < minrun:
|
||
|
if l + minrun < num:
|
||
|
sz = minrun
|
||
|
else:
|
||
|
sz = num - l
|
||
|
|
||
|
pr = pl + sz
|
||
|
|
||
|
while pi < pr:
|
||
|
vc = pi[0]
|
||
|
pj = pi
|
||
|
|
||
|
while pl < pj and _less(vc, pj[-1]):
|
||
|
pj[0] = pj[-1]
|
||
|
pj -= 1
|
||
|
|
||
|
pj[0] = vc
|
||
|
pi += 1
|
||
|
|
||
|
return sz
|
||
|
|
||
|
def _merge_left(p1: Ptr[T], l1: int, p2: Ptr[T], l2: int, p3: Ptr[T], T: type):
|
||
|
end = p2 + l2
|
||
|
str.memcpy(p3.as_byte(), p1.as_byte(), l1 * util.sizeof(T))
|
||
|
p1[0] = p2[0]
|
||
|
p1 += 1
|
||
|
p2 += 1
|
||
|
|
||
|
while p1 < p2 and p2 < end:
|
||
|
if _less(p2[0], p3[0]):
|
||
|
p1[0] = p2[0]
|
||
|
p1 += 1
|
||
|
p2 += 1
|
||
|
else:
|
||
|
p1[0] = p3[0]
|
||
|
p1 += 1
|
||
|
p3 += 1
|
||
|
|
||
|
if p1 != p2:
|
||
|
str.memcpy(p1.as_byte(), p3.as_byte(), (p2 - p1) * util.sizeof(T))
|
||
|
|
||
|
def _merge_right(p1: Ptr[T], l1: int, p2: Ptr[T], l2: int, p3: Ptr[T], T: type):
|
||
|
start = p1 - 1
|
||
|
str.memcpy(p3.as_byte(), p2.as_byte(), l2 * util.sizeof(T))
|
||
|
p1 += l1 - 1
|
||
|
p2 += l2 - 1
|
||
|
p3 += l2 - 1
|
||
|
p2[0] = p1[0]
|
||
|
p2 -= 1
|
||
|
p1 -= 1
|
||
|
|
||
|
while p1 < p2 and start < p1:
|
||
|
if _less(p3[0], p1[0]):
|
||
|
p2[0] = p1[0]
|
||
|
p2 -= 1
|
||
|
p1 -= 1
|
||
|
else:
|
||
|
p2[0] = p3[0]
|
||
|
p2 -= 1
|
||
|
p3 -= 1
|
||
|
|
||
|
if p1 != p2:
|
||
|
ofs = p2 - start
|
||
|
str.memcpy((start + 1).as_byte(), (p3 - ofs + 1).as_byte(), ofs * util.sizeof(T))
|
||
|
|
||
|
def _gallop_right(arr: Ptr[T], size: int, key: T, T: type):
|
||
|
if _less(key, arr[0]):
|
||
|
return 0
|
||
|
|
||
|
last_ofs = 0
|
||
|
ofs = 1
|
||
|
|
||
|
while True:
|
||
|
if size <= ofs or ofs < 0:
|
||
|
ofs = size
|
||
|
break
|
||
|
|
||
|
if _less(key, arr[ofs]):
|
||
|
break
|
||
|
else:
|
||
|
last_ofs = ofs
|
||
|
ofs = (ofs << 1) + 1
|
||
|
|
||
|
while last_ofs + 1 < ofs:
|
||
|
m = last_ofs + ((ofs - last_ofs) >> 1)
|
||
|
|
||
|
if _less(key, arr[m]):
|
||
|
ofs = m
|
||
|
else:
|
||
|
last_ofs = m
|
||
|
|
||
|
return ofs
|
||
|
|
||
|
def _gallop_left(arr: Ptr[T], size: int, key: T, T: type):
|
||
|
if _less(arr[size - 1], key):
|
||
|
return size
|
||
|
|
||
|
last_ofs = 0
|
||
|
ofs = 1
|
||
|
|
||
|
while True:
|
||
|
if size <= ofs or ofs < 0:
|
||
|
ofs = size
|
||
|
break
|
||
|
|
||
|
if _less(arr[size - ofs - 1], key):
|
||
|
break
|
||
|
else:
|
||
|
last_ofs = ofs
|
||
|
ofs = (ofs << 1) + 1
|
||
|
|
||
|
l = size - ofs - 1
|
||
|
r = size - last_ofs - 1
|
||
|
|
||
|
while l + 1 < r:
|
||
|
m = l + ((r - l) >> 1)
|
||
|
|
||
|
if _less(arr[m], key):
|
||
|
l = m
|
||
|
else:
|
||
|
r = m
|
||
|
|
||
|
return r
|
||
|
|
||
|
def _merge_at(arr: Ptr[T],
|
||
|
stack: Ptr[Tuple[int,int]],
|
||
|
at: int,
|
||
|
buffer: Buffer[T],
|
||
|
T: type):
|
||
|
s1, l1 = stack[at]
|
||
|
s2, l2 = stack[at + 1]
|
||
|
k = _gallop_right(arr + s1, l1, arr[s2])
|
||
|
|
||
|
if l1 == k:
|
||
|
return
|
||
|
|
||
|
p1 = arr + (s1 + k)
|
||
|
l1 -= k
|
||
|
p2 = arr + s2
|
||
|
l2 = _gallop_left(arr + s2, l2, arr[s2 - 1])
|
||
|
|
||
|
if l2 < l1:
|
||
|
buffer.resize(l2)
|
||
|
_merge_right(p1, l1, p2, l2, buffer.pw)
|
||
|
else:
|
||
|
buffer.resize(l1)
|
||
|
_merge_left(p1, l1, p2, l2, buffer.pw)
|
||
|
|
||
|
def _try_collapse(arr: Ptr[T],
|
||
|
stack: Ptr[Tuple[int,int]],
|
||
|
stack_ptr: Ptr[int],
|
||
|
buffer: Buffer[T],
|
||
|
T: type):
|
||
|
top = stack_ptr[0]
|
||
|
|
||
|
while 1 < top:
|
||
|
B = stack[top - 2][1]
|
||
|
C = stack[top - 1][1]
|
||
|
|
||
|
if ((2 < top and stack[top - 3][1] <= B + C) or
|
||
|
(3 < top and stack[top - 4][1] <= stack[top - 3][1] + B)):
|
||
|
A = stack[top - 3][1]
|
||
|
|
||
|
if A <= C:
|
||
|
_merge_at(arr, stack, top - 3, buffer)
|
||
|
s, l = stack[top - 3]
|
||
|
stack[top - 3] = (s, l + B)
|
||
|
stack[top - 2] = stack[top - 1]
|
||
|
top -= 1
|
||
|
else:
|
||
|
_merge_at(arr, stack, top - 2, buffer)
|
||
|
s, l = stack[top - 2]
|
||
|
stack[top - 2] = (s, l + C)
|
||
|
top -= 1
|
||
|
elif 1 < top and B <= C:
|
||
|
_merge_at(arr, stack, top - 2, buffer)
|
||
|
s, l = stack[top - 2]
|
||
|
stack[top - 2] = (s, l + C)
|
||
|
top -= 1
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
stack_ptr[0] = top
|
||
|
|
||
|
def _force_collapse(arr: Ptr[T],
|
||
|
stack: Ptr[Tuple[int,int]],
|
||
|
stack_ptr: Ptr[int],
|
||
|
buffer: Buffer[T],
|
||
|
T: type):
|
||
|
top = stack_ptr[0]
|
||
|
|
||
|
while 2 < top:
|
||
|
if stack[top - 3][1] <= stack[top - 1][1]:
|
||
|
_merge_at(arr, stack, top - 3, buffer)
|
||
|
_, l1 = stack[top - 2]
|
||
|
s, l2 = stack[top - 3]
|
||
|
stack[top - 3] = (s, l1 + l2)
|
||
|
stack[top - 2] = stack[top - 1]
|
||
|
top -= 1
|
||
|
else:
|
||
|
_merge_at(arr, stack, top - 2, buffer)
|
||
|
_, l1 = stack[top - 1]
|
||
|
s, l2 = stack[top - 2]
|
||
|
stack[top - 2] = (s, l1 + l2)
|
||
|
top -= 1
|
||
|
|
||
|
if 1 < top:
|
||
|
_merge_at(arr, stack, top - 2, buffer)
|
||
|
|
||
|
def timsort(start: Ptr[T], num: int, T: type):
|
||
|
stack = __array__[Tuple[int,int]](TIMSORT_STACK_SIZE)
|
||
|
buffer = Buffer[T]()
|
||
|
stack_ptr = 0
|
||
|
minrun = _compute_min_run(num)
|
||
|
|
||
|
l = 0
|
||
|
while l < num:
|
||
|
n = _count_run(start, l, num, minrun)
|
||
|
stack[stack_ptr] = (l, n)
|
||
|
stack_ptr += 1
|
||
|
_try_collapse(start, stack.ptr, __ptr__(stack_ptr), buffer)
|
||
|
l += n
|
||
|
|
||
|
_force_collapse(start, stack.ptr, __ptr__(stack_ptr), buffer)
|
||
|
buffer.free()
|
||
|
|
||
|
def _acount_run(arr: Ptr[T], tosort: Ptr[int], l: int, num: int, minrun: int, T: type):
|
||
|
if num - l == 1:
|
||
|
return 1
|
||
|
|
||
|
pl = tosort + l
|
||
|
|
||
|
if not _less(arr[pl[1]], arr[pl[0]]):
|
||
|
pi = pl + 1
|
||
|
while pi < tosort + (num - 1) and not _less(arr[pi[1]], arr[pi[0]]):
|
||
|
pi += 1
|
||
|
else:
|
||
|
pi = pl + 1
|
||
|
while pi < tosort + (num - 1) and _less(arr[pi[1]], arr[pi[0]]):
|
||
|
pi += 1
|
||
|
|
||
|
pj = pl
|
||
|
pr = pi
|
||
|
while pj < pr:
|
||
|
pj[0], pr[0] = pr[0], pj[0]
|
||
|
pj += 1
|
||
|
pr -= 1
|
||
|
|
||
|
pi += 1
|
||
|
sz = pi - pl
|
||
|
|
||
|
if sz < minrun:
|
||
|
if l + minrun < num:
|
||
|
sz = minrun
|
||
|
else:
|
||
|
sz = num - l
|
||
|
|
||
|
pr = pl + sz
|
||
|
|
||
|
while pi < pr:
|
||
|
vi = pi[0]
|
||
|
vc = arr[vi]
|
||
|
pj = pi
|
||
|
|
||
|
while pl < pj and _less(vc, arr[pj[-1]]):
|
||
|
pj[0] = pj[-1]
|
||
|
pj -= 1
|
||
|
|
||
|
pj[0] = vi
|
||
|
pi += 1
|
||
|
|
||
|
return sz
|
||
|
|
||
|
def _amerge_left(arr: Ptr[T], p1: Ptr[int], l1: int, p2: Ptr[int], l2: int, p3: Ptr[int], T: type):
|
||
|
end = p2 + l2
|
||
|
str.memcpy(p3.as_byte(), p1.as_byte(), l1 * util.sizeof(int))
|
||
|
p1[0] = p2[0]
|
||
|
p1 += 1
|
||
|
p2 += 1
|
||
|
|
||
|
while p1 < p2 and p2 < end:
|
||
|
if _less(arr[p2[0]], arr[p3[0]]):
|
||
|
p1[0] = p2[0]
|
||
|
p1 += 1
|
||
|
p2 += 1
|
||
|
else:
|
||
|
p1[0] = p3[0]
|
||
|
p1 += 1
|
||
|
p3 += 1
|
||
|
|
||
|
if p1 != p2:
|
||
|
str.memcpy(p1.as_byte(), p3.as_byte(), (p2 - p1) * util.sizeof(int))
|
||
|
|
||
|
def _amerge_right(arr: Ptr[T], p1: Ptr[int], l1: int, p2: Ptr[int], l2: int, p3: Ptr[int], T: type):
|
||
|
start = p1 - 1
|
||
|
str.memcpy(p3.as_byte(), p2.as_byte(), l2 * util.sizeof(int))
|
||
|
p1 += l1 - 1
|
||
|
p2 += l2 - 1
|
||
|
p3 += l2 - 1
|
||
|
p2[0] = p1[0]
|
||
|
p2 -= 1
|
||
|
p1 -= 1
|
||
|
|
||
|
while p1 < p2 and start < p1:
|
||
|
if _less(arr[p3[0]], arr[p1[0]]):
|
||
|
p2[0] = p1[0]
|
||
|
p2 -= 1
|
||
|
p1 -= 1
|
||
|
else:
|
||
|
p2[0] = p3[0]
|
||
|
p2 -= 1
|
||
|
p3 -= 1
|
||
|
|
||
|
if p1 != p2:
|
||
|
ofs = p2 - start
|
||
|
str.memcpy((start + 1).as_byte(), (p3 - ofs + 1).as_byte(), ofs * util.sizeof(int))
|
||
|
|
||
|
def _agallop_right(arr: Ptr[T], tosort: Ptr[int], size: int, key: T, T: type):
|
||
|
if _less(key, arr[tosort[0]]):
|
||
|
return 0
|
||
|
|
||
|
last_ofs = 0
|
||
|
ofs = 1
|
||
|
|
||
|
while True:
|
||
|
if size <= ofs or ofs < 0:
|
||
|
ofs = size
|
||
|
break
|
||
|
|
||
|
if _less(key, arr[tosort[ofs]]):
|
||
|
break
|
||
|
else:
|
||
|
last_ofs = ofs
|
||
|
ofs = (ofs << 1) + 1
|
||
|
|
||
|
while last_ofs + 1 < ofs:
|
||
|
m = last_ofs + ((ofs - last_ofs) >> 1)
|
||
|
|
||
|
if _less(key, arr[tosort[m]]):
|
||
|
ofs = m
|
||
|
else:
|
||
|
last_ofs = m
|
||
|
|
||
|
return ofs
|
||
|
|
||
|
def _agallop_left(arr: Ptr[T], tosort: Ptr[int], size: int, key: T, T: type):
|
||
|
if _less(arr[tosort[size - 1]], key):
|
||
|
return size
|
||
|
|
||
|
last_ofs = 0
|
||
|
ofs = 1
|
||
|
|
||
|
while True:
|
||
|
if size <= ofs or ofs < 0:
|
||
|
ofs = size
|
||
|
break
|
||
|
|
||
|
if _less(arr[tosort[size - ofs - 1]], key):
|
||
|
break
|
||
|
else:
|
||
|
last_ofs = ofs
|
||
|
ofs = (ofs << 1) + 1
|
||
|
|
||
|
l = size - ofs - 1
|
||
|
r = size - last_ofs - 1
|
||
|
|
||
|
while l + 1 < r:
|
||
|
m = l + ((r - l) >> 1)
|
||
|
|
||
|
if _less(arr[tosort[m]], key):
|
||
|
l = m
|
||
|
else:
|
||
|
r = m
|
||
|
|
||
|
return r
|
||
|
|
||
|
def _amerge_at(arr: Ptr[T],
|
||
|
tosort: Ptr[int],
|
||
|
stack: Ptr[Tuple[int,int]],
|
||
|
at: int,
|
||
|
buffer: Buffer[int],
|
||
|
T: type):
|
||
|
s1, l1 = stack[at]
|
||
|
s2, l2 = stack[at + 1]
|
||
|
k = _agallop_right(arr, tosort + s1, l1, arr[tosort[s2]])
|
||
|
|
||
|
if l1 == k:
|
||
|
return
|
||
|
|
||
|
p1 = tosort + (s1 + k)
|
||
|
l1 -= k
|
||
|
p2 = tosort + s2
|
||
|
l2 = _agallop_left(arr, tosort + s2, l2, arr[tosort[s2 - 1]])
|
||
|
|
||
|
if l2 < l1:
|
||
|
buffer.resize(l2)
|
||
|
_amerge_right(arr, p1, l1, p2, l2, buffer.pw)
|
||
|
else:
|
||
|
buffer.resize(l1)
|
||
|
_amerge_left(arr, p1, l1, p2, l2, buffer.pw)
|
||
|
|
||
|
def _atry_collapse(arr: Ptr[T],
|
||
|
tosort: Ptr[int],
|
||
|
stack: Ptr[Tuple[int,int]],
|
||
|
stack_ptr: Ptr[int],
|
||
|
buffer: Buffer[int],
|
||
|
T: type):
|
||
|
top = stack_ptr[0]
|
||
|
|
||
|
while 1 < top:
|
||
|
B = stack[top - 2][1]
|
||
|
C = stack[top - 1][1]
|
||
|
|
||
|
if ((2 < top and stack[top - 3][1] <= B + C) or
|
||
|
(3 < top and stack[top - 4][1] <= stack[top - 3][1] + B)):
|
||
|
A = stack[top - 3][1]
|
||
|
|
||
|
if A <= C:
|
||
|
_amerge_at(arr, tosort, stack, top - 3, buffer)
|
||
|
s, l = stack[top - 3]
|
||
|
stack[top - 3] = (s, l + B)
|
||
|
stack[top - 2] = stack[top - 1]
|
||
|
top -= 1
|
||
|
else:
|
||
|
_amerge_at(arr, tosort, stack, top - 2, buffer)
|
||
|
s, l = stack[top - 2]
|
||
|
stack[top - 2] = (s, l + C)
|
||
|
top -= 1
|
||
|
elif 1 < top and B <= C:
|
||
|
_amerge_at(arr, tosort, stack, top - 2, buffer)
|
||
|
s, l = stack[top - 2]
|
||
|
stack[top - 2] = (s, l + C)
|
||
|
top -= 1
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
stack_ptr[0] = top
|
||
|
|
||
|
def _aforce_collapse(arr: Ptr[T],
|
||
|
tosort: Ptr[int],
|
||
|
stack: Ptr[Tuple[int,int]],
|
||
|
stack_ptr: Ptr[int],
|
||
|
buffer: Buffer[int],
|
||
|
T: type):
|
||
|
top = stack_ptr[0]
|
||
|
|
||
|
while 2 < top:
|
||
|
if stack[top - 3][1] <= stack[top - 1][1]:
|
||
|
_amerge_at(arr, tosort, stack, top - 3, buffer)
|
||
|
_, l1 = stack[top - 2]
|
||
|
s, l2 = stack[top - 3]
|
||
|
stack[top - 3] = (s, l1 + l2)
|
||
|
stack[top - 2] = stack[top - 1]
|
||
|
top -= 1
|
||
|
else:
|
||
|
_amerge_at(arr, tosort, stack, top - 2, buffer)
|
||
|
_, l1 = stack[top - 1]
|
||
|
s, l2 = stack[top - 2]
|
||
|
stack[top - 2] = (s, l1 + l2)
|
||
|
top -= 1
|
||
|
|
||
|
if 1 < top:
|
||
|
_amerge_at(arr, tosort, stack, top - 2, buffer)
|
||
|
|
||
|
def atimsort(start: Ptr[T], tosort: Ptr[int], num: int, T: type):
|
||
|
stack = __array__[Tuple[int,int]](TIMSORT_STACK_SIZE)
|
||
|
buffer = Buffer[int]()
|
||
|
stack_ptr = 0
|
||
|
minrun = _compute_min_run(num)
|
||
|
|
||
|
l = 0
|
||
|
while l < num:
|
||
|
n = _acount_run(start, tosort, l, num, minrun)
|
||
|
stack[stack_ptr] = (l, n)
|
||
|
stack_ptr += 1
|
||
|
_atry_collapse(start, tosort, stack.ptr, __ptr__(stack_ptr), buffer)
|
||
|
l += n
|
||
|
|
||
|
_aforce_collapse(start, tosort, stack.ptr, __ptr__(stack_ptr), buffer)
|
||
|
buffer.free()
|
||
|
|
||
|
##########
|
||
|
# Stable #
|
||
|
##########
|
||
|
|
||
|
def stablesort(start: Ptr[T], num: int, T: type):
|
||
|
if T is int or T is byte or isinstance(T, Int) or isinstance(T, UInt):
|
||
|
return radixsort(start, num)
|
||
|
else:
|
||
|
return timsort(start, num)
|
||
|
|
||
|
def astablesort(start: Ptr[T], tosort: Ptr[int], num: int, T: type):
|
||
|
if T is int or T is byte or isinstance(T, Int) or isinstance(T, UInt):
|
||
|
return aradixsort(start, tosort, num)
|
||
|
else:
|
||
|
return atimsort(start, tosort, num)
|
||
|
|
||
|
#############
|
||
|
# Selection #
|
||
|
#############
|
||
|
|
||
|
MAX_PIVOT_STACK: Static[int] = 50
|
||
|
|
||
|
@tuple
|
||
|
class Sortee:
|
||
|
v: Ptr[T]
|
||
|
T: type
|
||
|
|
||
|
def __call__(self, i: int):
|
||
|
return self.v[i]
|
||
|
|
||
|
def swap(self, i: int, j: int):
|
||
|
v = self.v
|
||
|
tmp = v[i]
|
||
|
v[i] = v[j]
|
||
|
v[j] = tmp
|
||
|
|
||
|
@tuple
|
||
|
class Idx:
|
||
|
def __call__(self, i: int):
|
||
|
return i
|
||
|
|
||
|
@tuple
|
||
|
class ArgSortee:
|
||
|
tosort: Ptr[int]
|
||
|
|
||
|
def __call__(self, i: int):
|
||
|
return self.tosort[i]
|
||
|
|
||
|
def swap(self, i: int, j: int):
|
||
|
v = self.tosort
|
||
|
tmp = v[i]
|
||
|
v[i] = v[j]
|
||
|
v[j] = tmp
|
||
|
|
||
|
@tuple
|
||
|
class ArgIdx:
|
||
|
tosort: Ptr[int]
|
||
|
|
||
|
def __call__(self, i: int):
|
||
|
return self.tosort[i]
|
||
|
|
||
|
def _store_pivot(pivot: int, kth: int, pivots: Ptr[int], npiv: Ptr[int]):
|
||
|
if not pivots:
|
||
|
return
|
||
|
|
||
|
if pivot == kth and npiv[0] == MAX_PIVOT_STACK:
|
||
|
pivots[npiv[0] - 1] = pivot
|
||
|
elif pivot >= kth and npiv[0] < MAX_PIVOT_STACK:
|
||
|
pivots[npiv[0]] = pivot
|
||
|
npiv[0] += 1
|
||
|
|
||
|
def _median3_swap(v: Ptr[T], tosort: Ptr[int], low: int, mid: int, high: int, arg: Static[int], T: type):
|
||
|
if arg:
|
||
|
idx = ArgIdx(tosort)
|
||
|
sortee = ArgSortee(tosort)
|
||
|
else:
|
||
|
idx = Idx()
|
||
|
sortee = Sortee(v)
|
||
|
|
||
|
if _less(v[idx(high)], v[idx(mid)]):
|
||
|
sortee.swap(high, mid)
|
||
|
|
||
|
if _less(v[idx(high)], v[idx(low)]):
|
||
|
sortee.swap(high, low)
|
||
|
|
||
|
if _less(v[idx(low)], v[idx(mid)]):
|
||
|
sortee.swap(low, mid)
|
||
|
|
||
|
sortee.swap(mid, low + 1)
|
||
|
|
||
|
def _median5(v: Ptr[T], tosort: Ptr[int], arg: Static[int], T: type):
|
||
|
if arg:
|
||
|
idx = ArgIdx(tosort)
|
||
|
sortee = ArgSortee(tosort)
|
||
|
else:
|
||
|
idx = Idx()
|
||
|
sortee = Sortee(v)
|
||
|
|
||
|
if _less(v[idx(1)], v[idx(0)]):
|
||
|
sortee.swap(1, 0)
|
||
|
|
||
|
if _less(v[idx(4)], v[idx(3)]):
|
||
|
sortee.swap(4, 3)
|
||
|
|
||
|
if _less(v[idx(3)], v[idx(0)]):
|
||
|
sortee.swap(3, 0)
|
||
|
|
||
|
if _less(v[idx(4)], v[idx(1)]):
|
||
|
sortee.swap(4, 1)
|
||
|
|
||
|
if _less(v[idx(2)], v[idx(1)]):
|
||
|
sortee.swap(2, 1)
|
||
|
|
||
|
if _less(v[idx(3)], v[idx(2)]):
|
||
|
if _less(v[idx(3)], v[idx(1)]):
|
||
|
return 1
|
||
|
else:
|
||
|
return 3
|
||
|
else:
|
||
|
return 2
|
||
|
|
||
|
def _unguarded_partition(v: Ptr[T], tosort: Ptr[int], pivot: T, ll: int, hh: int, arg: Static[int], T: type):
|
||
|
if arg:
|
||
|
idx = ArgIdx(tosort)
|
||
|
sortee = ArgSortee(tosort)
|
||
|
else:
|
||
|
idx = Idx()
|
||
|
sortee = Sortee(v)
|
||
|
|
||
|
while True:
|
||
|
while True:
|
||
|
ll += 1
|
||
|
if not _less(v[idx(ll)], pivot):
|
||
|
break
|
||
|
|
||
|
while True:
|
||
|
hh -= 1
|
||
|
if not _less(pivot, v[idx(hh)]):
|
||
|
break
|
||
|
|
||
|
if hh < ll:
|
||
|
break
|
||
|
|
||
|
sortee.swap(ll, hh)
|
||
|
|
||
|
return ll, hh
|
||
|
|
||
|
def _median_of_median5(v: Ptr[T], tosort: Ptr[int], num: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[int], isel, T: type):
|
||
|
if arg:
|
||
|
idx = ArgIdx(tosort)
|
||
|
sortee = ArgSortee(tosort)
|
||
|
else:
|
||
|
idx = Idx()
|
||
|
sortee = Sortee(v)
|
||
|
|
||
|
right = num - 1
|
||
|
nmed = (right + 1) // 5
|
||
|
subleft = 0
|
||
|
|
||
|
for i in range(nmed):
|
||
|
m = _median5(v + (0 if arg else subleft), tosort + (subleft if arg else 0), arg)
|
||
|
sortee.swap(subleft + m, i)
|
||
|
subleft += 5
|
||
|
|
||
|
if nmed > 2:
|
||
|
isel(v, tosort, nmed, nmed // 2, pivots, npiv, arg)
|
||
|
|
||
|
return nmed // 2
|
||
|
|
||
|
def _dumb_select(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, arg: Static[int], T: type):
|
||
|
if arg:
|
||
|
idx = ArgIdx(tosort)
|
||
|
sortee = ArgSortee(tosort)
|
||
|
else:
|
||
|
idx = Idx()
|
||
|
sortee = Sortee(v)
|
||
|
|
||
|
for i in range(kth + 1):
|
||
|
minidx = i
|
||
|
minval = v[idx(i)]
|
||
|
for k in range(i + 1, num):
|
||
|
if _less(v[idx(k)], minval):
|
||
|
minidx = k
|
||
|
minval = v[idx(k)]
|
||
|
sortee.swap(i, minidx)
|
||
|
|
||
|
def _msb(unum: u64):
|
||
|
depth_limit = 0
|
||
|
while unum >> u64(1):
|
||
|
depth_limit += 1
|
||
|
unum >>= u64(1)
|
||
|
return depth_limit
|
||
|
|
||
|
def _introselect(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[int], T: type):
|
||
|
if arg:
|
||
|
idx = ArgIdx(tosort)
|
||
|
sortee = ArgSortee(tosort)
|
||
|
else:
|
||
|
idx = Idx()
|
||
|
sortee = Sortee(v)
|
||
|
|
||
|
low = 0
|
||
|
high = num - 1
|
||
|
|
||
|
if not npiv:
|
||
|
pivots = Ptr[int]()
|
||
|
|
||
|
while pivots and npiv[0] > 0:
|
||
|
if pivots[npiv[0] - 1] > kth:
|
||
|
high = pivots[npiv[0] - 1] - 1
|
||
|
break
|
||
|
elif pivots[npiv[0] - 1] == kth:
|
||
|
return
|
||
|
|
||
|
low = pivots[npiv[0] - 1] + 1
|
||
|
npiv[0] -= 1
|
||
|
|
||
|
if kth - low < 3:
|
||
|
_dumb_select(v + (0 if arg else low), tosort + (low if arg else 0), high - low + 1, kth - low, arg)
|
||
|
_store_pivot(kth, kth, pivots, npiv)
|
||
|
return
|
||
|
elif (T is float or T is float32) and kth == num - 1:
|
||
|
maxidx = low
|
||
|
maxval = v[idx(low)]
|
||
|
for k in range(low + 1, num):
|
||
|
if not _less(v[idx(k)], maxval):
|
||
|
maxidx = k
|
||
|
maxval = v[idx(k)]
|
||
|
sortee.swap(kth, maxidx)
|
||
|
return
|
||
|
|
||
|
depth_limit = _msb(u64(num)) * 2
|
||
|
|
||
|
while low + 1 < high:
|
||
|
ll = low + 1
|
||
|
hh = high
|
||
|
|
||
|
if depth_limit > 0 or hh - ll < 5:
|
||
|
mid = low + (high - low) // 2
|
||
|
_median3_swap(v, tosort, low, mid, high, arg)
|
||
|
else:
|
||
|
mid = ll + _median_of_median5(v + (0 if arg else ll), tosort + (ll if arg else 0), hh - ll, Ptr[int](), Ptr[int](), arg, _introselect)
|
||
|
sortee.swap(mid, low)
|
||
|
ll -= 1
|
||
|
hh += 1
|
||
|
|
||
|
depth_limit -= 1
|
||
|
|
||
|
ll, hh = _unguarded_partition(v, tosort, v[idx(low)], ll, hh, arg)
|
||
|
|
||
|
sortee.swap(low, hh)
|
||
|
|
||
|
if hh != kth:
|
||
|
_store_pivot(hh, kth, pivots, npiv)
|
||
|
|
||
|
if hh >= kth:
|
||
|
high = hh - 1
|
||
|
|
||
|
if hh <= kth:
|
||
|
low = ll
|
||
|
|
||
|
if high == low + 1:
|
||
|
if _less(v[idx(high)], v[idx(low)]):
|
||
|
sortee.swap(high, low)
|
||
|
|
||
|
_store_pivot(kth, kth, pivots, npiv)
|
||
|
|
||
|
def _partition(v: Ptr[T], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], T: type):
|
||
|
_introselect(v, Ptr[int](), num, kth, pivots, npiv, arg=False)
|
||
|
|
||
|
def _argpartition(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], T: type):
|
||
|
_introselect(v, tosort, num, kth, pivots, npiv, arg=True)
|
||
|
|
||
|
def _partition_compact(a: ndarray, kth: Ptr[int], nkth: int, axis: int):
|
||
|
n = a.shape[axis]
|
||
|
pivots = __array__[int](MAX_PIVOT_STACK)
|
||
|
|
||
|
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
|
||
|
idx1 = util.tuple_insert(idx, axis, 0)
|
||
|
p = a._ptr(idx1)
|
||
|
npiv = 0
|
||
|
|
||
|
for i in range(nkth):
|
||
|
_partition(p, n, kth[i], pivots.ptr, __ptr__(npiv))
|
||
|
|
||
|
def _partition_buffered(a: ndarray, kth: Ptr[int], nkth: int, axis: int):
|
||
|
n = a.shape[axis]
|
||
|
st = a.strides[axis]
|
||
|
buf = Ptr[a.dtype](n)
|
||
|
pivots = __array__[int](MAX_PIVOT_STACK)
|
||
|
|
||
|
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
|
||
|
npiv = 0
|
||
|
idx1 = util.tuple_insert(idx, axis, 0)
|
||
|
p = a._ptr(idx1).as_byte()
|
||
|
p0 = p
|
||
|
|
||
|
for i in range(n):
|
||
|
buf[i] = (Ptr[a.dtype](p))[0]
|
||
|
p += st
|
||
|
|
||
|
for i in range(nkth):
|
||
|
_partition(buf, n, kth[i], pivots.ptr, __ptr__(npiv))
|
||
|
|
||
|
p = p0
|
||
|
for i in range(n):
|
||
|
Ptr[a.dtype](p)[0] = buf[i]
|
||
|
p += st
|
||
|
|
||
|
util.free(buf)
|
||
|
|
||
|
def _fix_kth(kth, n: int):
|
||
|
kth = asarray(kth, order='C')
|
||
|
|
||
|
if kth.dtype is not int:
|
||
|
compile_error("Partition index must be integer")
|
||
|
|
||
|
if kth.ndim > 1:
|
||
|
compile_error("kth array must have dimension <= 1")
|
||
|
|
||
|
pkth = kth.data
|
||
|
nkth = kth.size
|
||
|
|
||
|
for i in range(nkth):
|
||
|
if pkth[i] < 0:
|
||
|
pkth[i] += n
|
||
|
|
||
|
if pkth[i] < 0 or pkth[i] >= n:
|
||
|
raise ValueError(f"kth(={pkth[i]}) out of bounds ({n})")
|
||
|
|
||
|
if kth.ndim == 1:
|
||
|
kth.sort()
|
||
|
|
||
|
return pkth, nkth
|
||
|
|
||
|
def _partition_helper(a: ndarray, kth, axis, kind: str, force_compact: Static[int] = False):
|
||
|
if axis is None:
|
||
|
_partition_helper(a.flatten(), kth=kth, axis=-1, kind=kind, force_compact=force_compact)
|
||
|
return
|
||
|
|
||
|
if kind != 'introselect':
|
||
|
raise ValueError(f"select kind must be 'introselect' (got {repr(kind)})")
|
||
|
|
||
|
axis = util.normalize_axis_index(axis, a.ndim)
|
||
|
n = a.shape[axis]
|
||
|
pkth, nkth = _fix_kth(kth, n)
|
||
|
|
||
|
if force_compact:
|
||
|
_partition_compact(a, pkth, nkth, axis)
|
||
|
else:
|
||
|
if a.strides[axis] == a.itemsize:
|
||
|
_partition_compact(a, pkth, nkth, axis)
|
||
|
else:
|
||
|
_partition_buffered(a, pkth, nkth, axis)
|
||
|
|
||
|
def partition(a, kth, axis = -1, kind: str = 'introselect'):
|
||
|
if isinstance(a, ndarray):
|
||
|
b = a.copy(order='C')
|
||
|
else:
|
||
|
b = asarray(a)
|
||
|
|
||
|
_partition_helper(b, kth=kth, axis=axis, kind=kind, force_compact=(b.ndim == 1))
|
||
|
return b
|
||
|
|
||
|
def argpartition(a: ndarray, kth, axis = -1, kind: str = 'introselect'):
|
||
|
if axis is None:
|
||
|
return _argpartition(a.flatten(), kth=kth, axis=-1, kind=kind)
|
||
|
|
||
|
if kind != 'introselect':
|
||
|
raise ValueError(f"select kind must be 'introselect' (got {repr(kind)})")
|
||
|
|
||
|
axis = util.normalize_axis_index(axis, a.ndim)
|
||
|
n = a.shape[axis]
|
||
|
pkth, nkth = _fix_kth(kth, n)
|
||
|
b = empty(a.shape, int)
|
||
|
a_stride = a.strides[axis]
|
||
|
b_stride = b.strides[axis]
|
||
|
a_compact = (a_stride == a.itemsize)
|
||
|
b_compact = (b_stride == b.itemsize)
|
||
|
a_buf = Ptr[a.dtype]() if a_compact else Ptr[a.dtype](n)
|
||
|
b_buf = Ptr[b.dtype]() if b_compact else Ptr[b.dtype](n)
|
||
|
pivots = __array__[int](MAX_PIVOT_STACK)
|
||
|
|
||
|
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
|
||
|
npiv = 0
|
||
|
idx1 = util.tuple_insert(idx, axis, 0)
|
||
|
pa = a._ptr(idx1)
|
||
|
pb = b._ptr(idx1)
|
||
|
qa = pa.as_byte()
|
||
|
qb = pb.as_byte()
|
||
|
|
||
|
if not a_compact:
|
||
|
for i in range(n):
|
||
|
a_buf[i] = (Ptr[a.dtype](qa))[0]
|
||
|
qa += a_stride
|
||
|
|
||
|
if not b_compact:
|
||
|
for i in range(n):
|
||
|
b_buf[i] = i
|
||
|
else:
|
||
|
for i in range(n):
|
||
|
pb[i] = i
|
||
|
|
||
|
start = pa if a_compact else a_buf
|
||
|
tosort = pb if b_compact else b_buf
|
||
|
|
||
|
for i in range(nkth):
|
||
|
_argpartition(start, tosort, n, pkth[i], pivots.ptr, __ptr__(npiv))
|
||
|
|
||
|
if not a_compact:
|
||
|
qa = pa.as_byte()
|
||
|
for i in range(n):
|
||
|
Ptr[a.dtype](qa)[0] = a_buf[i]
|
||
|
qa += a_stride
|
||
|
|
||
|
if not b_compact:
|
||
|
qb = pb.as_byte()
|
||
|
for i in range(n):
|
||
|
Ptr[b.dtype](qb)[0] = b_buf[i]
|
||
|
qb += b_stride
|
||
|
|
||
|
if not a_compact:
|
||
|
util.free(a_buf)
|
||
|
|
||
|
if not b_compact:
|
||
|
util.free(b_buf)
|
||
|
|
||
|
return b
|
||
|
|
||
|
###########
|
||
|
# Lexsort #
|
||
|
###########
|
||
|
|
||
|
def _lexsort(sort_keys, n: int, axis: int):
|
||
|
sort_keys0 = asarray(sort_keys[0])
|
||
|
nd: Static[int] = sort_keys0.ndim
|
||
|
|
||
|
if nd == 0 and (axis == 0 or axis == -1):
|
||
|
pass
|
||
|
else:
|
||
|
axis = util.normalize_axis_index(axis, nd)
|
||
|
|
||
|
if nd == 0:
|
||
|
return 0
|
||
|
|
||
|
if sort_keys0.size <= 1:
|
||
|
return zeros(sort_keys[0].shape, int)
|
||
|
|
||
|
ret = empty(sort_keys0.shape, int)
|
||
|
rstride = ret.strides[axis]
|
||
|
maxelsize = sort_keys0.itemsize
|
||
|
needcopy = (rstride != ret.itemsize)
|
||
|
N = sort_keys0.shape[axis]
|
||
|
|
||
|
for j in range(n):
|
||
|
key = sort_keys[j]
|
||
|
needcopy |= (key.strides[axis] != key.itemsize)
|
||
|
maxelsize = max(maxelsize, key.itemsize)
|
||
|
|
||
|
valbuffer = cobj()
|
||
|
indbuffer = Ptr[int]()
|
||
|
|
||
|
if needcopy:
|
||
|
valbufsize = N * maxelsize
|
||
|
if valbufsize == 0:
|
||
|
valbufsize = 1
|
||
|
valbuffer = cobj(valbufsize)
|
||
|
indbufsize = N * ret.itemsize
|
||
|
indbuffer = Ptr[int](indbufsize)
|
||
|
|
||
|
for idx in util.multirange(util.tuple_delete(ret.shape, axis)):
|
||
|
idx1 = util.tuple_insert(idx, axis, 0)
|
||
|
q = ret._ptr(idx1).as_byte()
|
||
|
|
||
|
for i in range(N):
|
||
|
indbuffer[i] = i
|
||
|
|
||
|
for j in range(n):
|
||
|
key = sort_keys[j]
|
||
|
buf = Ptr[key.dtype](valbuffer.as_byte())
|
||
|
p = key._ptr(idx1).as_byte()
|
||
|
s = key.strides[axis]
|
||
|
|
||
|
for i in range(N):
|
||
|
buf[i] = (Ptr[key.dtype](p))[0]
|
||
|
p += s
|
||
|
|
||
|
astablesort(buf, indbuffer, N)
|
||
|
|
||
|
for i in range(N):
|
||
|
z = Ptr[int](q)
|
||
|
z[0] = indbuffer[i]
|
||
|
q += rstride
|
||
|
else:
|
||
|
for idx in util.multirange(util.tuple_delete(ret.shape, axis)):
|
||
|
idx1 = util.tuple_insert(idx, axis, 0)
|
||
|
q = ret._ptr(idx1)
|
||
|
|
||
|
for i in range(N):
|
||
|
q[i] = i
|
||
|
|
||
|
for j in range(n):
|
||
|
p = sort_keys[j]._ptr(idx1)
|
||
|
astablesort(p, q, N)
|
||
|
|
||
|
if valbuffer:
|
||
|
util.free(valbuffer)
|
||
|
|
||
|
if indbuffer:
|
||
|
util.free(indbuffer)
|
||
|
|
||
|
return ret
|
||
|
|
||
|
def lexsort(keys, axis: int = -1):
|
||
|
n = len(keys)
|
||
|
|
||
|
if n == 0:
|
||
|
raise ValueError("need sequence of keys with len > 0 in lexsort")
|
||
|
|
||
|
if isinstance(keys, Tuple):
|
||
|
sort_keys = tuple(asarray(key) for key in keys)
|
||
|
for i in staticrange(1, staticlen(sort_keys)):
|
||
|
if sort_keys[i].ndim != sort_keys[0].ndim:
|
||
|
compile_error("all keys need to be the same shape")
|
||
|
if sort_keys[i].shape != sort_keys[0].shape:
|
||
|
raise ValueError("all keys need to be the same shape")
|
||
|
return _lexsort(sort_keys, n, axis)
|
||
|
elif isinstance(keys, List):
|
||
|
if isinstance(keys[0], ndarray):
|
||
|
sort_keys = keys
|
||
|
else:
|
||
|
sort_keys = [asarray(key) for key in keys]
|
||
|
for i in range(1, n):
|
||
|
if sort_keys[i].shape != sort_keys[0].shape:
|
||
|
raise ValueError("all keys need to be the same shape")
|
||
|
return _lexsort(sort_keys, n, axis)
|
||
|
elif isinstance(keys, ndarray):
|
||
|
return _lexsort(keys, n, axis)
|
||
|
else:
|
||
|
compile_error("keys must be an ndarray or a tuple")
|
||
|
|
||
|
#######
|
||
|
# API #
|
||
|
#######
|
||
|
|
||
|
def _sort_compact(a: ndarray, axis: int, sorter):
|
||
|
n = a.shape[axis]
|
||
|
|
||
|
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
|
||
|
idx1 = util.tuple_insert(idx, axis, 0)
|
||
|
p = a._ptr(idx1)
|
||
|
sorter(p, n)
|
||
|
|
||
|
def _sort_buffered(a: ndarray, axis: int, sorter):
|
||
|
n = a.shape[axis]
|
||
|
st = a.strides[axis]
|
||
|
buf = Ptr[a.dtype](n)
|
||
|
|
||
|
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
|
||
|
idx1 = util.tuple_insert(idx, axis, 0)
|
||
|
p = a._ptr(idx1).as_byte()
|
||
|
p0 = p
|
||
|
|
||
|
for i in range(n):
|
||
|
buf[i] = (Ptr[a.dtype](p))[0]
|
||
|
p += st
|
||
|
|
||
|
sorter(buf, n)
|
||
|
|
||
|
p = p0
|
||
|
for i in range(n):
|
||
|
Ptr[a.dtype](p)[0] = buf[i]
|
||
|
p += st
|
||
|
|
||
|
util.free(buf)
|
||
|
|
||
|
def _sort_dispatch(a: ndarray, axis: int, sorter, force_compact: Static[int] = False):
|
||
|
axis = util.normalize_axis_index(axis, a.ndim)
|
||
|
|
||
|
if force_compact:
|
||
|
_sort_compact(a, axis, sorter)
|
||
|
else:
|
||
|
if a.strides[axis] == a.itemsize:
|
||
|
_sort_compact(a, axis, sorter)
|
||
|
else:
|
||
|
_sort_buffered(a, axis, sorter)
|
||
|
|
||
|
def _sort(a: ndarray, axis: int, kind: Optional[str], force_compact: Static[int] = False):
|
||
|
if kind is None or kind == 'quicksort' or kind == 'quick':
|
||
|
_sort_dispatch(a, axis, quicksort, force_compact)
|
||
|
elif kind == 'mergesort' or kind == 'merge' or kind == 'stable':
|
||
|
_sort_dispatch(a, axis, stablesort, force_compact)
|
||
|
elif kind == 'heapsort' or kind == 'heap':
|
||
|
_sort_dispatch(a, axis, heapsort, force_compact)
|
||
|
else:
|
||
|
raise ValueError(f"sort kind must be one of 'quick', 'heap', or 'stable' (got {repr(kind)})")
|
||
|
|
||
|
def sort(a, axis = -1, kind: Optional[str] = None):
|
||
|
if axis is None:
|
||
|
return sort(asarray(a).flatten(), axis=-1, kind=kind)
|
||
|
|
||
|
if isinstance(a, ndarray):
|
||
|
b = a.copy(order='C')
|
||
|
else:
|
||
|
b = asarray(a)
|
||
|
_sort(b, axis=axis, kind=kind, force_compact=(b.ndim == 1))
|
||
|
return b
|
||
|
|
||
|
def _asort(a: ndarray, axis: int, sorter):
|
||
|
axis = util.normalize_axis_index(axis, a.ndim)
|
||
|
n = a.shape[axis]
|
||
|
b = empty(a.shape, int)
|
||
|
a_stride = a.strides[axis]
|
||
|
b_stride = b.strides[axis]
|
||
|
a_compact = (a_stride == a.itemsize)
|
||
|
b_compact = (b_stride == b.itemsize)
|
||
|
a_buf = Ptr[a.dtype]() if a_compact else Ptr[a.dtype](n)
|
||
|
b_buf = Ptr[b.dtype]() if b_compact else Ptr[b.dtype](n)
|
||
|
|
||
|
for idx in util.multirange(util.tuple_delete(a.shape, axis)):
|
||
|
idx1 = util.tuple_insert(idx, axis, 0)
|
||
|
pa = a._ptr(idx1)
|
||
|
pb = b._ptr(idx1)
|
||
|
qa = pa.as_byte()
|
||
|
qb = pb.as_byte()
|
||
|
|
||
|
if not a_compact:
|
||
|
for i in range(n):
|
||
|
a_buf[i] = (Ptr[a.dtype](qa))[0]
|
||
|
qa += a_stride
|
||
|
|
||
|
if not b_compact:
|
||
|
for i in range(n):
|
||
|
b_buf[i] = i
|
||
|
else:
|
||
|
for i in range(n):
|
||
|
pb[i] = i
|
||
|
|
||
|
start = pa if a_compact else a_buf
|
||
|
tosort = pb if b_compact else b_buf
|
||
|
sorter(start, tosort, n)
|
||
|
|
||
|
if not a_compact:
|
||
|
qa = pa.as_byte()
|
||
|
for i in range(n):
|
||
|
Ptr[a.dtype](qa)[0] = a_buf[i]
|
||
|
qa += a_stride
|
||
|
|
||
|
if not b_compact:
|
||
|
qb = pb.as_byte()
|
||
|
for i in range(n):
|
||
|
Ptr[b.dtype](qb)[0] = b_buf[i]
|
||
|
qb += b_stride
|
||
|
|
||
|
if not a_compact:
|
||
|
util.free(a_buf)
|
||
|
|
||
|
if not b_compact:
|
||
|
util.free(b_buf)
|
||
|
|
||
|
return b
|
||
|
|
||
|
def argsort(a, axis = -1, kind: Optional[str] = None):
|
||
|
if axis is None:
|
||
|
return argsort(asarray(a).flatten(), axis=-1, kind=kind)
|
||
|
|
||
|
a = asarray(a)
|
||
|
if kind is None or kind == 'quicksort' or kind == 'quick':
|
||
|
return _asort(a, axis, aquicksort)
|
||
|
elif kind == 'mergesort' or kind == 'merge' or kind == 'stable':
|
||
|
return _asort(a, axis, astablesort)
|
||
|
elif kind == 'heapsort' or kind == 'heap':
|
||
|
return _asort(a, axis, aheapsort)
|
||
|
else:
|
||
|
raise ValueError(f"sort kind must be one of 'quick', 'heap', or 'stable' (got {repr(kind)})")
|
||
|
|
||
|
def sort_complex(a):
|
||
|
b = array(a, copy=True)
|
||
|
b.sort()
|
||
|
if not (b.dtype is complex or b.dtype is complex64):
|
||
|
if b.dtype is byte or b.dtype is i8 or b.dtype is u8 or b.dtype is i16 or b.dtype is u16:
|
||
|
return b.astype(complex64)
|
||
|
else:
|
||
|
return b.astype(complex)
|
||
|
else:
|
||
|
return b
|
||
|
|
||
|
@extend
|
||
|
class ndarray:
|
||
|
def sort(self, axis: int = -1, kind: Optional[str] = None):
|
||
|
_sort(self, axis=axis, kind=kind, force_compact=False)
|
||
|
|
||
|
def argsort(self, axis: int = -1, kind: Optional[str] = None):
|
||
|
return argsort(self, axis=axis, kind=kind)
|
||
|
|
||
|
def partition(self, kth, axis: int = -1, kind: str = 'introselect'):
|
||
|
_partition_helper(self, kth=kth, axis=axis, kind=kind, force_compact=False)
|
||
|
|
||
|
def argpartition(self, kth, axis: int = -1, kind: str = 'introselect'):
|
||
|
return argpartition(self, kth=kth, axis=axis, kind=kind)
|
||
|
|
||
|
# TODO: support 'order' argument on sort, argsort, ndarray.sort, partition and argpartition
|