import numpy as np
import numpy.random as rnd

g = rnd.default_rng(1234)

def less(a: T, b: T, T: type):
    if T is float or T is float32:
        return a < b or (b != b and a == a)
    elif T is complex or T is complex64:
        if a.real < b.real:
            return a.imag == a.imag or b.imag != b.imag
        elif a.real > b.real:
            return b.imag != b.imag and a.imag == a.imag
        elif a.real == b.real or (a.real != a.real and b.real != b.real):
            return a.imag < b.imag or (b.imag != b.imag and a.imag == a.imag)
        else:
            return b.real != b.real
    else:
        return a < b

def less_equal(a, b):
    return a == b or less(a, b)

def check_sorted(vec):
    for i in range(1, len(vec)):
        if not less_equal(vec[i - 1], vec[i]):
            return False
    return True

def check_argsorted(vec, tosort):
    for i in range(1, len(vec)):
        if not less_equal(vec[tosort[i - 1]], vec[tosort[i]]):
            return False
    return True

def gen_data(generator, length: int, dtype: type):
    if dtype is float or dtype is float32:
        data = generator.normal(100, 10, size=length).astype(dtype)
    elif dtype is complex or dtype is complex64:
        F = type(dtype().real)
        data = generator.normal(1000, 100, size=length).astype(
            F) + generator.normal(100, 10, size=length).astype(F) * 1j
    elif isinstance(dtype, UInt):
        data = generator.integers(0, 100000, size=length).astype(dtype)
    else:
        data = generator.integers(-100000, 100000, size=length).astype(dtype)

    return data

@test
def test_sort(generator, length: int, kind: str, dtype: type = int):
    print(f'TEST kind={kind} length={length} dtype={dtype.__name__}')
    data = gen_data(generator, length, dtype)
    data_copy = data.copy()
    assert check_sorted(np.sort(data, kind=kind))
    assert check_argsorted(data, np.argsort(data, kind=kind))
    assert (data == data_copy).all()
    data.sort(kind=kind)
    assert check_sorted(data)

def test_sorts(dtype: type):
    for kind in ('quick', 'merge', 'heap'):
        for length in (0, 1, 10, 101, 1111, 12345, 1000000):
            test_sort(g, length, kind, dtype)

def check_partitioned(vec, kth):
    if isinstance(kth, int):
        return check_partitioned(vec, (kth, ))

    s = np.sort(vec)
    for k in kth:
        if vec[k] != s[k]:
            return False
    return True

def check_apartitioned(vec, kth, apart):
    if isinstance(kth, int):
        return check_apartitioned(vec, (kth, ), apart)

    s = np.sort(vec)
    for k in kth:
        if vec[apart[k]] != s[k]:
            return False
    return True

@test
def test_partition_specific(generator, length, kth):
    data = gen_data(generator, length, float)
    part = np.partition(data, kth)
    assert check_partitioned(part, kth)
    data.partition(kth)
    assert check_partitioned(data, kth)

def test_partition():
    for length in (101, 1111, 12345, 1000000):
        for kth in (0, 42, -1, (0, -1), (99, -99), (10, 20, 30 - 10)):
            test_partition_specific(g, length, kth)

@test
def test_apartition_specific(generator, length, kth):
    data = gen_data(generator, length, float)
    apart = np.argpartition(data, kth)
    assert check_apartitioned(data, kth, apart)
    apart = data.argpartition(kth)
    assert check_apartitioned(data, kth, apart)

def test_apartition():
    for length in (101, 1111, 12345, 1000000):
        for kth in (0, 42, -1, (0, -1), (99, -99), (10, 20, 30 - 10)):
            test_apartition_specific(g, length, kth)

@test
def test_sorts_along_axis():
    a = np.array([[2, 22, 222], [1, 11, 111], [3, 33, 333]])
    assert (np.sort(a, axis=0) == np.array([[1, 11, 111], [2, 22, 222],
                                            [3, 33, 333]])).all()
    assert (np.partition(a, kth=1, axis=0)[1] == np.array([2, 22, 222])).all()
    assert (np.argsort(a, axis=0) == np.array([[1, 1, 1], [0, 0, 0],
                                               [2, 2, 2]])).all()
    assert (np.argpartition(a, kth=1, axis=0)[1] == np.array([0, 0, 0])).all()
    assert (a.argsort(axis=0) == np.array([[1, 1, 1], [0, 0, 0], [2, 2,
                                                                  2]])).all()
    assert (a.argpartition(kth=1, axis=0)[1] == np.array([0, 0, 0])).all()

    a.sort(axis=0)
    assert (a == np.array([[1, 11, 111], [2, 22, 222], [3, 33, 333]])).all()

    a = np.array([[2, 22, 222], [1, 11, 111], [3, 33, 333]])
    a.partition(kth=1, axis=0)
    assert (a[1] == np.array([2, 22, 222])).all()

@test
def test_sort_complex():
    assert (np.sort_complex([5, 3, 6, 2, 1]) == np.array(
        [1. + 0.j, 2. + 0.j, 3. + 0.j, 5. + 0.j, 6. + 0.j])).all()
    assert (np.sort_complex([
        1 + 2j, 2 - 1j, 3 - 2j, 3 - 3j, 3 + 5j
    ]) == np.array([1. + 2.j, 2. - 1.j, 3. - 3.j, 3. - 2.j, 3. + 5.j])).all()

@test
def test_lexsort():
    a = [1, 5, 1, 4, 3, 4, 4]
    b = [9, 4, 0, 4, 0, 2, 1]
    assert np.lexsort(a) == 0
    assert (np.lexsort((a, )) == np.array([0, 2, 4, 3, 5, 6, 1])).all()
    assert (np.lexsort((a, b)) == np.array([2, 4, 6, 5, 3, 1, 0])).all()
    assert (np.lexsort((b, a)) == np.array([2, 0, 4, 6, 5, 3, 1])).all()
    assert np.lexsort(np.array(a)) == 0
    assert (np.lexsort(np.array([a])) == np.array([0, 2, 4, 3, 5, 6, 1])).all()
    assert (np.lexsort(np.array([a, b])) == np.array([2, 4, 6, 5, 3, 1,
                                                      0])).all()
    assert (np.lexsort(np.array([b, a])) == np.array([2, 0, 4, 6, 5, 3,
                                                      1])).all()
    assert (np.lexsort(np.array([[a, b],
                                 [a, b]])) == np.array([[0, 2, 4, 3, 5, 6, 1],
                                                        [2, 4, 6, 5, 1, 3,
                                                         0]])).all()
    assert (np.lexsort(np.array([[a, b], [a, b]]),
                       axis=0) == np.array([[0, 1, 1, 0, 1, 1, 1],
                                            [1, 0, 0, 1, 0, 0, 0]])).all()
    assert (np.lexsort(np.array([[a, b], [a, b]]).T) == np.array([[0, 1],
                                                                  [0,
                                                                   1]])).all()
    assert (np.lexsort(np.array([[a, b], [a, b]]).T,
                       axis=0) == np.array([[1, 1], [0, 0]])).all()

test_sorts(int)
test_sorts(u64)
test_sorts(u32)
test_sorts(i32)
test_sorts(u8)
test_sorts(i8)
test_sorts(u128)
test_sorts(i128)
test_sorts(float)
test_sorts(float32)
test_sorts(complex)
test_sorts(complex64)
test_partition()
test_apartition()
test_sorts_along_axis()
test_sort_complex()
test_lexsort()