INSERTION_SORT_THRESHOLD = 24
NINTHER_THRESHOLD = 128
PARTIAL_INSERTION_SORT_LIMIT = 8

from algorithms.insertionsort import _insertion_sort
from algorithms.heapsort import _heap_sort

def _floor_log2(n: int) -> int:
    """Returns floor(log2(n))"""
    log = 0
    while True:
        n >>= 1
        if n == 0:
            break
        log += 1
    return log

def _partial_insertion_sort[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T], S]) -> bool:
    if begin == end:
        return True

    limit = 0
    cur = begin + 1
    while cur != end:
        if limit > PARTIAL_INSERTION_SORT_LIMIT:
            return False

        sift = cur
        sift_1 = cur -1

        if keyf(arr[sift]) < keyf(arr[sift_1]):
            tmp = arr[sift]

            while True:
                arr[sift] = arr[sift_1]
                sift -= 1
                sift_1 -= 1
                if sift == begin or keyf(tmp) >= keyf(arr[sift_1]):
                    break

            arr[sift] = tmp
            limit += cur - sift

        cur += 1

    return True

def _partition_left[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T], S]) -> int:
    pivot = arr[begin]
    first = begin
    last = end

    while True:
        last -= 1
        if keyf(pivot) >= keyf(arr[last]):
            break

    if (last + 1 == end):
        while first < last:
            first += 1
            if keyf(pivot) < keyf(arr[first]):
                break

    else:
        while True:
            first += 1
            if keyf(pivot) < keyf(arr[first]):
                break

    while first < last:
        arr[first], arr[last] = arr[last], arr[first]
        while True:
            last -= 1
            if keyf(pivot) >= keyf(arr[last]):
                break
        while True:
            first += 1
            if keyf(pivot) < keyf(arr[first]):
                break

    pivot_pos = last
    arr[begin] = arr[pivot_pos]
    arr[pivot_pos] = pivot

    return pivot_pos

def _partition_right[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T], S]) -> Tuple[int,int]:
    pivot = arr[begin]
    first = begin
    last = end

    while True:
        first += 1
        if keyf(arr[first]) >= keyf(pivot):
            break

    if first - 1 == begin:
        while first < last:
            last -= 1
            if keyf(arr[last]) < keyf(pivot):
                break

    else:
        while True:
            last -= 1
            if keyf(arr[last]) < keyf(pivot):
                break

    already_partitioned = 0
    if first >= last:
        already_partitioned = 1

    while first < last:
        arr[first], arr[last] = arr[last], arr[first]

        while True:
            first += 1
            if keyf(arr[first]) >= keyf(pivot):
                break

        while True:
            last -= 1
            if keyf(arr[last]) < keyf(pivot):
                break

    pivot_pos = first - 1
    arr[begin] = arr[pivot_pos]
    arr[pivot_pos] = pivot

    return (pivot_pos, already_partitioned)

def _sort2[S,T](arr: Array[T], i: int, j: int, keyf: Callable[[T], S]):
    if keyf(arr[j]) < keyf(arr[i]):
        arr[i], arr[j] = arr[j], arr[i]

def _sort3[S,T](arr: Array[T], i: int, j: int, k: int, keyf: Callable[[T], S]):
    _sort2(arr, i, j, keyf)
    _sort2(arr, j, k, keyf)
    _sort2(arr, i, j, keyf)

def _pdq_sort[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T], S], bad_allowed: int, leftmost: bool):
    while True:
        size = end - begin
        if size < INSERTION_SORT_THRESHOLD:
            _insertion_sort(arr, begin, end, keyf)
            return

        size_2 = size // 2
        if size > NINTHER_THRESHOLD:
            _sort3(arr, begin, begin + size_2, end - 1, keyf)
            _sort3(arr, begin + 1, begin + (size_2 - 1), end - 2, keyf)
            _sort3(arr, begin + 2, begin + (size_2 + 1), end - 3, keyf)
            _sort3(arr, begin + (size_2 - 1), begin + size_2, begin + (size_2 + 1), keyf)
            arr[begin], arr[begin + size_2] = arr[begin + size_2], arr[begin]
        else:
            _sort3(arr, begin + size_2, begin, end - 1, keyf)

        if not leftmost and keyf(arr[begin - 1]) >= keyf(arr[begin]):
            begin = _partition_left(arr, begin, end, keyf) + 1
            continue

        part_result = _partition_right(arr, begin, end, keyf)
        pivot_pos = part_result[0]
        already_partitioned = part_result[1] == 1

        l_size = pivot_pos - begin
        r_size = end - (pivot_pos + 1)
        highly_unbalanced = (l_size < (size // 8)) or (r_size < (size // 8))

        if highly_unbalanced:
            bad_allowed -= 1
            if bad_allowed == 0:
                _heap_sort(arr, begin, end, keyf)
                return

            if l_size >= INSERTION_SORT_THRESHOLD:
                arr[begin], arr[begin + l_size // 4] = arr[begin + l_size // 4], arr[begin]
                arr[pivot_pos - 1], arr[pivot_pos - l_size // 4] = arr[pivot_pos - l_size // 4], arr[pivot_pos - 1]

                if l_size > NINTHER_THRESHOLD:
                    arr[begin + 1], arr[begin + (l_size // 4 + 1)] = arr[begin + (l_size // 4 + 1)], arr[begin + 1]
                    arr[begin + 2], arr[begin + (l_size // 4 + 2)] = arr[begin + (l_size // 4 + 2)], arr[begin + 2]
                    arr[pivot_pos - 2], arr[pivot_pos - (l_size // 4 + 1)] = arr[pivot_pos - (l_size // 4 + 1)], arr[pivot_pos - 2]
                    arr[pivot_pos - 3], arr[pivot_pos - (l_size // 4 + 2)] = arr[pivot_pos - (l_size // 4 + 2)], arr[pivot_pos - 3]

            if r_size >= INSERTION_SORT_THRESHOLD:
                arr[pivot_pos + 1], arr[pivot_pos + (1 + r_size // 4)] = arr[pivot_pos + (1 + r_size // 4)], arr[pivot_pos + 1]
                arr[end - 1], arr[end - r_size // 4] = arr[end - r_size // 4], arr[end - 1]

                if r_size > NINTHER_THRESHOLD:
                    arr[pivot_pos + 2], arr[pivot_pos + (2 + r_size // 4)] = arr[pivot_pos + (2 + r_size // 4)], arr[pivot_pos + 2]
                    arr[pivot_pos + 3], arr[pivot_pos + (3 + r_size // 4)] = arr[pivot_pos + (3 + r_size // 4)], arr[pivot_pos + 3]
                    arr[end - 2], arr[end - (1 + r_size // 4)] = arr[end - (1 + r_size // 4)], arr[end - 2]
                    arr[end - 3], arr[end - (2 + r_size // 4)] = arr[end - (2 + r_size // 4)], arr[end - 3]

        else:
            if (already_partitioned and _partial_insertion_sort(arr, begin, pivot_pos, keyf) and _partial_insertion_sort(arr, pivot_pos + 1, end, keyf)):
                return

        _pdq_sort(arr, begin, pivot_pos, keyf, bad_allowed, leftmost)
        begin = pivot_pos + 1
        leftmost = False

def pdq_sort_array[S,T](collection: Array[T], size: int, keyf: Callable[[T], S]):
    """
        Pattern-defeating Quicksort
        By Orson Peters, published at https://github.com/orlp/pdqsort

        Sorts the array inplace.
    """
    _pdq_sort(collection, 0, size, keyf, _floor_log2(size), True)

def pdq_sort_inplace[S,T](collection: List[T], keyf: Callable[[T], S]):
    """
        Pattern-defeating Quicksort
        By Orson Peters, published at https://github.com/orlp/pdqsort

        Sorts the list inplace.
    """
    pdq_sort_array(collection.arr, collection.len, keyf)

def pdq_sort[S,T](collection: List[T], keyf: Callable[[T], S]) -> List[T]:
    """
        Pattern-defeating Quicksort
        By Orson Peters, published at https://github.com/orlp/pdqsort

        Returns a sorted list.
    """
    newlst = copy(collection)
    pdq_sort_inplace(newlst, keyf)
    return newlst