# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
# Parts of this file: https://github.com/orlp/pdqsort
# License:
#    Copyright (c) 2021 Orson Peters <orsonpeters@gmail.com>
#
#    This software is provided 'as-is', without any express or implied warranty. In no event will the
#    authors be held liable for any damages arising from the use of this software.
#
#    Permission is granted to anyone to use this software for any purpose, including commercial
#    applications, and to alter it and redistribute it freely, subject to the following restrictions:
#
#    1. The origin of this software must not be misrepresented; you must not claim that you wrote the
#    original software. If you use this software in a product, an acknowledgment in the product
#    documentation would be appreciated but is not required.
#
#    2. Altered source versions must be plainly marked as such, and must not be misrepresented as
#    being the original software.
#
#    3. This notice may not be removed or altered from any source distribution.

INSERTION_SORT_THRESHOLD = 24
NINTHER_THRESHOLD = 128
PARTIAL_INSERTION_SORT_LIMIT = 8

from algorithms.insertionsort import _insertion_sort
from algorithms.heapsort import _heap_sort

def _floor_log2(n: int) -> int:
    """Returns floor(log2(n))"""
    log = 0
    while True:
        n >>= 1
        if n == 0:
            break
        log += 1
    return log

def _partial_insertion_sort(
    arr: Array[T], begin: int, end: int, keyf: Callable[[T], S], T: type, S: type
) -> bool:
    if begin == end:
        return True

    limit = 0
    cur = begin + 1
    while cur != end:
        if limit > PARTIAL_INSERTION_SORT_LIMIT:
            return False

        sift = cur
        sift_1 = cur - 1

        if keyf(arr[sift]) < keyf(arr[sift_1]):
            tmp = arr[sift]

            while True:
                arr[sift] = arr[sift_1]
                sift -= 1
                sift_1 -= 1
                if sift == begin or not keyf(tmp) < keyf(arr[sift_1]):
                    break

            arr[sift] = tmp
            limit += cur - sift

        cur += 1

    return True

def _partition_left(
    arr: Array[T], begin: int, end: int, keyf: Callable[[T], S], T: type, S: type
) -> int:
    pivot = arr[begin]
    first = begin
    last = end

    while True:
        last -= 1
        if not keyf(pivot) < keyf(arr[last]):
            break

    if last + 1 == end:
        while first < last:
            first += 1
            if keyf(pivot) < keyf(arr[first]):
                break

    else:
        while True:
            first += 1
            if keyf(pivot) < keyf(arr[first]):
                break

    while first < last:
        arr[first], arr[last] = arr[last], arr[first]
        while True:
            last -= 1
            if not keyf(pivot) < keyf(arr[last]):
                break
        while True:
            first += 1
            if keyf(pivot) < keyf(arr[first]):
                break

    pivot_pos = last
    arr[begin] = arr[pivot_pos]
    arr[pivot_pos] = pivot

    return pivot_pos

def _partition_right(
    arr: Array[T], begin: int, end: int, keyf: Callable[[T], S], T: type, S: type
) -> Tuple[int, int]:
    pivot = arr[begin]
    first = begin
    last = end

    while True:
        first += 1
        if not keyf(arr[first]) < keyf(pivot):
            break

    if first - 1 == begin:
        while first < last:
            last -= 1
            if keyf(arr[last]) < keyf(pivot):
                break

    else:
        while True:
            last -= 1
            if keyf(arr[last]) < keyf(pivot):
                break

    already_partitioned = 0
    if first >= last:
        already_partitioned = 1

    while first < last:
        arr[first], arr[last] = arr[last], arr[first]

        while True:
            first += 1
            if not keyf(arr[first]) < keyf(pivot):
                break

        while True:
            last -= 1
            if keyf(arr[last]) < keyf(pivot):
                break

    pivot_pos = first - 1
    arr[begin] = arr[pivot_pos]
    arr[pivot_pos] = pivot

    return (pivot_pos, already_partitioned)

def _sort2(
    arr: Array[T], i: int, j: int, keyf: Callable[[T], S], T: type, S: type
):
    if keyf(arr[j]) < keyf(arr[i]):
        arr[i], arr[j] = arr[j], arr[i]

def _sort3(
    arr: Array[T], i: int, j: int, k: int, keyf: Callable[[T], S], T: type, S: type
):
    _sort2(arr, i, j, keyf)
    _sort2(arr, j, k, keyf)
    _sort2(arr, i, j, keyf)

def _pdq_sort(
    arr: Array[T],
    begin: int,
    end: int,
    keyf: Callable[[T], S],
    bad_allowed: int,
    leftmost: bool,
    T: type,
    S: type,
):
    while True:
        size = end - begin
        if size < INSERTION_SORT_THRESHOLD:
            _insertion_sort(arr, begin, end, keyf)
            return

        size_2 = size // 2
        if size > NINTHER_THRESHOLD:
            _sort3(arr, begin, begin + size_2, end - 1, keyf)
            _sort3(arr, begin + 1, begin + (size_2 - 1), end - 2, keyf)
            _sort3(arr, begin + 2, begin + (size_2 + 1), end - 3, keyf)
            _sort3(
                arr, begin + (size_2 - 1), begin + size_2, begin + (size_2 + 1), keyf
            )
            arr[begin], arr[begin + size_2] = arr[begin + size_2], arr[begin]
        else:
            _sort3(arr, begin + size_2, begin, end - 1, keyf)

        if not leftmost and not keyf(arr[begin - 1]) < keyf(arr[begin]):
            begin = _partition_left(arr, begin, end, keyf) + 1
            continue

        part_result = _partition_right(arr, begin, end, keyf)
        pivot_pos = part_result[0]
        already_partitioned = part_result[1] == 1

        l_size = pivot_pos - begin
        r_size = end - (pivot_pos + 1)
        highly_unbalanced = (l_size < (size // 8)) or (r_size < (size // 8))

        if highly_unbalanced:
            bad_allowed -= 1
            if bad_allowed == 0:
                _heap_sort(arr, begin, end, keyf)
                return

            if l_size >= INSERTION_SORT_THRESHOLD:
                arr[begin], arr[begin + l_size // 4] = (
                    arr[begin + l_size // 4],
                    arr[begin],
                )
                arr[pivot_pos - 1], arr[pivot_pos - l_size // 4] = (
                    arr[pivot_pos - l_size // 4],
                    arr[pivot_pos - 1],
                )

                if l_size > NINTHER_THRESHOLD:
                    arr[begin + 1], arr[begin + (l_size // 4 + 1)] = (
                        arr[begin + (l_size // 4 + 1)],
                        arr[begin + 1],
                    )
                    arr[begin + 2], arr[begin + (l_size // 4 + 2)] = (
                        arr[begin + (l_size // 4 + 2)],
                        arr[begin + 2],
                    )
                    arr[pivot_pos - 2], arr[pivot_pos - (l_size // 4 + 1)] = (
                        arr[pivot_pos - (l_size // 4 + 1)],
                        arr[pivot_pos - 2],
                    )
                    arr[pivot_pos - 3], arr[pivot_pos - (l_size // 4 + 2)] = (
                        arr[pivot_pos - (l_size // 4 + 2)],
                        arr[pivot_pos - 3],
                    )

            if r_size >= INSERTION_SORT_THRESHOLD:
                arr[pivot_pos + 1], arr[pivot_pos + (1 + r_size // 4)] = (
                    arr[pivot_pos + (1 + r_size // 4)],
                    arr[pivot_pos + 1],
                )
                arr[end - 1], arr[end - r_size // 4] = (
                    arr[end - r_size // 4],
                    arr[end - 1],
                )

                if r_size > NINTHER_THRESHOLD:
                    arr[pivot_pos + 2], arr[pivot_pos + (2 + r_size // 4)] = (
                        arr[pivot_pos + (2 + r_size // 4)],
                        arr[pivot_pos + 2],
                    )
                    arr[pivot_pos + 3], arr[pivot_pos + (3 + r_size // 4)] = (
                        arr[pivot_pos + (3 + r_size // 4)],
                        arr[pivot_pos + 3],
                    )
                    arr[end - 2], arr[end - (1 + r_size // 4)] = (
                        arr[end - (1 + r_size // 4)],
                        arr[end - 2],
                    )
                    arr[end - 3], arr[end - (2 + r_size // 4)] = (
                        arr[end - (2 + r_size // 4)],
                        arr[end - 3],
                    )

        else:
            if (
                already_partitioned
                and _partial_insertion_sort(arr, begin, pivot_pos, keyf)
                and _partial_insertion_sort(arr, pivot_pos + 1, end, keyf)
            ):
                return

        _pdq_sort(arr, begin, pivot_pos, keyf, bad_allowed, leftmost)
        begin = pivot_pos + 1
        leftmost = False

def pdq_sort_array(
    collection: Array[T], size: int, keyf: Callable[[T], S], T: type, S: type
):
    """
    Pattern-defeating Quicksort
    By Orson Peters, published at https://github.com/orlp/pdqsort

    Sorts the array inplace.
    """
    _pdq_sort(collection, 0, size, keyf, _floor_log2(size), True)

def pdq_sort_inplace(
    collection: List[T], keyf: Callable[[T], S], T: type, S: type
):
    """
    Pattern-defeating Quicksort
    By Orson Peters, published at https://github.com/orlp/pdqsort

    Sorts the list inplace.
    """
    pdq_sort_array(collection.arr, collection.len, keyf)

def pdq_sort(collection: List[T], keyf: Callable[[T], S], T: type, S: type) -> List[T]:
    """
    Pattern-defeating Quicksort
    By Orson Peters, published at https://github.com/orlp/pdqsort

    Returns a sorted list.
    """
    newlst = collection.__copy__()
    pdq_sort_inplace(newlst, keyf)
    return newlst