2024-03-03 05:30:03 +08:00
|
|
|
# Copyright (C) 2022-2024 Exaloop Inc. <https://exaloop.io>
|
2022-01-24 14:47:43 +08:00
|
|
|
|
2021-09-28 02:02:44 +08:00
|
|
|
class object:
|
2022-10-23 08:53:25 +08:00
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def __repr__(self) -> str:
|
2022-10-23 08:53:25 +08:00
|
|
|
return f"<{self.__class__.__name__} object at {self.__raw__()}>"
|
2022-01-24 14:47:43 +08:00
|
|
|
|
|
|
|
def id(x) -> int:
|
2021-09-28 02:02:44 +08:00
|
|
|
if isinstance(x, ByRef):
|
|
|
|
return int(x.__raw__())
|
|
|
|
else:
|
|
|
|
return 0
|
|
|
|
|
|
|
|
_stdout = _C.seq_stdout()
|
2022-01-24 14:47:43 +08:00
|
|
|
|
2022-02-16 23:51:16 +08:00
|
|
|
def print(*args, sep: str = " ", end: str = "\n", file=_stdout, flush: bool = False):
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Print args to the text stream file.
|
|
|
|
"""
|
|
|
|
fp = cobj()
|
2022-07-27 04:06:00 +08:00
|
|
|
if isinstance(file, cobj):
|
2021-09-28 02:02:44 +08:00
|
|
|
fp = file
|
|
|
|
else:
|
|
|
|
fp = file.fp
|
|
|
|
i = 0
|
|
|
|
for a in args:
|
|
|
|
if i and sep:
|
|
|
|
_C.seq_print_full(sep, fp)
|
|
|
|
_C.seq_print_full(str(a), fp)
|
|
|
|
i += 1
|
|
|
|
_C.seq_print_full(end, fp)
|
|
|
|
if flush:
|
|
|
|
_C.fflush(fp)
|
|
|
|
|
2024-08-06 05:31:45 +08:00
|
|
|
def input(prompt: str = ""):
|
|
|
|
stdout = _C.seq_stdout()
|
|
|
|
stderr = _C.seq_stderr()
|
|
|
|
stdin = _C.seq_stdin()
|
|
|
|
_C.fflush(stderr)
|
|
|
|
_C.fflush(stdout)
|
|
|
|
print(prompt, end="")
|
|
|
|
buf = cobj()
|
|
|
|
n = 0
|
|
|
|
s = _C.getline(__ptr__(buf), __ptr__(n), stdin)
|
|
|
|
if s > 0:
|
|
|
|
if buf[s - 1] == byte(10):
|
|
|
|
s -= 1 # skip trailing '\n'
|
|
|
|
if s != 0 and buf[s - 1] == byte(13):
|
|
|
|
s -= 1 # skip trailing '\r'
|
|
|
|
ans = str(buf, s).__ptrcopy__()
|
|
|
|
_C.free(buf)
|
|
|
|
return ans
|
|
|
|
else:
|
|
|
|
_C.free(buf)
|
|
|
|
raise EOFError("EOF when reading a line")
|
|
|
|
|
2022-12-05 08:45:21 +08:00
|
|
|
@extend
|
|
|
|
class __internal__:
|
2023-05-10 21:28:25 +08:00
|
|
|
def print(*args):
|
|
|
|
print(*args, flush=True, file=_C.seq_stdout())
|
|
|
|
|
2024-01-20 00:22:20 +08:00
|
|
|
def min(*args, key=None, default=None):
|
2021-09-28 02:02:44 +08:00
|
|
|
if staticlen(args) == 0:
|
2024-01-20 00:22:20 +08:00
|
|
|
compile_error("min() expected at least 1 argument, got 0")
|
|
|
|
elif staticlen(args) > 1 and default is not None:
|
|
|
|
compile_error("min() 'default' argument only allowed for iterables")
|
2024-01-13 08:27:29 +08:00
|
|
|
elif staticlen(args) == 1:
|
2021-09-28 02:02:44 +08:00
|
|
|
x = args[0].__iter__()
|
|
|
|
if not x.done():
|
|
|
|
s = x.next()
|
|
|
|
while not x.done():
|
|
|
|
i = x.next()
|
2024-01-13 08:27:29 +08:00
|
|
|
if key is None:
|
|
|
|
if i < s:
|
|
|
|
s = i
|
|
|
|
else:
|
|
|
|
if key(i) < key(s):
|
|
|
|
s = i
|
2021-09-28 02:02:44 +08:00
|
|
|
x.destroy()
|
|
|
|
return s
|
|
|
|
else:
|
|
|
|
x.destroy()
|
2024-01-20 00:22:20 +08:00
|
|
|
if default is None:
|
|
|
|
raise ValueError("min() arg is an empty sequence")
|
|
|
|
else:
|
|
|
|
return default
|
2021-09-28 02:02:44 +08:00
|
|
|
elif staticlen(args) == 2:
|
|
|
|
a, b = args
|
2024-01-13 08:27:29 +08:00
|
|
|
if key is None:
|
|
|
|
return a if a <= b else b
|
|
|
|
else:
|
|
|
|
return a if key(a) <= key(b) else b
|
2021-09-28 02:02:44 +08:00
|
|
|
else:
|
|
|
|
m = args[0]
|
2024-01-13 08:27:29 +08:00
|
|
|
for i in args[1:]:
|
|
|
|
if key is None:
|
|
|
|
if i < m:
|
|
|
|
m = i
|
|
|
|
else:
|
|
|
|
if key(i) < key(m):
|
|
|
|
m = i
|
2021-09-28 02:02:44 +08:00
|
|
|
return m
|
|
|
|
|
2024-01-20 00:22:20 +08:00
|
|
|
def max(*args, key=None, default=None):
|
2021-09-28 02:02:44 +08:00
|
|
|
if staticlen(args) == 0:
|
2024-01-20 00:22:20 +08:00
|
|
|
compile_error("max() expected at least 1 argument, got 0")
|
|
|
|
elif staticlen(args) > 1 and default is not None:
|
|
|
|
compile_error("max() 'default' argument only allowed for iterables")
|
2024-01-13 08:27:29 +08:00
|
|
|
elif staticlen(args) == 1:
|
2021-09-28 02:02:44 +08:00
|
|
|
x = args[0].__iter__()
|
|
|
|
if not x.done():
|
|
|
|
s = x.next()
|
|
|
|
while not x.done():
|
|
|
|
i = x.next()
|
2024-01-13 08:27:29 +08:00
|
|
|
if key is None:
|
|
|
|
if i > s:
|
|
|
|
s = i
|
|
|
|
else:
|
|
|
|
if key(i) > key(s):
|
|
|
|
s = i
|
2021-09-28 02:02:44 +08:00
|
|
|
x.destroy()
|
|
|
|
return s
|
|
|
|
else:
|
|
|
|
x.destroy()
|
2024-01-20 00:22:20 +08:00
|
|
|
if default is None:
|
|
|
|
raise ValueError("max() arg is an empty sequence")
|
|
|
|
else:
|
|
|
|
return default
|
2021-09-28 02:02:44 +08:00
|
|
|
elif staticlen(args) == 2:
|
|
|
|
a, b = args
|
2024-01-13 08:27:29 +08:00
|
|
|
if key is None:
|
|
|
|
return a if a >= b else b
|
|
|
|
else:
|
|
|
|
return a if key(a) >= key(b) else b
|
2021-09-28 02:02:44 +08:00
|
|
|
else:
|
|
|
|
m = args[0]
|
2024-01-13 08:27:29 +08:00
|
|
|
for i in args[1:]:
|
|
|
|
if key is None:
|
|
|
|
if i > m:
|
|
|
|
m = i
|
|
|
|
else:
|
|
|
|
if key(i) > key(m):
|
|
|
|
m = i
|
2021-09-28 02:02:44 +08:00
|
|
|
return m
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def len(x) -> int:
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Return the length of x
|
|
|
|
"""
|
|
|
|
return x.__len__()
|
|
|
|
|
|
|
|
def iter(x):
|
|
|
|
"""
|
|
|
|
Return an iterator for the given object
|
|
|
|
"""
|
|
|
|
return x.__iter__()
|
|
|
|
|
|
|
|
def abs(x):
|
|
|
|
"""
|
|
|
|
Return the absolute value of x
|
|
|
|
"""
|
|
|
|
return x.__abs__()
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def hash(x) -> int:
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Returns hashed value only for immutable objects
|
|
|
|
"""
|
|
|
|
return x.__hash__()
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def ord(s: str) -> int:
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Return an integer representing the Unicode code point of s
|
|
|
|
"""
|
|
|
|
if len(s) != 1:
|
2022-01-24 14:47:43 +08:00
|
|
|
raise TypeError(
|
2022-02-16 23:51:16 +08:00
|
|
|
f"ord() expected a character, but string of length {len(s)} found"
|
2022-01-24 14:47:43 +08:00
|
|
|
)
|
2021-09-28 02:02:44 +08:00
|
|
|
return int(s.ptr[0])
|
|
|
|
|
2021-10-10 05:07:41 +08:00
|
|
|
def divmod(a, b):
|
|
|
|
if hasattr(a, "__divmod__"):
|
|
|
|
return a.__divmod__(b)
|
|
|
|
else:
|
|
|
|
return (a // b, a % b)
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def chr(i: int) -> str:
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Return a string representing a character whose Unicode
|
|
|
|
code point is an integer
|
|
|
|
"""
|
|
|
|
p = cobj(1)
|
|
|
|
p[0] = byte(i)
|
|
|
|
return str(p, 1)
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def next(g: Generator[T], default: Optional[T] = None, T: type) -> T:
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Return the next item from g
|
|
|
|
"""
|
|
|
|
if g.done():
|
2022-08-03 02:53:17 +08:00
|
|
|
if default is not None:
|
|
|
|
return default.__val__()
|
2021-09-28 02:02:44 +08:00
|
|
|
else:
|
|
|
|
raise StopIteration()
|
|
|
|
return g.next()
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def any(x: Generator[T], T: type) -> bool:
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Returns True if any item in x is true,
|
|
|
|
False otherwise
|
|
|
|
"""
|
|
|
|
for a in x:
|
|
|
|
if a:
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def all(x: Generator[T], T: type) -> bool:
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Returns True when all elements in x are true,
|
|
|
|
False otherwise
|
|
|
|
"""
|
|
|
|
for a in x:
|
|
|
|
if not a:
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
|
|
def zip(*args):
|
|
|
|
"""
|
|
|
|
Returns a zip object, which is an iterator of tuples
|
|
|
|
that aggregates elements based on the iterables passed
|
|
|
|
"""
|
|
|
|
if staticlen(args) == 0:
|
|
|
|
yield from List[int]()
|
|
|
|
else:
|
|
|
|
iters = tuple(iter(i) for i in args)
|
|
|
|
done = False
|
|
|
|
while not done:
|
|
|
|
for i in iters:
|
|
|
|
if i.done():
|
|
|
|
done = True
|
|
|
|
if not done:
|
|
|
|
yield tuple(i.next() for i in iters)
|
|
|
|
for i in iters:
|
|
|
|
i.destroy()
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def filter(f: Callable[[T], bool], x: Generator[T], T: type) -> Generator[T]:
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Returns all a from the iterable x that are filtered by f
|
|
|
|
"""
|
|
|
|
for a in x:
|
|
|
|
if f(a):
|
|
|
|
yield a
|
|
|
|
|
|
|
|
def map(f, *args):
|
|
|
|
"""
|
|
|
|
Applies a function on all a in x and returns map object
|
|
|
|
"""
|
|
|
|
if staticlen(args) == 0:
|
|
|
|
compile_error("map() expects at least one iterator")
|
|
|
|
elif staticlen(args) == 1:
|
|
|
|
for a in args[0]:
|
|
|
|
yield f(a)
|
|
|
|
else:
|
|
|
|
for a in zip(*args):
|
|
|
|
yield f(*a)
|
|
|
|
|
|
|
|
def enumerate(x, start: int = 0):
|
|
|
|
"""
|
|
|
|
Creates a tuple containing a count (from start which defaults
|
|
|
|
to 0) and the values obtained from iterating over x
|
|
|
|
"""
|
|
|
|
i = start
|
|
|
|
for a in x:
|
2022-01-24 14:47:43 +08:00
|
|
|
yield (i, a)
|
2021-09-28 02:02:44 +08:00
|
|
|
i += 1
|
|
|
|
|
2023-04-13 06:13:54 +08:00
|
|
|
def staticenumerate(tup):
|
|
|
|
i = -1
|
|
|
|
return tuple(((i := i + 1), t) for t in tup)
|
|
|
|
i
|
|
|
|
|
2021-09-28 02:02:44 +08:00
|
|
|
def echo(x):
|
|
|
|
"""
|
|
|
|
Print and return argument
|
|
|
|
"""
|
|
|
|
print x
|
|
|
|
return x
|
|
|
|
|
|
|
|
def reversed(x):
|
|
|
|
"""
|
|
|
|
Return an iterator that accesses x in the reverse order
|
|
|
|
"""
|
2021-10-09 02:25:11 +08:00
|
|
|
if hasattr(x, "__reversed__"):
|
|
|
|
return x.__reversed__()
|
|
|
|
else:
|
|
|
|
i = x.__len__() - 1
|
|
|
|
while i >= 0:
|
|
|
|
yield x[i]
|
|
|
|
i -= 1
|
2021-09-28 02:02:44 +08:00
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def round(x, n=0):
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
|
|
|
Return the x rounded off to the given
|
|
|
|
n digits after the decimal point.
|
|
|
|
"""
|
|
|
|
nx = float.__pow__(10.0, n)
|
|
|
|
return float.__round__(x * nx) / nx
|
|
|
|
|
2023-01-17 23:21:59 +08:00
|
|
|
def _sum_start(x, start):
|
|
|
|
if isinstance(x.__iter__(), Generator[float]) and isinstance(start, int):
|
|
|
|
return float(start)
|
|
|
|
else:
|
|
|
|
return start
|
|
|
|
|
|
|
|
def sum(x, start=0):
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
2023-01-17 23:21:59 +08:00
|
|
|
Return the sum of the items added together from x
|
2021-09-28 02:02:44 +08:00
|
|
|
"""
|
2023-01-17 23:21:59 +08:00
|
|
|
s = _sum_start(x, start)
|
|
|
|
|
|
|
|
for a in x:
|
|
|
|
# don't use += to avoid calling iadd
|
|
|
|
if isinstance(a, bool):
|
|
|
|
s = s + (1 if a else 0)
|
|
|
|
else:
|
|
|
|
s = s + a
|
|
|
|
|
|
|
|
return s
|
2021-09-28 02:02:44 +08:00
|
|
|
|
|
|
|
def repr(x):
|
2022-01-24 14:47:43 +08:00
|
|
|
"""Return the string representation of x"""
|
2021-09-28 02:02:44 +08:00
|
|
|
return x.__repr__()
|
|
|
|
|
2022-01-24 14:47:43 +08:00
|
|
|
def _int_format(a: int, base: int, prefix: str = ""):
|
2021-10-09 02:25:11 +08:00
|
|
|
assert base == 2 or base == 8 or base == 10 or base == 16
|
2022-01-24 14:47:43 +08:00
|
|
|
chars = "0123456789abcdef-"
|
2021-10-09 02:25:11 +08:00
|
|
|
|
|
|
|
b = a
|
|
|
|
digits = 0
|
|
|
|
while b != 0:
|
|
|
|
digits += 1
|
|
|
|
b //= base
|
|
|
|
|
|
|
|
sz = digits + (1 if a <= 0 else 0) + len(prefix)
|
|
|
|
p = Ptr[byte](sz)
|
|
|
|
q = p
|
|
|
|
|
|
|
|
if a < 0:
|
|
|
|
q[0] = chars[-1].ptr[0]
|
|
|
|
q += 1
|
|
|
|
|
|
|
|
if prefix:
|
|
|
|
str.memcpy(q, prefix.ptr, len(prefix))
|
|
|
|
q += len(prefix)
|
|
|
|
|
|
|
|
if digits != 0:
|
|
|
|
b = a
|
|
|
|
q += digits - 1
|
|
|
|
i = 1
|
|
|
|
while b != 0:
|
|
|
|
i += 1
|
|
|
|
q[0] = chars.ptr[abs(b % base)]
|
|
|
|
q += -1
|
|
|
|
b //= base
|
|
|
|
else:
|
|
|
|
q[0] = chars.ptr[0]
|
|
|
|
|
|
|
|
return str(p, sz)
|
|
|
|
|
|
|
|
def bin(n):
|
2022-01-24 14:47:43 +08:00
|
|
|
return _int_format(n.__index__(), 2, "0b")
|
|
|
|
|
2021-10-09 02:25:11 +08:00
|
|
|
def oct(n):
|
2022-01-24 14:47:43 +08:00
|
|
|
return _int_format(n.__index__(), 8, "0o")
|
|
|
|
|
2021-10-09 02:25:11 +08:00
|
|
|
def hex(n):
|
2022-01-24 14:47:43 +08:00
|
|
|
return _int_format(n.__index__(), 16, "0x")
|
|
|
|
|
2023-04-13 06:13:54 +08:00
|
|
|
def pow(base: float, exp: float):
|
|
|
|
return base ** exp
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def pow(base: int, exp: int, mod: Optional[int] = None):
|
|
|
|
if exp < 0:
|
|
|
|
raise ValueError("pow() negative int exponent not supported")
|
|
|
|
|
|
|
|
if mod is not None:
|
|
|
|
if mod == 0:
|
|
|
|
raise ValueError("pow() 3rd argument cannot be 0")
|
|
|
|
base %= mod
|
|
|
|
|
|
|
|
result = 1
|
|
|
|
while exp > 0:
|
|
|
|
if exp & 1:
|
|
|
|
x = result * base
|
|
|
|
result = x % mod if mod is not None else x
|
|
|
|
y = base * base
|
|
|
|
base = y % mod if mod is not None else y
|
|
|
|
exp >>= 1
|
|
|
|
return result % mod if mod is not None else result
|
|
|
|
|
2021-09-28 02:02:44 +08:00
|
|
|
@extend
|
|
|
|
class int:
|
|
|
|
def _from_str(s: str, base: int):
|
2024-05-22 22:22:50 +08:00
|
|
|
def parse_error(s: str, base: int):
|
|
|
|
raise ValueError(
|
|
|
|
f"invalid literal for int() with base {base}: {s.__repr__()}"
|
|
|
|
)
|
2022-01-24 14:47:43 +08:00
|
|
|
|
2021-09-28 02:02:44 +08:00
|
|
|
if base < 0 or base > 36 or base == 1:
|
|
|
|
raise ValueError("int() base must be >= 2 and <= 36, or 0")
|
|
|
|
|
2022-12-13 09:54:01 +08:00
|
|
|
s0 = s
|
2024-05-22 22:22:50 +08:00
|
|
|
base0 = base
|
2022-12-13 09:54:01 +08:00
|
|
|
s = s.strip()
|
2021-09-28 02:02:44 +08:00
|
|
|
n = len(s)
|
2024-05-22 22:22:50 +08:00
|
|
|
negate = False
|
|
|
|
|
|
|
|
if base == 0:
|
|
|
|
# skip leading sign
|
|
|
|
o = 0
|
|
|
|
if n >= 1 and (s.ptr[0] == byte(43) or s.ptr[0] == byte(45)):
|
|
|
|
o = 1
|
|
|
|
|
|
|
|
# detect base from prefix
|
|
|
|
if n >= o + 1 and s.ptr[o] == byte(48): # '0'
|
|
|
|
if n < o + 2:
|
|
|
|
parse_error(s0, base)
|
|
|
|
|
|
|
|
if s.ptr[o + 1] == byte(98) or s.ptr[o + 1] == byte(66): # 'b'/'B'
|
|
|
|
base = 2
|
|
|
|
elif s.ptr[o + 1] == byte(111) or s.ptr[o + 1] == byte(79): # 'o'/'O'
|
|
|
|
base = 8
|
|
|
|
elif s.ptr[o + 1] == byte(120) or s.ptr[o + 1] == byte(88): # 'x'/'X'
|
|
|
|
base = 16
|
|
|
|
else:
|
|
|
|
parse_error(s0, base)
|
|
|
|
else:
|
|
|
|
base = 10
|
|
|
|
|
|
|
|
if base == 2 or base == 8 or base == 16:
|
|
|
|
if base == 2:
|
|
|
|
C_LOWER = byte(98) # 'b'
|
|
|
|
C_UPPER = byte(66) # 'B'
|
|
|
|
elif base == 8:
|
|
|
|
C_LOWER = byte(111) # 'o'
|
|
|
|
C_UPPER = byte(79) # 'O'
|
|
|
|
else:
|
|
|
|
C_LOWER = byte(120) # 'x'
|
|
|
|
C_UPPER = byte(88) # 'X'
|
|
|
|
|
|
|
|
def check_digit(d: byte, base: int):
|
|
|
|
if base == 2:
|
|
|
|
return d == byte(48) or d == byte(49)
|
|
|
|
elif base == 8:
|
|
|
|
return byte(48) <= d <= byte(55)
|
|
|
|
elif base == 16:
|
|
|
|
return ((byte(48) <= d <= byte(57)) or
|
|
|
|
(byte(97) <= d <= byte(102)) or
|
|
|
|
(byte(65) <= d <= byte(70)))
|
|
|
|
return False
|
|
|
|
|
|
|
|
if (n >= 4 and
|
|
|
|
(s.ptr[0] == byte(43) or s.ptr[0] == byte(45)) and
|
|
|
|
s.ptr[1] == byte(48) and
|
|
|
|
(s.ptr[2] == C_LOWER or s.ptr[2] == C_UPPER)): # '+0b' etc.
|
|
|
|
if not check_digit(s.ptr[3], base):
|
|
|
|
parse_error(s0, base0)
|
|
|
|
negate = (s.ptr[0] == byte(45))
|
|
|
|
s = str(s.ptr + 3, n - 3)
|
|
|
|
elif (n >= 3 and
|
|
|
|
s.ptr[0] == byte(48) and
|
|
|
|
(s.ptr[1] == C_LOWER or s.ptr[1] == C_UPPER)): # '0b' etc.
|
|
|
|
if not check_digit(s.ptr[3], base):
|
|
|
|
parse_error(s0, base0)
|
|
|
|
s = str(s.ptr + 2, n - 2)
|
2021-09-28 02:02:44 +08:00
|
|
|
|
|
|
|
end = cobj()
|
2024-05-22 22:22:50 +08:00
|
|
|
result = _C.seq_int_from_str(s, __ptr__(end), i32(base))
|
|
|
|
n = len(s)
|
2021-09-28 02:02:44 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if n == 0 or end != s.ptr + n:
|
|
|
|
parse_error(s0, base0)
|
2021-09-28 02:02:44 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if negate:
|
|
|
|
result = -result
|
2021-09-28 02:02:44 +08:00
|
|
|
|
|
|
|
return result
|
2021-11-22 21:32:49 +08:00
|
|
|
|
2022-12-13 09:54:01 +08:00
|
|
|
@extend
|
|
|
|
class float:
|
|
|
|
def _from_str(s: str) -> float:
|
|
|
|
s0 = s
|
2024-05-22 22:22:50 +08:00
|
|
|
s = s.rstrip()
|
2022-12-13 09:54:01 +08:00
|
|
|
n = len(s)
|
|
|
|
end = cobj()
|
2024-05-22 22:22:50 +08:00
|
|
|
result = _C.seq_float_from_str(s, __ptr__(end))
|
2022-12-13 09:54:01 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if n == 0 or end != s.ptr + n:
|
2022-12-13 09:54:01 +08:00
|
|
|
raise ValueError(f"could not convert string to float: {s0.__repr__()}")
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
2023-12-06 04:48:34 +08:00
|
|
|
@extend
|
|
|
|
class complex:
|
|
|
|
def _from_str(v: str) -> complex:
|
|
|
|
def parse_error():
|
|
|
|
raise ValueError("complex() arg is a malformed string")
|
|
|
|
|
|
|
|
n = len(v)
|
2024-05-22 22:22:50 +08:00
|
|
|
s = v.ptr
|
2023-12-06 04:48:34 +08:00
|
|
|
x = 0.0
|
|
|
|
y = 0.0
|
|
|
|
z = 0.0
|
|
|
|
got_bracket = False
|
|
|
|
end = cobj()
|
2024-05-22 22:22:50 +08:00
|
|
|
i = 0
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
while i < n and str._isspace(s[i]):
|
|
|
|
i += 1
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if i < n and s[i] == byte(40): # '('
|
2023-12-06 04:48:34 +08:00
|
|
|
got_bracket = True
|
2024-05-22 22:22:50 +08:00
|
|
|
i += 1
|
|
|
|
while i < n and str._isspace(s[i]):
|
|
|
|
i += 1
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
z = _C.seq_float_from_str(str(s + i, n - i), __ptr__(end))
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if end != s + i:
|
|
|
|
i = end - s
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if i < n and (s[i] == byte(43) or s[i] == byte(45)): # '+' '-'
|
2023-12-06 04:48:34 +08:00
|
|
|
x = z
|
2024-05-22 22:22:50 +08:00
|
|
|
y = _C.seq_float_from_str(str(s + i, n - i), __ptr__(end))
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if end != s + i:
|
|
|
|
i = end - s
|
2023-12-06 04:48:34 +08:00
|
|
|
else:
|
2024-05-22 22:22:50 +08:00
|
|
|
y = 1.0 if s[i] == byte(43) else -1.0
|
|
|
|
i += 1
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if not (i < n and (s[i] == byte(106) or s[i] == byte(74))): # 'j' 'J'
|
2023-12-06 04:48:34 +08:00
|
|
|
parse_error()
|
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
i += 1
|
|
|
|
elif i < n and (s[i] == byte(106) or s[i] == byte(74)): # 'j' 'J'
|
|
|
|
i += 1
|
2023-12-06 04:48:34 +08:00
|
|
|
y = z
|
|
|
|
else:
|
|
|
|
x = z
|
|
|
|
else:
|
2024-05-22 22:22:50 +08:00
|
|
|
if i < n and (s[i] == byte(43) or s[i] == byte(45)): # '+' '-'
|
|
|
|
y = 1.0 if s[i] == byte(43) else -1.0
|
|
|
|
i += 1
|
2023-12-06 04:48:34 +08:00
|
|
|
else:
|
|
|
|
y = 1.0
|
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if not (i < n and (s[i] == byte(106) or s[i] == byte(74))): # 'j' 'J'
|
2023-12-06 04:48:34 +08:00
|
|
|
parse_error()
|
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
i += 1
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
while i < n and str._isspace(s[i]):
|
|
|
|
i += 1
|
2023-12-06 04:48:34 +08:00
|
|
|
|
|
|
|
if got_bracket:
|
2024-05-22 22:22:50 +08:00
|
|
|
if i < n and s[i] != byte(41): # ')'
|
2023-12-06 04:48:34 +08:00
|
|
|
parse_error()
|
2024-05-22 22:22:50 +08:00
|
|
|
i += 1
|
|
|
|
while i < n and str._isspace(s[i]):
|
|
|
|
i += 1
|
2023-12-06 04:48:34 +08:00
|
|
|
|
2024-05-22 22:22:50 +08:00
|
|
|
if i != n:
|
2023-12-06 04:48:34 +08:00
|
|
|
parse_error()
|
|
|
|
|
|
|
|
return complex(x, y)
|
|
|
|
|
2024-01-26 02:41:36 +08:00
|
|
|
@extend
|
|
|
|
class float32:
|
|
|
|
def _from_str(s: str) -> float32:
|
|
|
|
return float32(float._from_str(s))
|
|
|
|
|
|
|
|
@extend
|
|
|
|
class float16:
|
|
|
|
def _from_str(s: str) -> float16:
|
|
|
|
return float16(float._from_str(s))
|
|
|
|
|
|
|
|
@extend
|
|
|
|
class bfloat16:
|
|
|
|
def _from_str(s: str) -> bfloat16:
|
|
|
|
return bfloat16(float._from_str(s))
|
|
|
|
|
|
|
|
@extend
|
|
|
|
class complex64:
|
|
|
|
def _from_str(s: str) -> complex64:
|
|
|
|
return complex64(complex._from_str(s))
|
|
|
|
|
2022-01-19 13:19:52 +08:00
|
|
|
def _jit_display(x, s: Static[str], bundle: Set[str] = Set[str]()):
|
2022-07-27 04:06:00 +08:00
|
|
|
if isinstance(x, None):
|
|
|
|
return
|
2022-01-05 02:13:59 +08:00
|
|
|
if hasattr(x, "_repr_mimebundle_") and s == "jupyter":
|
|
|
|
d = x._repr_mimebundle_(bundle)
|
|
|
|
# TODO: pick appropriate mime
|
|
|
|
mime = next(d.keys()) # just pick first
|
|
|
|
print(f"\x00\x00__codon/mime__\x00{mime}\x00{d[mime]}", end='')
|
2021-11-22 21:32:49 +08:00
|
|
|
elif hasattr(x, "__repr__"):
|
2022-01-05 02:13:59 +08:00
|
|
|
print(x.__repr__(), end='')
|
2021-11-22 21:32:49 +08:00
|
|
|
elif hasattr(x, "__str__"):
|
2022-01-05 02:13:59 +08:00
|
|
|
print(x.__str__(), end='')
|