codon/stdlib/heapq.codon

346 lines
11 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos
# is the index of a leaf with a possibly out-of-order value. Restore the
# heap invariant.
def _siftdown(heap: List[T], startpos: int, pos: int, T: type):
newitem = heap[pos]
# Follow the path to the root, moving parents down until finding a place
# newitem fits.
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if newitem < parent:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
def _siftup(heap: List[T], pos: int, T: type):
endpos = len(heap)
startpos = pos
newitem = heap[pos]
# Bubble up the smaller child until hitting a leaf.
childpos = 2 * pos + 1 # leftmost child position
while childpos < endpos:
# Set childpos to index of smaller child.
rightpos = childpos + 1
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
childpos = rightpos
# Move the smaller child up.
heap[pos] = heap[childpos]
pos = childpos
childpos = 2 * pos + 1
# The leaf at pos is empty now. Put newitem there, and bubble it up
# to its final resting place (by sifting its parents down).
heap[pos] = newitem
_siftdown(heap, startpos, pos)
def _siftdown_max(heap: List[T], startpos: int, pos: int, T: type):
"Maxheap variant of _siftdown"
newitem = heap[pos]
# Follow the path to the root, moving parents down until finding a place
# newitem fits.
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
if parent < newitem:
heap[pos] = parent
pos = parentpos
continue
break
heap[pos] = newitem
def _siftup_max(heap: List[T], pos: int, T: type):
"Maxheap variant of _siftup"
endpos = len(heap)
startpos = pos
newitem = heap[pos]
# Bubble up the larger child until hitting a leaf.
childpos = 2 * pos + 1 # leftmost child position
while childpos < endpos:
# Set childpos to index of larger child.
rightpos = childpos + 1
if rightpos < endpos and not heap[rightpos] < heap[childpos]:
childpos = rightpos
# Move the larger child up.
heap[pos] = heap[childpos]
pos = childpos
childpos = 2 * pos + 1
# The leaf at pos is empty now. Put newitem there, and bubble it up
# to its final resting place (by sifting its parents down).
heap[pos] = newitem
_siftdown_max(heap, startpos, pos)
def heappush(heap: List[T], item: T, T: type):
"""Push item onto heap, maintaining the heap invariant."""
heap.append(item)
_siftdown(heap, 0, len(heap) - 1)
def heappop(heap: List[T], T: type) -> T:
"""Pop the smallest item off the heap, maintaining the heap invariant."""
lastelt = heap.pop() # raises appropriate IndexError if heap is empty
if heap:
returnitem = heap[0]
heap[0] = lastelt
_siftup(heap, 0)
return returnitem
return lastelt
def heapreplace(heap: List[T], item: T, T: type) -> T:
"""
Pop and return the current smallest value, and add the new item.
This is more efficient than heappop() followed by heappush(), and can be
more appropriate when using a fixed-size heap. Note that the value
returned may be larger than item! That constrains reasonable uses of
this routine unless written as part of a conditional replacement:
``if item > heap[0]: item = heapreplace(heap, item)``.
"""
returnitem = heap[0] # raises appropriate IndexError if heap is empty
heap[0] = item
_siftup(heap, 0)
return returnitem
def heappushpop(heap: List[T], item: T, T: type) -> T:
"""Fast version of a heappush followed by a heappop."""
if heap and heap[0] < item:
item, heap[0] = heap[0], item
_siftup(heap, 0)
return item
def heapify(x: List[T], T: type):
"""Transform list into a heap, in-place, in $O(len(x))$ time."""
n = len(x)
# Transform bottom-up. The largest index there's any point to looking at
# is the largest with a child index in-range, so must have 2*i + 1 < n,
# or i < (n-1)/2. If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
# j-1 is the largest, which is n//2 - 1. If n is odd = 2*j+1, this is
# (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.
for i in reversed(range(n // 2)):
_siftup(x, i)
def _heappop_max(heap: List[T], T: type) -> T:
"""Maxheap version of a heappop."""
lastelt = heap.pop() # raises appropriate IndexError if heap is empty
if heap:
returnitem = heap[0]
heap[0] = lastelt
_siftup_max(heap, 0)
return returnitem
return lastelt
def _heapreplace_max(heap: List[T], item: T, T: type) -> T:
"""Maxheap version of a heappop followed by a heappush."""
returnitem = heap[0] # raises appropriate IndexError if heap is empty
heap[0] = item
_siftup_max(heap, 0)
return returnitem
def _heapify_max(x: List[T], T: type):
"""Transform list into a maxheap, in-place, in O(len(x)) time."""
n = len(x)
for i in reversed(range(n // 2)):
_siftup_max(x, i)
def nsmallest(n: int, iterable: Generator[T], key=Optional[int](), T: type) -> List[T]:
"""Find the n smallest elements in a dataset.
Equivalent to: sorted(iterable, key=key)[:n]
"""
if n == 1:
v = List(1)
for a in iterable:
if not v:
v.append(a)
else:
if not isinstance(key, Optional):
if key(a) < key(v[0]):
v[0] = a
elif a < v[0]:
v[0] = a
return v
# When key is none, use simpler decoration
if isinstance(key, Optional):
it = iter(iterable)
# put the range(n) first so that zip() doesn't
# consume one too many elements from the iterator
result = List(n)
done = False
for i in range(n):
if it.done():
done = True
break
result.append((it.next(), i))
if not result:
it.destroy()
return []
_heapify_max(result)
top = result[0][0]
order = n
if not done:
for elem in it:
if elem < top:
_heapreplace_max(result, (elem, order))
top, _order = result[0]
order += 1
else:
it.destroy()
result.sort()
return [elem for elem, order in result]
else:
# General case, slowest method
it = iter(iterable)
result = List(n)
done = False
for i in range(n):
if it.done():
done = True
break
elem = it.next()
result.append((key(elem), i, elem))
if not result:
it.destroy()
return []
_heapify_max(result)
top = result[0][0]
order = n
if not done:
for elem in it:
k = key(elem)
if k < top:
_heapreplace_max(result, (k, order, elem))
top, _order, _elem = result[0]
order += 1
else:
it.destroy()
result.sort()
return [elem for k, order, elem in result]
def nlargest(n: int, iterable: Generator[T], key=Optional[int](), T: type) -> List[T]:
"""Find the n largest elements in a dataset.
Equivalent to: sorted(iterable, key=key, reverse=True)[:n]
"""
if n == 1:
v = List(1)
for a in iterable:
if not v:
v.append(a)
else:
if not isinstance(key, Optional):
if key(a) > key(v[0]):
v[0] = a
elif a > v[0]:
v[0] = a
return v
# When key is none, use simpler decoration
if isinstance(key, Optional):
it = iter(iterable)
result = List(n)
done = False
for i in range(0, -n, -1):
if it.done():
done = True
break
result.append((it.next(), i))
if not result:
it.destroy()
return []
heapify(result)
top = result[0][0]
order = -n
if not done:
for elem in it:
if top < elem:
heapreplace(result, (elem, order))
top, _order = result[0]
order -= 1
else:
it.destroy()
result.sort()
return [elem for elem, order in reversed(result)]
else:
# General case, slowest method
it = iter(iterable)
result = List(n)
done = False
for i in range(0, -n, -1):
if it.done():
done = True
break
elem = it.next()
result.append((key(elem), i, elem))
if not result:
return []
heapify(result)
top = result[0][0]
order = -n
if not done:
for elem in it:
k = key(elem)
if top < k:
heapreplace(result, (k, order, elem))
top, _order, _elem = result[0]
order -= 1
else:
it.destroy()
result.sort()
return [elem for k, order, elem in reversed(result)]
@tuple
class _MergeItem:
value: T
order: int
gen: Generator[T]
key: S
T: type
S: type
def __lt__(self, other):
if isinstance(self.key, Optional):
return (self.value, self.order) < (other.value, other.order)
else:
return (self.key(self.value), self.order, self.value) < (
other.key(other.value),
other.order,
other.value,
)
def merge(*iterables, key=Optional[int](), reverse: bool = False):
items = []
# TODO: unify types of different compatible functions
# TODO: lambdas with void?
def _heapify(x):
if reverse:
_heapify_max(x)
else:
heapify(x)
_heappop = lambda x: _heappop_max(x) if reverse else heappop(x)
_heapreplace = lambda x, s: _heapreplace_max(x, s) if reverse else heapreplace(x, s)
direction = -1 if reverse else 1
order = 0
for it in iterables:
gen = iter(it)
if not gen.done():
items.append(_MergeItem(gen.next(), order * direction, gen, key))
order += 1
_heapify(items)
while len(items) > 1:
while True:
# TODO: @tuple unpacking does not work
value, order, gen = items[0].value, items[0].order, items[0].gen
yield value
if gen.done():
_heappop(items)
break
_heapreplace(items, _MergeItem(gen.next(), order, gen, key))
if items:
# fast case when only a single iterator remains
value, order, gen = items[0].value, items[0].order, items[0].gen
yield value
yield from gen