mirror of https://github.com/exaloop/codon.git
759 lines
18 KiB
Python
759 lines
18 KiB
Python
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
|
|
|
from internal.types.optional import unwrap
|
|
|
|
# Infinite iterators
|
|
|
|
@inline
|
|
def count(start: T = 0, step: T = 1, T: type) -> Generator[T]:
|
|
"""
|
|
Return a count object whose ``__next__`` method returns consecutive values.
|
|
"""
|
|
n = start
|
|
while True:
|
|
yield n
|
|
n += step
|
|
|
|
@inline
|
|
def cycle(iterable: Generator[T], T: type) -> Generator[T]:
|
|
"""
|
|
Cycles repeatedly through an iterable.
|
|
"""
|
|
saved = []
|
|
for element in iterable:
|
|
yield element
|
|
saved.append(element)
|
|
while saved:
|
|
for element in saved:
|
|
yield element
|
|
|
|
@inline
|
|
def repeat(object: T, times: Optional[int] = None, T: type) -> Generator[T]:
|
|
"""
|
|
Make an iterator that returns a given object over and over again.
|
|
"""
|
|
if times is None:
|
|
while True:
|
|
yield object
|
|
else:
|
|
for i in range(times):
|
|
yield object
|
|
|
|
# Iterators terminating on the shortest input sequence
|
|
|
|
@inline
|
|
def accumulate(iterable: Generator[T], func=lambda a, b: a + b, initial=0, T: type):
|
|
"""
|
|
Make an iterator that returns accumulated sums, or accumulated results
|
|
of other binary functions (specified via the optional func argument).
|
|
"""
|
|
total = initial
|
|
yield total
|
|
for element in iterable:
|
|
total = func(total, element)
|
|
yield total
|
|
|
|
@inline
|
|
@overload
|
|
def accumulate(iterable: Generator[T], func=lambda a, b: a + b, T: type):
|
|
"""
|
|
Make an iterator that returns accumulated sums, or accumulated results
|
|
of other binary functions (specified via the optional func argument).
|
|
"""
|
|
total = None
|
|
for element in iterable:
|
|
total = element if total is None else func(unwrap(total), element)
|
|
yield unwrap(total)
|
|
|
|
@tuple
|
|
class chain:
|
|
"""
|
|
Make an iterator that returns elements from the first iterable until it is exhausted,
|
|
then proceeds to the next iterable, until all of the iterables are exhausted.
|
|
"""
|
|
|
|
@inline
|
|
def __new__(*iterables):
|
|
for it in iterables:
|
|
for element in it:
|
|
yield element
|
|
|
|
@inline
|
|
def from_iterable(iterables):
|
|
for it in iterables:
|
|
for element in it:
|
|
yield element
|
|
|
|
@inline
|
|
def compress(
|
|
data: Generator[T], selectors: Generator[B], T: type, B: type
|
|
) -> Generator[T]:
|
|
"""
|
|
Return data elements corresponding to true selector elements.
|
|
Forms a shorter iterator from selected data elements using the selectors to
|
|
choose the data elements.
|
|
"""
|
|
for d, s in zip(data, selectors):
|
|
if s:
|
|
yield d
|
|
|
|
@inline
|
|
def dropwhile(
|
|
predicate: Callable[[T], bool], iterable: Generator[T], T: type
|
|
) -> Generator[T]:
|
|
"""
|
|
Drop items from the iterable while predicate(item) is true.
|
|
Afterwards, return every element until the iterable is exhausted.
|
|
"""
|
|
b = False
|
|
for x in iterable:
|
|
if not b and not predicate(x):
|
|
b = True
|
|
if b:
|
|
yield x
|
|
|
|
@inline
|
|
def filterfalse(
|
|
predicate: Callable[[T], bool], iterable: Generator[T], T: type
|
|
) -> Generator[T]:
|
|
"""
|
|
Return those items of iterable for which function(item) is false.
|
|
"""
|
|
for x in iterable:
|
|
if not predicate(x):
|
|
yield x
|
|
|
|
# TODO: fix this once Optional[Callable] lands
|
|
@inline
|
|
def groupby(iterable, key=Optional[int]()):
|
|
"""
|
|
Make an iterator that returns consecutive keys and groups from the iterable.
|
|
"""
|
|
currkey = None
|
|
group = []
|
|
|
|
for currvalue in iterable:
|
|
k = currvalue if isinstance(key, Optional) else key(currvalue)
|
|
if currkey is None:
|
|
currkey = k
|
|
if k != unwrap(currkey):
|
|
yield unwrap(currkey), group
|
|
currkey = k
|
|
group = []
|
|
group.append(currvalue)
|
|
if currkey is not None:
|
|
yield unwrap(currkey), group
|
|
|
|
def islice(iterable: Generator[T], stop: Optional[int], T: type) -> Generator[T]:
|
|
"""
|
|
Make an iterator that returns selected elements from the iterable.
|
|
"""
|
|
if stop is not None and stop.__val__() < 0:
|
|
raise ValueError(
|
|
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize."
|
|
)
|
|
i = 0
|
|
for x in iterable:
|
|
if stop is not None and i >= stop.__val__():
|
|
break
|
|
yield x
|
|
i += 1
|
|
|
|
@overload
|
|
def islice(
|
|
iterable: Generator[T],
|
|
start: Optional[int],
|
|
stop: Optional[int],
|
|
step: Optional[int] = None,
|
|
T: type,
|
|
) -> Generator[T]:
|
|
"""
|
|
Make an iterator that returns selected elements from the iterable.
|
|
"""
|
|
from sys import maxsize
|
|
|
|
start: int = 0 if start is None else start
|
|
stop: int = maxsize if stop is None else stop
|
|
step: int = 1 if step is None else step
|
|
have_stop = False
|
|
|
|
if start < 0 or stop < 0:
|
|
raise ValueError(
|
|
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize."
|
|
)
|
|
elif step < 0:
|
|
raise ValueError("Step for islice() must be a positive integer or None.")
|
|
|
|
it = range(start, stop, step)
|
|
N = len(it)
|
|
idx = 0
|
|
b = -1
|
|
|
|
if N == 0:
|
|
for i, element in zip(range(start), iterable):
|
|
pass
|
|
return
|
|
|
|
nexti = it[0]
|
|
for i, element in enumerate(iterable):
|
|
if i == nexti:
|
|
yield element
|
|
idx += 1
|
|
if idx >= N:
|
|
b = i
|
|
break
|
|
nexti = it[idx]
|
|
|
|
if b >= 0:
|
|
for i, element in zip(range(b + 1, stop), iterable):
|
|
pass
|
|
|
|
@inline
|
|
def starmap(function, iterable):
|
|
"""
|
|
Return an iterator whose values are returned from the function
|
|
evaluated with an argument tuple taken from the given sequence.
|
|
"""
|
|
for args in iterable:
|
|
yield function(*args)
|
|
|
|
@inline
|
|
def takewhile(
|
|
predicate: Callable[[T], bool], iterable: Generator[T], T: type
|
|
) -> Generator[T]:
|
|
"""
|
|
Return successive entries from an iterable as long as the predicate evaluates to true for each entry.
|
|
"""
|
|
for x in iterable:
|
|
if predicate(x):
|
|
yield x
|
|
else:
|
|
break
|
|
|
|
def tee(iterable: Generator[T], n: int = 2, T: type) -> List[Generator[T]]:
|
|
"""
|
|
Return n independent iterators from a single iterable.
|
|
"""
|
|
from collections import deque
|
|
|
|
it = iter(iterable)
|
|
deques = [deque[T]() for i in range(n)]
|
|
|
|
def gen(mydeque: deque[T], T: type) -> Generator[T]:
|
|
while True:
|
|
if not mydeque: # when the local deque is empty
|
|
if it.__done__():
|
|
return
|
|
it.__resume__()
|
|
if it.__done__():
|
|
return
|
|
newval = it.next()
|
|
for d in deques: # load it to all the deques
|
|
d.append(newval)
|
|
yield mydeque.popleft()
|
|
|
|
return [gen(d) for d in deques]
|
|
|
|
@inline
|
|
def zip_longest(*iterables, fillvalue):
|
|
"""
|
|
Make an iterator that aggregates elements from each of the iterables.
|
|
If the iterables are of uneven length, missing values are filled-in
|
|
with fillvalue. Iteration continues until the longest iterable is
|
|
exhausted.
|
|
"""
|
|
if staticlen(iterables) == 2:
|
|
a = iter(iterables[0])
|
|
b = iter(iterables[1])
|
|
a_done = False
|
|
b_done = False
|
|
|
|
while not a.done():
|
|
a_val = a.next()
|
|
b_val = fillvalue
|
|
if not b_done:
|
|
b_done = b.done()
|
|
if not b_done:
|
|
b_val = b.next()
|
|
yield a_val, b_val
|
|
|
|
if not b_done:
|
|
while not b.done():
|
|
yield fillvalue, b.next()
|
|
|
|
a.destroy()
|
|
b.destroy()
|
|
else:
|
|
iterators = tuple(iter(it) for it in iterables)
|
|
num_active = len(iterators)
|
|
if not num_active:
|
|
return
|
|
while True:
|
|
values = []
|
|
for it in iterators:
|
|
if it.__done__(): # already done
|
|
values.append(fillvalue)
|
|
elif it.done(): # resume and check
|
|
num_active -= 1
|
|
if not num_active:
|
|
return
|
|
values.append(fillvalue)
|
|
else:
|
|
values.append(it.next())
|
|
yield values
|
|
|
|
@inline
|
|
@overload
|
|
def zip_longest(*args):
|
|
"""
|
|
Make an iterator that aggregates elements from each of the iterables.
|
|
If the iterables are of uneven length, missing values are filled-in
|
|
with fillvalue. Iteration continues until the longest iterable is
|
|
exhausted.
|
|
"""
|
|
|
|
def get_next(it):
|
|
if it.__done__() or it.done():
|
|
return None
|
|
return it.next()
|
|
|
|
iters = tuple(iter(arg) for arg in args)
|
|
while True:
|
|
done_count = 0
|
|
result = tuple(get_next(it) for it in iters)
|
|
all_none = True
|
|
for a in result:
|
|
if a is not None:
|
|
all_none = False
|
|
if all_none:
|
|
return
|
|
yield result
|
|
for it in iters:
|
|
it.destroy()
|
|
|
|
# Combinatoric iterators
|
|
|
|
def _as_list(x):
|
|
if isinstance(x, list):
|
|
return x
|
|
else:
|
|
return list(x)
|
|
|
|
def product(*iterables, repeat: int):
|
|
if repeat < 0:
|
|
raise ValueError("repeat must be non-negative")
|
|
|
|
if repeat == 0:
|
|
nargs = 0
|
|
else:
|
|
nargs = len(iterables)
|
|
|
|
npools = nargs * repeat
|
|
indices = Ptr[int](npools)
|
|
|
|
pools = list(capacity=npools)
|
|
i = 0
|
|
|
|
while i < nargs:
|
|
p = _as_list(iterables[i])
|
|
if len(p) == 0:
|
|
return
|
|
pools.append(p)
|
|
indices[i] = 0
|
|
i += 1
|
|
|
|
while i < npools:
|
|
pools.append(pools[i - nargs])
|
|
indices[i] = 0
|
|
i += 1
|
|
|
|
result = list(capacity=npools)
|
|
for i in range(npools):
|
|
result.append(pools[i][0])
|
|
|
|
while True:
|
|
yield result
|
|
|
|
result = result.copy()
|
|
i = npools - 1
|
|
while i >= 0:
|
|
pool = pools[i]
|
|
indices[i] += 1
|
|
|
|
if indices[i] == len(pool):
|
|
indices[i] = 0
|
|
result[i] = pool[0]
|
|
else:
|
|
result[i] = pool[indices[i]]
|
|
break
|
|
|
|
i -= 1
|
|
|
|
if i < 0:
|
|
break
|
|
|
|
@overload
|
|
def product(*iterables, repeat: Static[int] = 1):
|
|
if repeat < 0:
|
|
compile_error("repeat must be non-negative")
|
|
|
|
# handle some common cases
|
|
if repeat == 0:
|
|
yield ()
|
|
elif repeat == 1 and staticlen(iterables) == 1:
|
|
it0 = iterables[0]
|
|
for a in it0:
|
|
yield (a,)
|
|
elif repeat == 1 and staticlen(iterables) == 2:
|
|
it0 = iterables[0]
|
|
it1 = iterables[1]
|
|
for a in it0:
|
|
for b in it1:
|
|
yield (a, b)
|
|
elif repeat == 1 and staticlen(iterables) == 3:
|
|
it0 = iterables[0]
|
|
it1 = iterables[1]
|
|
it2 = iterables[2]
|
|
for a in it0:
|
|
for b in it1:
|
|
for c in it2:
|
|
yield (a, b, c)
|
|
else:
|
|
nargs: Static[int] = staticlen(iterables)
|
|
npools: Static[int] = nargs * repeat
|
|
indices_tuple = (0,) * npools
|
|
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
|
|
pools = tuple(_as_list(it) for it in iterables) * repeat
|
|
|
|
for i in staticrange(nargs):
|
|
if len(pools[i]) == 0:
|
|
return
|
|
|
|
result = tuple(pool[0] for pool in pools)
|
|
|
|
while True:
|
|
yield result
|
|
|
|
i = npools - 1
|
|
while i >= 0:
|
|
pool = pools[i]
|
|
indices[i] += 1
|
|
|
|
if indices[i] == len(pool):
|
|
indices[i] = 0
|
|
else:
|
|
break
|
|
|
|
i -= 1
|
|
|
|
if i < 0:
|
|
break
|
|
|
|
result = tuple(pools[i][indices[i]] for i in staticrange(npools))
|
|
|
|
def combinations(pool, r: int):
|
|
if r < 0:
|
|
raise ValueError("r must be non-negative")
|
|
|
|
pool_list = _as_list(pool)
|
|
n = len(pool)
|
|
|
|
if r > n:
|
|
return
|
|
|
|
pool = pool_list.arr.ptr
|
|
indices = Ptr[int](r)
|
|
result = list(capacity=r)
|
|
|
|
for i in range(r):
|
|
indices[i] = i
|
|
result.append(pool[i])
|
|
|
|
while True:
|
|
yield result
|
|
|
|
i = r - 1
|
|
while i >= 0 and indices[i] == i + n - r:
|
|
i -= 1
|
|
|
|
if i < 0:
|
|
break
|
|
|
|
indices[i] += 1
|
|
|
|
for j in range(i + 1, r):
|
|
indices[j] = indices[j-1] + 1
|
|
|
|
result = result.copy()
|
|
while i < r:
|
|
result[i] = pool[indices[i]]
|
|
i += 1
|
|
|
|
@overload
|
|
def combinations(pool, r: Static[int]):
|
|
def empty(T: type) -> T:
|
|
pass
|
|
|
|
if r < 0:
|
|
compile_error("r must be non-negative")
|
|
|
|
if isinstance(pool, list):
|
|
pool_list = pool
|
|
else:
|
|
pool_list = list(pool)
|
|
|
|
n = len(pool)
|
|
|
|
if r > n:
|
|
return
|
|
|
|
pool = pool_list.arr.ptr
|
|
indices_tuple = (0,) * r
|
|
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
|
|
result_tuple = (empty(pool.T),) * r
|
|
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
|
|
|
|
for i in range(r):
|
|
indices[i] = i
|
|
result[i] = pool[i]
|
|
|
|
while True:
|
|
yield result_tuple
|
|
|
|
i = r - 1
|
|
while i >= 0 and indices[i] == i + n - r:
|
|
i -= 1
|
|
|
|
if i < 0:
|
|
break
|
|
|
|
indices[i] += 1
|
|
|
|
for j in range(i + 1, r):
|
|
indices[j] = indices[j-1] + 1
|
|
|
|
while i < r:
|
|
result[i] = pool[indices[i]]
|
|
i += 1
|
|
|
|
def combinations_with_replacement(pool, r: int):
|
|
if r < 0:
|
|
raise ValueError("r must be non-negative")
|
|
|
|
pool_list = _as_list(pool)
|
|
n = len(pool)
|
|
|
|
if n == 0:
|
|
if r == 0:
|
|
yield List[pool_list.T](capacity=0)
|
|
return
|
|
|
|
pool = pool_list.arr.ptr
|
|
indices = Ptr[int](r)
|
|
result = list(capacity=r)
|
|
|
|
for i in range(r):
|
|
indices[i] = 0
|
|
result.append(pool[0])
|
|
|
|
while True:
|
|
yield result
|
|
|
|
i = r - 1
|
|
while i >= 0 and indices[i] == n - 1:
|
|
i -= 1
|
|
|
|
if i < 0:
|
|
break
|
|
|
|
result = result.copy()
|
|
index = indices[i] + 1
|
|
elem = pool[index]
|
|
|
|
while i < r:
|
|
indices[i] = index
|
|
result[i] = elem
|
|
i += 1
|
|
|
|
@overload
|
|
def combinations_with_replacement(pool, r: Static[int]):
|
|
def empty(T: type) -> T:
|
|
pass
|
|
|
|
if r < 0:
|
|
compile_error("r must be non-negative")
|
|
|
|
if r == 0:
|
|
yield ()
|
|
return
|
|
|
|
if isinstance(pool, list):
|
|
pool_list = pool
|
|
else:
|
|
pool_list = list(pool)
|
|
|
|
n = len(pool)
|
|
|
|
if n == 0:
|
|
return
|
|
|
|
pool = pool_list.arr.ptr
|
|
indices_tuple = (0,) * r
|
|
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
|
|
result_tuple = (empty(pool.T),) * r
|
|
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
|
|
|
|
for i in range(r):
|
|
result[i] = pool[0]
|
|
|
|
while True:
|
|
yield result_tuple
|
|
|
|
i = r - 1
|
|
while i >= 0 and indices[i] == n - 1:
|
|
i -= 1
|
|
|
|
if i < 0:
|
|
break
|
|
|
|
index = indices[i] + 1
|
|
elem = pool[index]
|
|
|
|
while i < r:
|
|
indices[i] = index
|
|
result[i] = elem
|
|
i += 1
|
|
|
|
def _permutations_non_static(pool, r = None):
|
|
pool_list = _as_list(pool)
|
|
n = len(pool)
|
|
|
|
if r is None:
|
|
return _permutations_non_static(pool_list, n)
|
|
elif not isinstance(r, int):
|
|
compile_error("Expected int as r")
|
|
|
|
if r < 0:
|
|
raise ValueError("r must be non-negative")
|
|
|
|
if r > n:
|
|
return
|
|
|
|
indices = Ptr[int](n)
|
|
cycles = Ptr[int](r)
|
|
|
|
for i in range(n):
|
|
indices[i] = i
|
|
|
|
for i in range(r):
|
|
cycles[i] = n - i
|
|
|
|
pool = pool_list.arr.ptr
|
|
result = list(capacity=r)
|
|
|
|
for i in range(r):
|
|
result.append(pool[i])
|
|
|
|
while True:
|
|
yield result
|
|
|
|
if n == 0:
|
|
break
|
|
|
|
result = result.copy()
|
|
i = r - 1
|
|
while i >= 0:
|
|
cycles[i] -= 1
|
|
if cycles[i] == 0:
|
|
index = indices[i]
|
|
for j in range(i, n - 1):
|
|
indices[j] = indices[j+1]
|
|
indices[n-1] = index
|
|
cycles[i] = n - i
|
|
else:
|
|
j = cycles[i]
|
|
index = indices[i]
|
|
indices[i] = indices[n - j]
|
|
indices[n - j] = index
|
|
|
|
for k in range(i, r):
|
|
index = indices[k]
|
|
result[k] = pool[index]
|
|
|
|
break
|
|
i -= 1
|
|
|
|
if i < 0:
|
|
break
|
|
|
|
def _permutations_static(pool, r: Static[int]):
|
|
def empty(T: type) -> T:
|
|
pass
|
|
|
|
pool_list = _as_list(pool)
|
|
n = len(pool)
|
|
|
|
if r < 0:
|
|
raise compile_error("r must be non-negative")
|
|
|
|
if r > n:
|
|
return
|
|
|
|
indices = Ptr[int](n)
|
|
cycles_tuple = (0,) * r
|
|
cycles = Ptr[int](__ptr__(cycles_tuple).as_byte())
|
|
|
|
for i in range(n):
|
|
indices[i] = i
|
|
|
|
for i in range(r):
|
|
cycles[i] = n - i
|
|
|
|
pool = pool_list.arr.ptr
|
|
result_tuple = (empty(pool.T),) * r
|
|
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
|
|
|
|
for i in range(r):
|
|
result[i] = pool[i]
|
|
|
|
while True:
|
|
yield result_tuple
|
|
|
|
if n == 0:
|
|
break
|
|
|
|
i = r - 1
|
|
while i >= 0:
|
|
cycles[i] -= 1
|
|
if cycles[i] == 0:
|
|
index = indices[i]
|
|
for j in range(i, n - 1):
|
|
indices[j] = indices[j+1]
|
|
indices[n-1] = index
|
|
cycles[i] = n - i
|
|
else:
|
|
j = cycles[i]
|
|
index = indices[i]
|
|
indices[i] = indices[n - j]
|
|
indices[n - j] = index
|
|
|
|
for k in range(i, r):
|
|
index = indices[k]
|
|
result[k] = pool[index]
|
|
|
|
break
|
|
i -= 1
|
|
|
|
if i < 0:
|
|
break
|
|
|
|
def permutations(pool, r = None):
|
|
if isinstance(pool, Tuple) and r is None:
|
|
return _permutations_static(pool, staticlen(pool))
|
|
else:
|
|
return _permutations_non_static(pool, r)
|
|
|
|
@overload
|
|
def permutations(pool, r: Static[int]):
|
|
return _permutations_static(pool, r)
|