mirror of https://github.com/exaloop/codon.git
184 lines
6.5 KiB
Python
184 lines
6.5 KiB
Python
import numpy as np
|
|
import numpy.random as rnd
|
|
|
|
g = rnd.default_rng(1234)
|
|
|
|
def less(a: T, b: T, T: type):
|
|
if T is float or T is float32:
|
|
return a < b or (b != b and a == a)
|
|
elif T is complex or T is complex64:
|
|
if a.real < b.real:
|
|
return a.imag == a.imag or b.imag != b.imag
|
|
elif a.real > b.real:
|
|
return b.imag != b.imag and a.imag == a.imag
|
|
elif a.real == b.real or (a.real != a.real and b.real != b.real):
|
|
return a.imag < b.imag or (b.imag != b.imag and a.imag == a.imag)
|
|
else:
|
|
return b.real != b.real
|
|
else:
|
|
return a < b
|
|
|
|
def less_equal(a, b):
|
|
return a == b or less(a, b)
|
|
|
|
def check_sorted(vec):
|
|
for i in range(1, len(vec)):
|
|
if not less_equal(vec[i - 1], vec[i]):
|
|
return False
|
|
return True
|
|
|
|
def check_argsorted(vec, tosort):
|
|
for i in range(1, len(vec)):
|
|
if not less_equal(vec[tosort[i - 1]], vec[tosort[i]]):
|
|
return False
|
|
return True
|
|
|
|
def gen_data(generator, length: int, dtype: type):
|
|
if dtype is float or dtype is float32:
|
|
data = generator.normal(100, 10, size=length).astype(dtype)
|
|
elif dtype is complex or dtype is complex64:
|
|
F = type(dtype().real)
|
|
data = generator.normal(1000, 100, size=length).astype(
|
|
F) + generator.normal(100, 10, size=length).astype(F) * 1j
|
|
elif isinstance(dtype, UInt):
|
|
data = generator.integers(0, 100000, size=length).astype(dtype)
|
|
else:
|
|
data = generator.integers(-100000, 100000, size=length).astype(dtype)
|
|
|
|
return data
|
|
|
|
@test
|
|
def test_sort(generator, length: int, kind: str, dtype: type = int):
|
|
print(f'TEST kind={kind} length={length} dtype={dtype.__name__}')
|
|
data = gen_data(generator, length, dtype)
|
|
data_copy = data.copy()
|
|
assert check_sorted(np.sort(data, kind=kind))
|
|
assert check_argsorted(data, np.argsort(data, kind=kind))
|
|
assert (data == data_copy).all()
|
|
data.sort(kind=kind)
|
|
assert check_sorted(data)
|
|
|
|
def test_sorts(dtype: type):
|
|
for kind in ('quick', 'merge', 'heap'):
|
|
for length in (0, 1, 10, 101, 1111, 12345, 1000000):
|
|
test_sort(g, length, kind, dtype)
|
|
|
|
def check_partitioned(vec, kth):
|
|
if isinstance(kth, int):
|
|
return check_partitioned(vec, (kth, ))
|
|
|
|
s = np.sort(vec)
|
|
for k in kth:
|
|
if vec[k] != s[k]:
|
|
return False
|
|
return True
|
|
|
|
def check_apartitioned(vec, kth, apart):
|
|
if isinstance(kth, int):
|
|
return check_apartitioned(vec, (kth, ), apart)
|
|
|
|
s = np.sort(vec)
|
|
for k in kth:
|
|
if vec[apart[k]] != s[k]:
|
|
return False
|
|
return True
|
|
|
|
@test
|
|
def test_partition_specific(generator, length, kth):
|
|
data = gen_data(generator, length, float)
|
|
part = np.partition(data, kth)
|
|
assert check_partitioned(part, kth)
|
|
data.partition(kth)
|
|
assert check_partitioned(data, kth)
|
|
|
|
def test_partition():
|
|
for length in (101, 1111, 12345, 1000000):
|
|
for kth in (0, 42, -1, (0, -1), (99, -99), (10, 20, 30 - 10)):
|
|
test_partition_specific(g, length, kth)
|
|
|
|
@test
|
|
def test_apartition_specific(generator, length, kth):
|
|
data = gen_data(generator, length, float)
|
|
apart = np.argpartition(data, kth)
|
|
assert check_apartitioned(data, kth, apart)
|
|
apart = data.argpartition(kth)
|
|
assert check_apartitioned(data, kth, apart)
|
|
|
|
def test_apartition():
|
|
for length in (101, 1111, 12345, 1000000):
|
|
for kth in (0, 42, -1, (0, -1), (99, -99), (10, 20, 30 - 10)):
|
|
test_apartition_specific(g, length, kth)
|
|
|
|
@test
|
|
def test_sorts_along_axis():
|
|
a = np.array([[2, 22, 222], [1, 11, 111], [3, 33, 333]])
|
|
assert (np.sort(a, axis=0) == np.array([[1, 11, 111], [2, 22, 222],
|
|
[3, 33, 333]])).all()
|
|
assert (np.partition(a, kth=1, axis=0)[1] == np.array([2, 22, 222])).all()
|
|
assert (np.argsort(a, axis=0) == np.array([[1, 1, 1], [0, 0, 0],
|
|
[2, 2, 2]])).all()
|
|
assert (np.argpartition(a, kth=1, axis=0)[1] == np.array([0, 0, 0])).all()
|
|
assert (a.argsort(axis=0) == np.array([[1, 1, 1], [0, 0, 0], [2, 2,
|
|
2]])).all()
|
|
assert (a.argpartition(kth=1, axis=0)[1] == np.array([0, 0, 0])).all()
|
|
|
|
a.sort(axis=0)
|
|
assert (a == np.array([[1, 11, 111], [2, 22, 222], [3, 33, 333]])).all()
|
|
|
|
a = np.array([[2, 22, 222], [1, 11, 111], [3, 33, 333]])
|
|
a.partition(kth=1, axis=0)
|
|
assert (a[1] == np.array([2, 22, 222])).all()
|
|
|
|
@test
|
|
def test_sort_complex():
|
|
assert (np.sort_complex([5, 3, 6, 2, 1]) == np.array(
|
|
[1. + 0.j, 2. + 0.j, 3. + 0.j, 5. + 0.j, 6. + 0.j])).all()
|
|
assert (np.sort_complex([
|
|
1 + 2j, 2 - 1j, 3 - 2j, 3 - 3j, 3 + 5j
|
|
]) == np.array([1. + 2.j, 2. - 1.j, 3. - 3.j, 3. - 2.j, 3. + 5.j])).all()
|
|
|
|
@test
|
|
def test_lexsort():
|
|
a = [1, 5, 1, 4, 3, 4, 4]
|
|
b = [9, 4, 0, 4, 0, 2, 1]
|
|
assert np.lexsort(a) == 0
|
|
assert (np.lexsort((a, )) == np.array([0, 2, 4, 3, 5, 6, 1])).all()
|
|
assert (np.lexsort((a, b)) == np.array([2, 4, 6, 5, 3, 1, 0])).all()
|
|
assert (np.lexsort((b, a)) == np.array([2, 0, 4, 6, 5, 3, 1])).all()
|
|
assert np.lexsort(np.array(a)) == 0
|
|
assert (np.lexsort(np.array([a])) == np.array([0, 2, 4, 3, 5, 6, 1])).all()
|
|
assert (np.lexsort(np.array([a, b])) == np.array([2, 4, 6, 5, 3, 1,
|
|
0])).all()
|
|
assert (np.lexsort(np.array([b, a])) == np.array([2, 0, 4, 6, 5, 3,
|
|
1])).all()
|
|
assert (np.lexsort(np.array([[a, b],
|
|
[a, b]])) == np.array([[0, 2, 4, 3, 5, 6, 1],
|
|
[2, 4, 6, 5, 1, 3,
|
|
0]])).all()
|
|
assert (np.lexsort(np.array([[a, b], [a, b]]),
|
|
axis=0) == np.array([[0, 1, 1, 0, 1, 1, 1],
|
|
[1, 0, 0, 1, 0, 0, 0]])).all()
|
|
assert (np.lexsort(np.array([[a, b], [a, b]]).T) == np.array([[0, 1],
|
|
[0,
|
|
1]])).all()
|
|
assert (np.lexsort(np.array([[a, b], [a, b]]).T,
|
|
axis=0) == np.array([[1, 1], [0, 0]])).all()
|
|
|
|
test_sorts(int)
|
|
test_sorts(u64)
|
|
test_sorts(u32)
|
|
test_sorts(i32)
|
|
test_sorts(u8)
|
|
test_sorts(i8)
|
|
test_sorts(u128)
|
|
test_sorts(i128)
|
|
test_sorts(float)
|
|
test_sorts(float32)
|
|
test_sorts(complex)
|
|
test_sorts(complex64)
|
|
test_partition()
|
|
test_apartition()
|
|
test_sorts_along_axis()
|
|
test_sort_complex()
|
|
test_lexsort()
|