# Copyright (C) 2022-2024 Exaloop Inc. from internal.gc import ( alloc, alloc_atomic, alloc_atomic_uncollectable, free, sizeof, register_finalizer ) from internal.static import vars_types, tuple_type, vars as _vars, fn_overloads, fn_can_call def vars(obj, with_index: Static[bool] = False): return _vars(obj, with_index) __vtables__ = Ptr[Ptr[cobj]]() __vtable_size__ = 0 @extend class __internal__: def yield_final(val): pass def yield_in_no_suspend(T: type) -> T: pass @pure @derives @llvm def class_raw_ptr(obj) -> Ptr[byte]: ret ptr %obj @pure @derives @llvm def class_raw_rtti_ptr(obj) -> Ptr[byte]: %0 = extractvalue {ptr, ptr} %obj, 0 ret ptr %0 @pure @derives @llvm def class_raw_rtti_rtti(obj: T, T: type) -> Ptr[byte]: %0 = extractvalue {ptr, ptr} %obj, 1 ret ptr %0 def class_alloc(T: type) -> T: """Allocates a new reference (class) type""" sz = sizeof(tuple(T)) obj = alloc_atomic(sz) if T.__contents_atomic__ else alloc(sz) if __has_rtti__(T): register_finalizer(obj) rtti = RTTI(T.__id__).__raw__() return __internal__.to_class_ptr_rtti((obj, rtti), T) else: register_finalizer(obj) return __internal__.to_class_ptr(obj, T) def class_ctr(T: type, *args, **kwargs) -> T: """Shorthand for `t = T.__new__(); t.__init__(*args, **kwargs); t`""" return T(*args, **kwargs) def class_init_vtables(): """ Create a global vtable. """ global __vtables__ sz = __vtable_size__ + 1 p = alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj])) __vtables__ = Ptr[Ptr[cobj]](p) __internal__.class_populate_vtables() def _print(a): from C import seq_print(str) if hasattr(a, '__repr__'): seq_print(a.__repr__()) else: seq_print(a.__str__()) def class_populate_vtables() -> None: """ Populate content of vtables. Compiler generated. Corresponds to: for each realized class C: __internal__.class_set_rtti_vtable(, + 1, T=C) for each fn F in C's vtable: __internal__.class_set_rtti_vtable_fn( , , Function().__raw__(), T=C ) """ pass def class_get_rtti_vtable(obj) -> Ptr[cobj]: if not __has_rtti__(type(obj)): compile_error("class is not polymorphic") rtti = __internal__.class_raw_rtti_rtti(obj) id = __internal__.to_class_ptr(rtti, RTTI).id return __vtables__[id] def class_set_rtti_vtable(id: int, sz: int, T: type): if not __has_rtti__(T): compile_error("class is not polymorphic") p = alloc_atomic_uncollectable((sz + 1) * sizeof(cobj)) __vtables__[id] = Ptr[cobj](p) __internal__.class_set_typeinfo(__vtables__[id], id) def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type): if not __has_rtti__(T): compile_error("class is not polymorphic") __vtables__[id][fid] = f def class_set_typeinfo(p: Ptr[cobj], typeinfo: T, T: type) -> None: i = Ptr[T](1) i[0] = typeinfo p[0] = i.as_byte() def class_get_typeinfo(p) -> int: c = Ptr[Ptr[cobj]](p.__raw__()) vt = c[0] return Ptr[int](vt[0])[0] @inline def class_base_derived_dist(B: type, D: type) -> int: """Calculates the byte distance of base class B and derived class D. Compiler generated.""" return 0 @inline def class_base_to_derived(b: B, B: type, D: type) -> D: if not (__has_rtti__(D) and __has_rtti__(B)): compile_error("classes are not polymorphic") off = __internal__.class_base_derived_dist(B, D) ptr = __internal__.class_raw_rtti_ptr(b) - off pr = __internal__.class_raw_rtti_rtti(b) return __internal__.to_class_ptr_rtti((ptr, pr), D) def class_copy(obj: T, T: type) -> T: p = __internal__.class_alloc(T) str.memcpy(p.__raw__(), obj.__raw__(), sizeof(tuple(T))) return p def class_super(obj, B: type, change_rtti: Static[int] = 0) -> B: D = type(obj) if not __has_rtti__(D): # static inheritance return __internal__.to_class_ptr(obj.__raw__(), B) else: if not __has_rtti__(B): compile_error("classes are not polymorphic") off = __internal__.class_base_derived_dist(B, D) ptr = __internal__.class_raw_rtti_ptr(obj) + off pr = __internal__.class_raw_rtti_rtti(obj) if change_rtti: pr = RTTI(B.__id__).__raw__() return __internal__.to_class_ptr_rtti((ptr, pr), B) # Unions def get_union_tag(u, tag: Static[int]): # compiler-generated pass @llvm def union_set_tag(tag: byte, U: type) -> U: %0 = insertvalue {=U} undef, i8 %tag, 0 ret {=U} %0 @llvm def union_get_data_ptr(ptr: Ptr[U], U: type, T: type) -> Ptr[T]: %0 = getelementptr inbounds {=U}, ptr %ptr, i64 0, i32 1 ret ptr %0 @llvm def union_get_tag(u: U, U: type) -> byte: %0 = extractvalue {=U} %u, 0 ret i8 %0 def union_get_data(u, T: type) -> T: return __internal__.union_get_data_ptr(__ptr__(u), T=T)[0] def union_make(tag: int, value, U: type) -> U: u = __internal__.union_set_tag(byte(tag), U) __internal__.union_get_data_ptr(__ptr__(u), T=type(value))[0] = value return u def new_union(value, U: type) -> U: for tag, T in vars_types(U, with_index=True): if isinstance(value, T): return __internal__.union_make(tag, value, U) if isinstance(value, Union[T]): return __internal__.union_make(tag, __internal__.get_union(value, T), U) # TODO: make this static! raise TypeError("invalid union constructor") def get_union(union, T: type) -> T: for tag, TU in vars_types(union, with_index=True): if isinstance(TU, T): if __internal__.union_get_tag(union) == tag: return __internal__.union_get_data(union, TU) raise TypeError(f"invalid union getter for type '{T.__class__.__name__}'") def _union_member_helper(union, member: Static[str]) -> Union: for tag, T in vars_types(union, with_index=True): if hasattr(T, member): if __internal__.union_get_tag(union) == tag: return getattr(__internal__.union_get_data(union, T), member) raise TypeError(f"invalid union call '{member}'") def union_member(union, member: Static[str]): t = __internal__._union_member_helper(union, member) if staticlen(t) == 1: return __internal__.get_union_tag(t, 0) else: return t def _union_call_helper(union, args, kwargs) -> Union: for tag, T in vars_types(union, with_index=True): if fn_can_call(T, *args, **kwargs): if __internal__.union_get_tag(union) == tag: return __internal__.union_get_data(union, T)(*args, **kwargs) elif hasattr(T, '__call__'): if fn_can_call(T.__call__, *args, **kwargs): if __internal__.union_get_tag(union) == tag: return __internal__.union_get_data(union, T).__call__(*args, **kwargs) raise TypeError("cannot call union " + union.__class__.__name__) def union_call(union, args, kwargs): t = __internal__._union_call_helper(union, args, kwargs) if staticlen(t) == 1: return __internal__.get_union_tag(t, 0) else: return t def union_str(union): for tag, T in vars_types(union, with_index=True): if hasattr(T, '__str__'): if __internal__.union_get_tag(union) == tag: return __internal__.union_get_data(union, T).__str__() elif hasattr(T, '__repr__'): if __internal__.union_get_tag(union) == tag: return __internal__.union_get_data(union, T).__repr__() return '' # Tuples def namedkeys(N: Static[int]): pass @pure @derives @llvm def _tuple_getitem_llvm(t: T, idx: int, T: type, E: type) -> E: %x = alloca {=T} store {=T} %t, ptr %x %p = getelementptr {=E}, ptr %x, i64 %idx %v = load {=E}, ptr %p ret {=E} %v def tuple_fix_index(idx: int, len: int) -> int: if idx < 0: idx += len if idx < 0 or idx >= len: raise IndexError("tuple index out of range") return idx def tuple_getitem(t: T, idx: int, T: type, E: type) -> E: return __internal__._tuple_getitem_llvm( t, __internal__.tuple_fix_index(idx, staticlen(t)), T, E ) @llvm def tuple_cast_unsafe(t, U: type) -> U: ret {=U} %t @pure @derives @llvm def fn_new(p: Ptr[byte], T: type) -> T: ret ptr %p @pure @derives @llvm def fn_raw(fn: T, T: type) -> Ptr[byte]: ret ptr %fn @pure @llvm def int_sext(what, F: Static[int], T: Static[int]) -> Int[T]: %0 = sext i{=F} %what to i{=T} ret i{=T} %0 @pure @llvm def int_zext(what, F: Static[int], T: Static[int]) -> Int[T]: %0 = zext i{=F} %what to i{=T} ret i{=T} %0 @pure @llvm def int_trunc(what, F: Static[int], T: Static[int]) -> Int[T]: %0 = trunc i{=F} %what to i{=T} ret i{=T} %0 def seq_assert(file: str, line: int, msg: str) -> AssertionError: s = f": {msg}" if msg else "" s = f"Assert failed{s} ({file}:{line.__repr__()})" return AssertionError(s) def seq_assert_test(file: str, line: int, msg: str): from C import seq_print(str) s = f": {msg}" if msg else "" s = f"\033[1;31mTEST FAILED:\033[0m {file} (line {line}){s}\n" seq_print(s) def check_errno(prefix: str): @pure @C def seq_check_errno() -> str: pass msg = seq_check_errno() if msg: raise OSError(prefix + msg) @pure @llvm def opt_ref_new(T: type) -> Optional[T]: ret ptr null def opt_ref_new_rtti(T: type) -> Optional[T]: obj = Ptr[byte]() rsz = sizeof(tuple(T)) rtti = alloc_atomic(rsz) if RTTI.__contents_atomic__ else alloc(rsz) __internal__.to_class_ptr(rtti, RTTI).id = T.__id__ return __internal__.opt_ref_new_arg(__internal__.to_class_ptr_rtti((obj, rtti), T)) @pure @derives @llvm def opt_tuple_new_arg(what: T, T: type) -> Optional[T]: %0 = insertvalue { i1, {=T} } { i1 true, {=T} undef }, {=T} %what, 1 ret { i1, {=T} } %0 @pure @derives @llvm def opt_ref_new_arg(what: T, T: type) -> Optional[T]: ret ptr %what @pure @derives @llvm def opt_ref_new_arg_rtti(what: T, T: type) -> Optional[T]: ret { ptr, ptr } %what @pure @llvm def opt_tuple_bool(what: Optional[T], T: type) -> bool: %0 = extractvalue { i1, {=T} } %what, 0 %1 = zext i1 %0 to i8 ret i8 %1 @pure @llvm def opt_ref_bool(what: Optional[T], T: type) -> bool: %0 = icmp ne ptr %what, null %1 = zext i1 %0 to i8 ret i8 %1 @pure def opt_ref_bool_rtti(what: Optional[T], T: type) -> bool: return __internal__.class_raw_rtti_ptr() != cobj() @pure @derives @llvm def opt_tuple_invert(what: Optional[T], T: type) -> T: %0 = extractvalue { i1, {=T} } %what, 1 ret {=T} %0 @pure @derives @llvm def opt_ref_invert(what: Optional[T], T: type) -> T: ret ptr %what @pure @derives @llvm def opt_ref_invert_rtti(what: Optional[T], T: type) -> T: ret { ptr, ptr } %what @pure @derives @llvm def to_class_ptr(p: Ptr[byte], T: type) -> T: ret ptr %p @pure @derives @llvm def to_class_ptr_rtti(p: Tuple[Ptr[byte], Ptr[byte]], T: type) -> T: ret { ptr, ptr } %p def _tuple_offsetof(x, field: Static[int]) -> int: @pure @llvm def _llvm_offsetof(T: type, idx: Static[int], TE: type) -> int: %a = alloca {=T} %b = getelementptr inbounds {=T}, ptr %a, i64 0, i32 {=idx} %base = ptrtoint ptr %a to i64 %elem = ptrtoint ptr %b to i64 %offset = sub i64 %elem, %base ret i64 %offset return _llvm_offsetof(type(x), field, type(x[field])) def raw_type_str(p: Ptr[byte], name: str) -> str: pstr = p.__repr__() # '<[name] at [pstr]>' total = 1 + name.len + 4 + pstr.len + 1 buf = Ptr[byte](total) where = 0 buf[where] = byte(60) # '<' where += 1 str.memcpy(buf + where, name.ptr, name.len) where += name.len buf[where] = byte(32) # ' ' where += 1 buf[where] = byte(97) # 'a' where += 1 buf[where] = byte(116) # 't' where += 1 buf[where] = byte(32) # ' ' where += 1 str.memcpy(buf + where, pstr.ptr, pstr.len) where += pstr.len buf[where] = byte(62) # '>' free(pstr.ptr) return str(buf, total) def tuple_str(strs: Ptr[str], names: Ptr[str], n: int) -> str: # special case of 1-element plain tuple: format as "(x,)" if n == 1 and names[0].len == 0: total = strs[0].len + 3 buf = Ptr[byte](total) buf[0] = byte(40) # '(' str.memcpy(buf + 1, strs[0].ptr, strs[0].len) buf[total - 2] = byte(44) # ',' buf[total - 1] = byte(41) # ')' return str(buf, total) total = 2 # one for each of '(' and ')' i = 0 while i < n: total += strs[i].len if names[i].len: total += names[i].len + 2 # extra : and space if i < n - 1: total += 2 # ", " i += 1 buf = Ptr[byte](total) where = 0 buf[where] = byte(40) # '(' where += 1 i = 0 while i < n: s = names[i] l = s.len if l: str.memcpy(buf + where, s.ptr, l) where += l buf[where] = byte(58) # ':' where += 1 buf[where] = byte(32) # ' ' where += 1 s = strs[i] l = s.len str.memcpy(buf + where, s.ptr, l) where += l if i < n - 1: buf[where] = byte(44) # ',' where += 1 buf[where] = byte(32) # ' ' where += 1 i += 1 buf[where] = byte(41) # ')' return str(buf, total) def undef(v, s): if not v: raise NameError(f"name '{s}' is not defined") @__hidden__ def set_header(e, func, file, line, col, cause): if not isinstance(e, BaseException): compile_error("exceptions must derive from BaseException") e.func = func e.file = file e.line = line e.col = col if cause is not None: e.cause = cause return e def kwargs_get(kw, key: Static[str], default): if hasattr(kw, key): return getattr(kw, key) else: return default @extend class __magic__: # always present def tuplesize(T: type) -> int: return (t() for t in vars_types(T)).__elemsize__ # @dataclass parameter: init=True for tuples; # always present for reference types only def new(T: type) -> T: """Create a new reference (class) type""" return __internal__.class_alloc(T) # init is compiler-generated when init=True for reference types # def init(self, a1, ..., aN): ... # always present for reference types only def raw(obj) -> Ptr[byte]: if __has_rtti__(type(obj)): return __internal__.class_raw_rtti_ptr(obj) else: return __internal__.class_raw_ptr(obj) # always present for reference types only def dict(slf) -> List[str]: d = List[str](staticlen(slf)) for k, _ in vars(slf): d.append(k) return d # always present for tuple types only def len(slf) -> int: return staticlen(slf) # always present for tuple types only def add(slf, obj): if not isinstance(obj, Tuple): compile_error("can only concatenate tuple to tuple") return (*slf, *obj) # always present for tuple types only def mul(slf, i: Static[int]): if i < 1: return () elif i == 1: return slf else: return (*(__magic__.mul(slf, i - 1)), *slf) # always present for tuples def contains(slf, what) -> bool: for _, v in vars(slf): if isinstance(what, type(v)): if what == v: return True return False # @dataclass parameter: container=True def getitem(slf, index: int): if staticlen(slf) == 0: __internal__.tuple_fix_index(index, 0) # raise exception else: return __internal__.tuple_getitem(slf, index, type(slf), tuple_type(slf, 0)) # @dataclass parameter: container=True def iter(slf): if staticlen(slf) == 0: if int(0): yield 0 # set generator type without yielding anything for _, v in vars(slf): yield v # @dataclass parameter: order=True or eq=True def eq(slf, obj) -> bool: for k, v in vars(slf): if not (v == getattr(obj, k)): return False return True # @dataclass parameter: order=True or eq=True def ne(slf, obj) -> bool: return not (slf == obj) # @dataclass parameter: order=True def lt(slf, obj) -> bool: for k, v in vars(slf): z = getattr(obj, k) if v < z: return True if not (v == z): return False return False # @dataclass parameter: order=True def le(slf, obj) -> bool: for k, v in vars(slf): z = getattr(obj, k) if v < z: return True if not (v == z): return False return True # @dataclass parameter: order=True def gt(slf, obj) -> bool: for k, v in vars(slf): z = getattr(obj, k) if z < v: return True if not (v == z): return False return False # @dataclass parameter: order=True def ge(slf, obj) -> bool: for k, v in vars(slf): z = getattr(obj, k) if z < v: return True if not (v == z): return False return True # @dataclass parameter: hash=True def hash(slf) -> int: seed = 0 for _, v in vars(slf): seed = seed ^ ((v.__hash__() + 2_654_435_769) + ((seed << 6) + (seed >> 2))) return seed # @dataclass parameter: pickle=True def pickle(slf, dest: Ptr[byte]) -> None: for _, v in vars(slf): v.__pickle__(dest) # @dataclass parameter: pickle=True def unpickle(src: Ptr[byte], T: type) -> T: if isinstance(T, ByVal): return __internal__.tuple_cast_unsafe(tuple(type(t).__unpickle__(src) for t in vars_types(T)), T) else: obj = T.__new__() for k, v in vars(obj): setattr(obj, k, type(v).__unpickle__(src)) return obj # @dataclass parameter: python=True def to_py(slf) -> Ptr[byte]: o = pyobj._tuple_new(staticlen(slf)) for i, _, v in vars(slf, with_index=True): pyobj._tuple_set(o, i, v.__to_py__()) return o # @dataclass parameter: python=True def from_py(src: Ptr[byte], T: type) -> T: if isinstance(T, ByVal): return __internal__.tuple_cast_unsafe(tuple( type(t).__from_py__(pyobj._tuple_get(src, i)) for i, t in vars_types(T, with_index=True) ), T) else: obj = T.__new__() for i, k, v in vars(obj, with_index=True): setattr(obj, k, type(v).__from_py__(pyobj._tuple_get(src, i))) return obj # @dataclass parameter: gpu=True def to_gpu(slf, cache): return __internal__.class_to_gpu(slf, cache) # @dataclass parameter: gpu=True def from_gpu(slf: T, other: T, T: type): __internal__.class_from_gpu(slf, other) # @dataclass parameter: gpu=True def from_gpu_new(other: T, T: type) -> T: return __internal__.class_from_gpu_new(other) # @dataclass parameter: repr=True def repr(slf) -> str: if staticlen(slf) == 0: return "()" a = __array__[str](staticlen(slf)) n = __array__[str](staticlen(slf)) for i, k, v in vars(slf, with_index=True): a[i] = v.__repr__() if isinstance(slf, Tuple): n[i] = "" else: n[i] = k return __internal__.tuple_str(a.ptr, n.ptr, staticlen(slf)) # @dataclass parameter: repr=True def str(slf) -> str: return slf.__repr__() @extend class Function: @pure @llvm def __new__(what: Ptr[byte]) -> Function[T, TR]: ret ptr %what def __new__(what: Function[T, TR]) -> Function[T, TR]: return what @pure @llvm def __raw__(self) -> Ptr[byte]: ret ptr %self def __repr__(self) -> str: return __internal__.raw_type_str(self.__raw__(), "function") @llvm def __call_internal__(self: Function[T, TR], args: T) -> TR: noop # compiler will populate this one def __call__(self, *args) -> TR: return Function.__call_internal__(self, args) @tuple class PyObject: refcnt: int pytype: Ptr[byte] @tuple class PyWrapper[T]: head: PyObject data: T @extend class RTTI: def __new__() -> RTTI: return __magic__.new(RTTI) def __init__(self, i: int): self.id = i def __raw__(self): return __internal__.class_raw_ptr(self) @extend class ellipsis: def __repr__(self): return 'Ellipsis' def __eq__(self, other: ellipsis): return True def __ne__(self, other: ellipsis): return False def __hash__(self): return 269626442 # same as CPython __internal__.class_init_vtables() class __cast__: def cast(obj: T, T: type) -> Generator[T]: return obj.__iter__() def cast(obj: int) -> float: return float(obj) def cast(obj: T, T: type) -> Optional[T]: return Optional[T](obj) def cast(obj: Optional[T], T: type) -> T: return obj.unwrap() def cast(obj: T, T: type) -> pyobj: return obj.__to_py__() def cast(obj: pyobj, T: type) -> T: return T.__from_py__(obj) # Function[[T...], R] # ExternFunction[[T...], R] # CodonFunction[[T...], R] # Partial[foo, [T...], R] # function into partial (if not Function) / fn(foo) -> fn(foo(...)) # empty partial (!!) into Function[] # union extract # any into Union[] # derived to base def conv_float(obj: float) -> int: return int(obj) @extend class TypeWrap: def __new__(T: type) -> TypeWrap[T]: return __internal__.tuple_cast_unsafe((), TypeWrap[T]) def __call_no_self__(*args, **kwargs) -> T: return T(*args, **kwargs) def __call__(self, *args, **kwargs) -> T: return T(*args, **kwargs) def __repr__(self): return f"" @property def __name__(self): return T.__name__