# Copyright (C) 2022-2024 Exaloop Inc. class object: def __init__(self): pass def __repr__(self) -> str: return f"<{self.__class__.__name__} object at {self.__raw__()}>" def id(x) -> int: if isinstance(x, ByRef): return int(x.__raw__()) else: return 0 _stdout = _C.seq_stdout() def print(*args, sep: str = " ", end: str = "\n", file=_stdout, flush: bool = False): """ Print args to the text stream file. """ fp = cobj() if isinstance(file, cobj): fp = file else: fp = file.fp i = 0 for a in args: if i and sep: _C.seq_print_full(sep, fp) _C.seq_print_full(str(a), fp) i += 1 _C.seq_print_full(end, fp) if flush: _C.fflush(fp) def input(prompt: str = ""): stdout = _C.seq_stdout() stderr = _C.seq_stderr() stdin = _C.seq_stdin() _C.fflush(stderr) _C.fflush(stdout) print(prompt, end="") buf = cobj() n = 0 s = _C.getline(__ptr__(buf), __ptr__(n), stdin) if s > 0: if buf[s - 1] == byte(10): s -= 1 # skip trailing '\n' if s != 0 and buf[s - 1] == byte(13): s -= 1 # skip trailing '\r' ans = str(buf, s).__ptrcopy__() _C.free(buf) return ans else: _C.free(buf) raise EOFError("EOF when reading a line") @extend class __internal__: def print(*args): print(*args, flush=True, file=_C.seq_stdout()) def min(*args, key=None, default=None): if staticlen(args) == 0: compile_error("min() expected at least 1 argument, got 0") elif staticlen(args) > 1 and default is not None: compile_error("min() 'default' argument only allowed for iterables") elif staticlen(args) == 1: x = args[0].__iter__() if not x.done(): s = x.next() while not x.done(): i = x.next() if key is None: if i < s: s = i else: if key(i) < key(s): s = i x.destroy() return s else: x.destroy() if default is None: raise ValueError("min() arg is an empty sequence") else: return default elif staticlen(args) == 2: a, b = args if key is None: return a if a <= b else b else: return a if key(a) <= key(b) else b else: m = args[0] for i in args[1:]: if key is None: if i < m: m = i else: if key(i) < key(m): m = i return m def max(*args, key=None, default=None): if staticlen(args) == 0: compile_error("max() expected at least 1 argument, got 0") elif staticlen(args) > 1 and default is not None: compile_error("max() 'default' argument only allowed for iterables") elif staticlen(args) == 1: x = args[0].__iter__() if not x.done(): s = x.next() while not x.done(): i = x.next() if key is None: if i > s: s = i else: if key(i) > key(s): s = i x.destroy() return s else: x.destroy() if default is None: raise ValueError("max() arg is an empty sequence") else: return default elif staticlen(args) == 2: a, b = args if key is None: return a if a >= b else b else: return a if key(a) >= key(b) else b else: m = args[0] for i in args[1:]: if key is None: if i > m: m = i else: if key(i) > key(m): m = i return m def len(x) -> int: """ Return the length of x """ return x.__len__() def iter(x): """ Return an iterator for the given object """ return x.__iter__() def abs(x): """ Return the absolute value of x """ return x.__abs__() def hash(x) -> int: """ Returns hashed value only for immutable objects """ return x.__hash__() def ord(s: str) -> int: """ Return an integer representing the Unicode code point of s """ if len(s) != 1: raise TypeError( f"ord() expected a character, but string of length {len(s)} found" ) return int(s.ptr[0]) def divmod(a, b): if hasattr(a, "__divmod__"): return a.__divmod__(b) else: return (a // b, a % b) def chr(i: int) -> str: """ Return a string representing a character whose Unicode code point is an integer """ p = cobj(1) p[0] = byte(i) return str(p, 1) def next(g: Generator[T], default: Optional[T] = None, T: type) -> T: """ Return the next item from g """ if g.done(): if default is not None: return default.__val__() else: raise StopIteration() return g.next() def any(x: Generator[T], T: type) -> bool: """ Returns True if any item in x is true, False otherwise """ for a in x: if a: return True return False def all(x: Generator[T], T: type) -> bool: """ Returns True when all elements in x are true, False otherwise """ for a in x: if not a: return False return True def zip(*args): """ Returns a zip object, which is an iterator of tuples that aggregates elements based on the iterables passed """ if staticlen(args) == 0: yield from List[int]() else: iters = tuple(iter(i) for i in args) done = False while not done: for i in iters: if i.done(): done = True if not done: yield tuple(i.next() for i in iters) for i in iters: i.destroy() def filter(f: Callable[[T], bool], x: Generator[T], T: type) -> Generator[T]: """ Returns all a from the iterable x that are filtered by f """ for a in x: if f(a): yield a def map(f, *args): """ Applies a function on all a in x and returns map object """ if staticlen(args) == 0: compile_error("map() expects at least one iterator") elif staticlen(args) == 1: for a in args[0]: yield f(a) else: for a in zip(*args): yield f(*a) def enumerate(x, start: int = 0): """ Creates a tuple containing a count (from start which defaults to 0) and the values obtained from iterating over x """ i = start for a in x: yield (i, a) i += 1 def staticenumerate(tup): i = -1 return tuple(((i := i + 1), t) for t in tup) i def echo(x): """ Print and return argument """ print x return x def reversed(x): """ Return an iterator that accesses x in the reverse order """ if hasattr(x, "__reversed__"): return x.__reversed__() else: i = x.__len__() - 1 while i >= 0: yield x[i] i -= 1 def round(x, n=0): """ Return the x rounded off to the given n digits after the decimal point. """ nx = float.__pow__(10.0, n) return float.__round__(x * nx) / nx def _sum_start(x, start): if isinstance(x.__iter__(), Generator[float]) and isinstance(start, int): return float(start) else: return start def sum(x, start=0): """ Return the sum of the items added together from x """ s = _sum_start(x, start) for a in x: # don't use += to avoid calling iadd if isinstance(a, bool): s = s + (1 if a else 0) else: s = s + a return s def repr(x): """Return the string representation of x""" return x.__repr__() def _int_format(a: int, base: int, prefix: str = ""): assert base == 2 or base == 8 or base == 10 or base == 16 chars = "0123456789abcdef-" b = a digits = 0 while b != 0: digits += 1 b //= base sz = digits + (1 if a <= 0 else 0) + len(prefix) p = Ptr[byte](sz) q = p if a < 0: q[0] = chars[-1].ptr[0] q += 1 if prefix: str.memcpy(q, prefix.ptr, len(prefix)) q += len(prefix) if digits != 0: b = a q += digits - 1 i = 1 while b != 0: i += 1 q[0] = chars.ptr[abs(b % base)] q += -1 b //= base else: q[0] = chars.ptr[0] return str(p, sz) def bin(n): return _int_format(n.__index__(), 2, "0b") def oct(n): return _int_format(n.__index__(), 8, "0o") def hex(n): return _int_format(n.__index__(), 16, "0x") def pow(base: float, exp: float): return base ** exp @overload def pow(base: int, exp: int, mod: Optional[int] = None): if exp < 0: raise ValueError("pow() negative int exponent not supported") if mod is not None: if mod == 0: raise ValueError("pow() 3rd argument cannot be 0") base %= mod result = 1 while exp > 0: if exp & 1: x = result * base result = x % mod if mod is not None else x y = base * base base = y % mod if mod is not None else y exp >>= 1 return result % mod if mod is not None else result @extend class int: def _from_str(s: str, base: int): def parse_error(s: str, base: int): raise ValueError( f"invalid literal for int() with base {base}: {s.__repr__()}" ) if base < 0 or base > 36 or base == 1: raise ValueError("int() base must be >= 2 and <= 36, or 0") s0 = s base0 = base s = s.strip() n = len(s) negate = False if base == 0: # skip leading sign o = 0 if n >= 1 and (s.ptr[0] == byte(43) or s.ptr[0] == byte(45)): o = 1 # detect base from prefix if n >= o + 1 and s.ptr[o] == byte(48): # '0' if n < o + 2: parse_error(s0, base) if s.ptr[o + 1] == byte(98) or s.ptr[o + 1] == byte(66): # 'b'/'B' base = 2 elif s.ptr[o + 1] == byte(111) or s.ptr[o + 1] == byte(79): # 'o'/'O' base = 8 elif s.ptr[o + 1] == byte(120) or s.ptr[o + 1] == byte(88): # 'x'/'X' base = 16 else: parse_error(s0, base) else: base = 10 if base == 2 or base == 8 or base == 16: if base == 2: C_LOWER = byte(98) # 'b' C_UPPER = byte(66) # 'B' elif base == 8: C_LOWER = byte(111) # 'o' C_UPPER = byte(79) # 'O' else: C_LOWER = byte(120) # 'x' C_UPPER = byte(88) # 'X' def check_digit(d: byte, base: int): if base == 2: return d == byte(48) or d == byte(49) elif base == 8: return byte(48) <= d <= byte(55) elif base == 16: return ((byte(48) <= d <= byte(57)) or (byte(97) <= d <= byte(102)) or (byte(65) <= d <= byte(70))) return False if (n >= 4 and (s.ptr[0] == byte(43) or s.ptr[0] == byte(45)) and s.ptr[1] == byte(48) and (s.ptr[2] == C_LOWER or s.ptr[2] == C_UPPER)): # '+0b' etc. if not check_digit(s.ptr[3], base): parse_error(s0, base0) negate = (s.ptr[0] == byte(45)) s = str(s.ptr + 3, n - 3) elif (n >= 3 and s.ptr[0] == byte(48) and (s.ptr[1] == C_LOWER or s.ptr[1] == C_UPPER)): # '0b' etc. if not check_digit(s.ptr[3], base): parse_error(s0, base0) s = str(s.ptr + 2, n - 2) end = cobj() result = _C.seq_int_from_str(s, __ptr__(end), i32(base)) n = len(s) if n == 0 or end != s.ptr + n: parse_error(s0, base0) if negate: result = -result return result @extend class float: def _from_str(s: str) -> float: s0 = s s = s.rstrip() n = len(s) end = cobj() result = _C.seq_float_from_str(s, __ptr__(end)) if n == 0 or end != s.ptr + n: raise ValueError(f"could not convert string to float: {s0.__repr__()}") return result @extend class complex: def _from_str(v: str) -> complex: def parse_error(): raise ValueError("complex() arg is a malformed string") n = len(v) s = v.ptr x = 0.0 y = 0.0 z = 0.0 got_bracket = False end = cobj() i = 0 while i < n and str._isspace(s[i]): i += 1 if i < n and s[i] == byte(40): # '(' got_bracket = True i += 1 while i < n and str._isspace(s[i]): i += 1 z = _C.seq_float_from_str(str(s + i, n - i), __ptr__(end)) if end != s + i: i = end - s if i < n and (s[i] == byte(43) or s[i] == byte(45)): # '+' '-' x = z y = _C.seq_float_from_str(str(s + i, n - i), __ptr__(end)) if end != s + i: i = end - s else: y = 1.0 if s[i] == byte(43) else -1.0 i += 1 if not (i < n and (s[i] == byte(106) or s[i] == byte(74))): # 'j' 'J' parse_error() i += 1 elif i < n and (s[i] == byte(106) or s[i] == byte(74)): # 'j' 'J' i += 1 y = z else: x = z else: if i < n and (s[i] == byte(43) or s[i] == byte(45)): # '+' '-' y = 1.0 if s[i] == byte(43) else -1.0 i += 1 else: y = 1.0 if not (i < n and (s[i] == byte(106) or s[i] == byte(74))): # 'j' 'J' parse_error() i += 1 while i < n and str._isspace(s[i]): i += 1 if got_bracket: if i < n and s[i] != byte(41): # ')' parse_error() i += 1 while i < n and str._isspace(s[i]): i += 1 if i != n: parse_error() return complex(x, y) @extend class float32: def _from_str(s: str) -> float32: return float32(float._from_str(s)) @extend class float16: def _from_str(s: str) -> float16: return float16(float._from_str(s)) @extend class bfloat16: def _from_str(s: str) -> bfloat16: return bfloat16(float._from_str(s)) @extend class complex64: def _from_str(s: str) -> complex64: return complex64(complex._from_str(s)) def _jit_display(x, s: Static[str], bundle: Set[str] = Set[str]()): if isinstance(x, None): return if hasattr(x, "_repr_mimebundle_") and s == "jupyter": d = x._repr_mimebundle_(bundle) # TODO: pick appropriate mime mime = next(d.keys()) # just pick first print(f"\x00\x00__codon/mime__\x00{mime}\x00{d[mime]}", end='') elif hasattr(x, "__repr__"): print(x.__repr__(), end='') elif hasattr(x, "__str__"): print(x.__str__(), end='')