# (c) 2022 Exaloop Inc. All rights reserved.

# 'heap' is a heap at all indices >= startpos, except possibly for pos.  pos
# is the index of a leaf with a possibly out-of-order value.  Restore the
# heap invariant.
def _siftdown(heap: List[T], startpos: int, pos: int, T: type):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem


def _siftup(heap: List[T], pos: int, T: type):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    # Bubble up the smaller child until hitting a leaf.
    childpos = 2 * pos + 1  # leftmost child position
    while childpos < endpos:
        # Set childpos to index of smaller child.
        rightpos = childpos + 1
        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        # Move the smaller child up.
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2 * pos + 1
    # The leaf at pos is empty now.  Put newitem there, and bubble it up
    # to its final resting place (by sifting its parents down).
    heap[pos] = newitem
    _siftdown(heap, startpos, pos)


def _siftdown_max(heap: List[T], startpos: int, pos: int, T: type):
    "Maxheap variant of _siftdown"
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if parent < newitem:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem


def _siftup_max(heap: List[T], pos: int, T: type):
    "Maxheap variant of _siftup"
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    # Bubble up the larger child until hitting a leaf.
    childpos = 2 * pos + 1  # leftmost child position
    while childpos < endpos:
        # Set childpos to index of larger child.
        rightpos = childpos + 1
        if rightpos < endpos and not heap[rightpos] < heap[childpos]:
            childpos = rightpos
        # Move the larger child up.
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2 * pos + 1
    # The leaf at pos is empty now.  Put newitem there, and bubble it up
    # to its final resting place (by sifting its parents down).
    heap[pos] = newitem
    _siftdown_max(heap, startpos, pos)


def heappush(heap: List[T], item: T, T: type):
    """Push item onto heap, maintaining the heap invariant."""
    heap.append(item)
    _siftdown(heap, 0, len(heap) - 1)


def heappop(heap: List[T], T: type) -> T:
    """Pop the smallest item off the heap, maintaining the heap invariant."""
    lastelt = heap.pop()  # raises appropriate IndexError if heap is empty
    if heap:
        returnitem = heap[0]
        heap[0] = lastelt
        _siftup(heap, 0)
        return returnitem
    return lastelt


def heapreplace(heap: List[T], item: T, T: type) -> T:
    """
    Pop and return the current smallest value, and add the new item.
    This is more efficient than heappop() followed by heappush(), and can be
    more appropriate when using a fixed-size heap.  Note that the value
    returned may be larger than item!  That constrains reasonable uses of
    this routine unless written as part of a conditional replacement:
    ``if item > heap[0]: item = heapreplace(heap, item)``.
    """
    returnitem = heap[0]  # raises appropriate IndexError if heap is empty
    heap[0] = item
    _siftup(heap, 0)
    return returnitem


def heappushpop(heap: List[T], item: T, T: type) -> T:
    """Fast version of a heappush followed by a heappop."""
    if heap and heap[0] < item:
        item, heap[0] = heap[0], item
        _siftup(heap, 0)
    return item


def heapify(x: List[T], T: type):
    """Transform list into a heap, in-place, in $O(len(x))$ time."""
    n = len(x)
    # Transform bottom-up.  The largest index there's any point to looking at
    # is the largest with a child index in-range, so must have 2*i + 1 < n,
    # or i < (n-1)/2.  If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
    # j-1 is the largest, which is n//2 - 1.  If n is odd = 2*j+1, this is
    # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.
    for i in reversed(range(n // 2)):
        _siftup(x, i)


def _heappop_max(heap: List[T], T: type) -> T:
    """Maxheap version of a heappop."""
    lastelt = heap.pop()  # raises appropriate IndexError if heap is empty
    if heap:
        returnitem = heap[0]
        heap[0] = lastelt
        _siftup_max(heap, 0)
        return returnitem
    return lastelt


def _heapreplace_max(heap: List[T], item: T, T: type) -> T:
    """Maxheap version of a heappop followed by a heappush."""
    returnitem = heap[0]  # raises appropriate IndexError if heap is empty
    heap[0] = item
    _siftup_max(heap, 0)
    return returnitem


def _heapify_max(x: List[T], T: type):
    """Transform list into a maxheap, in-place, in O(len(x)) time."""
    n = len(x)
    for i in reversed(range(n // 2)):
        _siftup_max(x, i)


def nsmallest(n: int, iterable: Generator[T], key=Optional[int](), T: type) -> List[T]:
    """Find the n smallest elements in a dataset.
    Equivalent to:  sorted(iterable, key=key)[:n]
    """
    if n == 1:
        v = List(1)
        for a in iterable:
            if not v:
                v.append(a)
            else:
                if not isinstance(key, Optional):
                    if key(a) < key(v[0]):
                        v[0] = a
                elif a < v[0]:
                    v[0] = a
        return v

    # When key is none, use simpler decoration
    if isinstance(key, Optional):
        it = iter(iterable)
        # put the range(n) first so that zip() doesn't
        # consume one too many elements from the iterator
        result = List(n)
        done = False
        for i in range(n):
            if it.done():
                done = True
                break
            result.append((it.next(), i))
        if not result:
            it.destroy()
            return []
        _heapify_max(result)
        top = result[0][0]
        order = n
        if not done:
            for elem in it:
                if elem < top:
                    _heapreplace_max(result, (elem, order))
                    top, _order = result[0]
                    order += 1
        else:
            it.destroy()
        result.sort()
        return [elem for elem, order in result]
    else:
        # General case, slowest method
        it = iter(iterable)
        result = List(n)
        done = False
        for i in range(n):
            if it.done():
                done = True
                break
            elem = it.next()
            result.append((key(elem), i, elem))
        if not result:
            it.destroy()
            return []
        _heapify_max(result)
        top = result[0][0]
        order = n
        if not done:
            for elem in it:
                k = key(elem)
                if k < top:
                    _heapreplace_max(result, (k, order, elem))
                    top, _order, _elem = result[0]
                    order += 1
        else:
            it.destroy()
        result.sort()
        return [elem for k, order, elem in result]


def nlargest(n: int, iterable: Generator[T], key=Optional[int](), T: type) -> List[T]:
    """Find the n largest elements in a dataset.
    Equivalent to:  sorted(iterable, key=key, reverse=True)[:n]
    """
    if n == 1:
        v = List(1)
        for a in iterable:
            if not v:
                v.append(a)
            else:
                if not isinstance(key, Optional):
                    if key(a) > key(v[0]):
                        v[0] = a
                elif a > v[0]:
                    v[0] = a
        return v

    # When key is none, use simpler decoration
    if isinstance(key, Optional):
        it = iter(iterable)
        result = List(n)
        done = False
        for i in range(0, -n, -1):
            if it.done():
                done = True
                break
            result.append((it.next(), i))
        if not result:
            it.destroy()
            return []
        heapify(result)
        top = result[0][0]
        order = -n
        if not done:
            for elem in it:
                if top < elem:
                    heapreplace(result, (elem, order))
                    top, _order = result[0]
                    order -= 1
        else:
            it.destroy()
        result.sort()
        return [elem for elem, order in reversed(result)]
    else:
        # General case, slowest method
        it = iter(iterable)
        result = List(n)
        done = False
        for i in range(0, -n, -1):
            if it.done():
                done = True
                break
            elem = it.next()
            result.append((key(elem), i, elem))
        if not result:
            return []
        heapify(result)
        top = result[0][0]
        order = -n
        if not done:
            for elem in it:
                k = key(elem)
                if top < k:
                    heapreplace(result, (k, order, elem))
                    top, _order, _elem = result[0]
                    order -= 1
        else:
            it.destroy()
        result.sort()
        return [elem for k, order, elem in reversed(result)]


@tuple
class _MergeItem:
    value: T
    order: int
    gen: Generator[T]
    key: S
    T: type
    S: type

    def __lt__(self, other):
        if isinstance(self.key, Optional):
            return (self.value, self.order) < (other.value, other.order)
        else:
            return (self.key(self.value), self.order, self.value) < (
                other.key(other.value),
                other.order,
                other.value,
            )


def merge(*iterables, key=Optional[int](), reverse: bool = False):
    items = []

    # TODO: unify types of different compatible functions
    # TODO: lambdas with void?
    def _heapify(x):
        if reverse:
            _heapify_max(x)
        else:
            heapify(x)

    _heappop = lambda x: _heappop_max(x) if reverse else heappop(x)
    _heapreplace = lambda x, s: _heapreplace_max(x, s) if reverse else heapreplace(x, s)
    direction = -1 if reverse else 1

    order = 0
    for it in iterables:
        gen = iter(it)
        if not gen.done():
            items.append(_MergeItem(gen.next(), order * direction, gen, key))
        order += 1
    _heapify(items)
    while len(items) > 1:
        while True:
            # TODO: @tuple unpacking does not work
            value, order, gen = items[0].value, items[0].order, items[0].gen
            yield value
            if gen.done():
                _heappop(items)
                break
            _heapreplace(items, _MergeItem(gen.next(), order, gen, key))
    if items:
        # fast case when only a single iterator remains
        value, order, gen = items[0].value, items[0].order, items[0].gen
        yield value
        yield from gen