# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>

from internal.gc import (
    alloc, alloc_atomic, alloc_atomic_uncollectable,
    free, sizeof, register_finalizer
)
from internal.static import vars_types, tuple_type, vars as _vars


def vars(obj, with_index: Static[int] = 0):
    return _vars(obj, with_index)


__vtables__ = Ptr[Ptr[cobj]]()
__vtable_size__ = 0


@extend
class __internal__:
    def yield_final(val):
        pass

    def yield_in_no_suspend(T: type) -> T:
        pass

    @pure
    @derives
    @llvm
    def class_raw_ptr(obj) -> Ptr[byte]:
        ret ptr %obj

    @pure
    @derives
    @llvm
    def class_raw_rtti_ptr(obj) -> Ptr[byte]:
        %0 = extractvalue {ptr, ptr} %obj, 0
        ret ptr %0

    @pure
    @derives
    @llvm
    def class_raw_rtti_rtti(obj: T, T: type) -> Ptr[byte]:
        %0 = extractvalue {ptr, ptr} %obj, 1
        ret ptr %0

    def class_alloc(T: type) -> T:
        """Allocates a new reference (class) type"""
        sz = sizeof(tuple(T))
        obj = alloc_atomic(sz) if T.__contents_atomic__ else alloc(sz)
        if __has_rtti__(T):
            register_finalizer(obj)
            rtti = RTTI(T.__id__).__raw__()
            return __internal__.to_class_ptr_rtti((obj, rtti), T)
        else:
            register_finalizer(obj)
            return __internal__.to_class_ptr(obj, T)

    def class_ctr(T: type, *args, **kwargs) -> T:
        """Shorthand for `t = T.__new__(); t.__init__(*args, **kwargs); t`"""
        return T(*args, **kwargs)

    def class_init_vtables():
        """
        Create a global vtable.
        """
        global __vtables__
        sz = __vtable_size__ + 1
        __vtables__ = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj])))
        __internal__.class_populate_vtables()

    def class_populate_vtables() -> None:
        """
        Populate content of vtables. Compiler generated.
        Corresponds to:
            for each realized class C:
                __internal__.class_set_rtti_vtable(<C's realization ID>, <C's vtable size> + 1, T=C)
                for each fn F in C's vtable:
                    __internal__.class_set_rtti_vtable_fn(
                        <C's realization ID>, <F's vtable ID>, Function(<instantiated F>).__raw__(), T=C
                    )
        """
        pass

    def class_get_rtti_vtable(obj) -> Ptr[cobj]:
        if not __has_rtti__(type(obj)):
            compile_error("class is not polymorphic")
        rtti = __internal__.class_raw_rtti_rtti(obj)
        id = __internal__.to_class_ptr(rtti, RTTI).id
        return __vtables__[id]

    def class_set_rtti_vtable(id: int, sz: int, T: type):
        if not __has_rtti__(T):
            compile_error("class is not polymorphic")
        __vtables__[id] = Ptr[cobj](sz + 1)
        __internal__.class_set_typeinfo(__vtables__[id], id)

    def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type):
        if not __has_rtti__(T):
            compile_error("class is not polymorphic")
        __vtables__[id][fid] = f

    def class_set_typeinfo(p: Ptr[cobj], typeinfo: T, T: type) -> None:
        i = Ptr[T](1)
        i[0] = typeinfo
        p[0] = i.as_byte()

    def class_get_typeinfo(p) -> int:
        c = Ptr[Ptr[cobj]](p.__raw__())
        vt = c[0]
        return Ptr[int](vt[0])[0]

    @inline
    def class_base_derived_dist(B: type, D: type) -> int:
        """Calculates the byte distance of base class B and derived class D. Compiler generated."""
        return 0

    @inline
    def class_base_to_derived(b: B, B: type, D: type) -> D:
        if not (__has_rtti__(D) and __has_rtti__(B)):
            compile_error("classes are not polymorphic")
        off = __internal__.class_base_derived_dist(B, D)
        ptr = __internal__.class_raw_rtti_ptr(b) - off
        pr = __internal__.class_raw_rtti_rtti(b)
        return __internal__.to_class_ptr_rtti((ptr, pr), D)

    def class_copy(obj: T, T: type) -> T:
        p = __internal__.class_alloc(T)
        str.memcpy(p.__raw__(), obj.__raw__(), sizeof(tuple(T)))
        return p

    def class_super(obj, B: type, change_rtti: Static[int] = 0) -> B:
        D = type(obj)
        if not __has_rtti__(D):  # static inheritance
            return __internal__.to_class_ptr(obj.__raw__(), B)
        else:
            if not __has_rtti__(B):
                compile_error("classes are not polymorphic")
            off = __internal__.class_base_derived_dist(B, D)
            ptr = __internal__.class_raw_rtti_ptr(obj) + off
            pr = __internal__.class_raw_rtti_rtti(obj)
            if change_rtti:
                pr = RTTI(B.__id__).__raw__()
            return __internal__.to_class_ptr_rtti((ptr, pr), B)

    # Unions

    @llvm
    def union_set_tag(tag: byte, U: type) -> U:
        %0 = insertvalue {=U} undef, i8 %tag, 0
        ret {=U} %0

    @llvm
    def union_get_data_ptr(ptr: Ptr[U], U: type, T: type) -> Ptr[T]:
        %0 = getelementptr inbounds {=U}, ptr %ptr, i64 0, i32 1
        ret ptr %0

    @llvm
    def union_get_tag(u: U, U: type) -> byte:
        %0 = extractvalue {=U} %u, 0
        ret i8 %0

    def union_get_data(u, T: type) -> T:
        return __internal__.union_get_data_ptr(__ptr__(u), T=T)[0]

    def union_make(tag: int, value, U: type) -> U:
        u = __internal__.union_set_tag(byte(tag), U)
        __internal__.union_get_data_ptr(__ptr__(u), T=type(value))[0] = value
        return u

    def new_union(value, U: type) -> U:
        pass

    def get_union(union, T: type) -> T:
        pass

    def get_union_first(union):
        pass

    def _get_union_method(union, method: Static[str], *args, **kwargs):
        pass

    def get_union_method(union, method: Static[str], *args, **kwargs):
        t = __internal__._get_union_method(union, method, *args, **kwargs)
        if staticlen(t) == 1:
            return __internal__.get_union_first(t)
        return t

    # Tuples

    @pure
    @derives
    @llvm
    def _tuple_getitem_llvm(t: T, idx: int, T: type, E: type) -> E:
        %x = alloca {=T}
        store {=T} %t, ptr %x
        %p = getelementptr {=E}, ptr %x, i64 %idx
        %v = load {=E}, ptr %p
        ret {=E} %v

    def tuple_fix_index(idx: int, len: int) -> int:
        if idx < 0:
            idx += len
        if idx < 0 or idx >= len:
            raise IndexError(f"tuple index {idx} out of range 0..{len}")
        return idx

    def tuple_getitem(t: T, idx: int, T: type, E: type) -> E:
        return __internal__._tuple_getitem_llvm(
            t, __internal__.tuple_fix_index(idx, staticlen(t)), T, E
        )

    @pure
    @derives
    @llvm
    def fn_new(p: Ptr[byte], T: type) -> T:
        ret ptr %p

    @pure
    @derives
    @llvm
    def fn_raw(fn: T, T: type) -> Ptr[byte]:
        ret ptr %fn

    @pure
    @llvm
    def int_sext(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = sext i{=F} %what to i{=T}
        ret i{=T} %0

    @pure
    @llvm
    def int_zext(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = zext i{=F} %what to i{=T}
        ret i{=T} %0

    @pure
    @llvm
    def int_trunc(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = trunc i{=F} %what to i{=T}
        ret i{=T} %0

    def seq_assert(file: str, line: int, msg: str) -> AssertionError:
        s = f": {msg}" if msg else ""
        s = f"Assert failed{s} ({file}:{line.__repr__()})"
        return AssertionError(s)

    def seq_assert_test(file: str, line: int, msg: str):
        from C import seq_print(str)
        s = f": {msg}" if msg else ""
        s = f"\033[1;31mTEST FAILED:\033[0m {file} (line {line}){s}\n"
        seq_print(s)

    def check_errno(prefix: str):
        @pure
        @C
        def seq_check_errno() -> str:
            pass

        msg = seq_check_errno()
        if msg:
            raise OSError(prefix + msg)

    @pure
    @llvm
    def opt_ref_new(T: type) -> Optional[T]:
        ret ptr null

    def opt_ref_new_rtti(T: type) -> Optional[T]:
        obj = Ptr[byte]()
        rsz = sizeof(tuple(T))
        rtti = alloc_atomic(rsz) if RTTI.__contents_atomic__ else alloc(rsz)
        __internal__.to_class_ptr(rtti, RTTI).id = T.__id__
        return __internal__.opt_ref_new_arg(__internal__.to_class_ptr_rtti((obj, rtti), T))

    @pure
    @derives
    @llvm
    def opt_tuple_new_arg(what: T, T: type) -> Optional[T]:
        %0 = insertvalue { i1, {=T} } { i1 true, {=T} undef }, {=T} %what, 1
        ret { i1, {=T} } %0

    @pure
    @derives
    @llvm
    def opt_ref_new_arg(what: T, T: type) -> Optional[T]:
        ret ptr %what

    @pure
    @derives
    @llvm
    def opt_ref_new_arg_rtti(what: T, T: type) -> Optional[T]:
        ret { ptr, ptr } %what

    @pure
    @llvm
    def opt_tuple_bool(what: Optional[T], T: type) -> bool:
        %0 = extractvalue { i1, {=T} } %what, 0
        %1 = zext i1 %0 to i8
        ret i8 %1

    @pure
    @llvm
    def opt_ref_bool(what: Optional[T], T: type) -> bool:
        %0 = icmp ne ptr %what, null
        %1 = zext i1 %0 to i8
        ret i8 %1

    @pure
    def opt_ref_bool_rtti(what: Optional[T], T: type) -> bool:
        return __internal__.class_raw_rtti_ptr() != cobj()

    @pure
    @derives
    @llvm
    def opt_tuple_invert(what: Optional[T], T: type) -> T:
        %0 = extractvalue { i1, {=T} } %what, 1
        ret {=T} %0

    @pure
    @derives
    @llvm
    def opt_ref_invert(what: Optional[T], T: type) -> T:
        ret ptr %what

    @pure
    @derives
    @llvm
    def opt_ref_invert_rtti(what: Optional[T], T: type) -> T:
        ret { ptr, ptr } %what

    @pure
    @derives
    @llvm
    def to_class_ptr(p: Ptr[byte], T: type) -> T:
        ret ptr %p

    @pure
    @derives
    @llvm
    def to_class_ptr_rtti(p: Tuple[Ptr[byte], Ptr[byte]], T: type) -> T:
        ret { ptr, ptr } %p

    def _tuple_offsetof(x, field: Static[int]) -> int:
        @pure
        @llvm
        def _llvm_offsetof(T: type, idx: Static[int], TE: type) -> int:
            %a = alloca {=T}
            %b = getelementptr inbounds {=T}, ptr %a, i64 0, i32 {=idx}
            %base = ptrtoint ptr %a to i64
            %elem = ptrtoint ptr %b to i64
            %offset = sub i64 %elem, %base
            ret i64 %offset

        return _llvm_offsetof(type(x), field, type(x[field]))

    def raw_type_str(p: Ptr[byte], name: str) -> str:
        pstr = p.__repr__()
        # '<[name] at [pstr]>'
        total = 1 + name.len + 4 + pstr.len + 1
        buf = Ptr[byte](total)
        where = 0
        buf[where] = byte(60)  # '<'
        where += 1
        str.memcpy(buf + where, name.ptr, name.len)
        where += name.len
        buf[where] = byte(32)  # ' '
        where += 1
        buf[where] = byte(97)  # 'a'
        where += 1
        buf[where] = byte(116)  # 't'
        where += 1
        buf[where] = byte(32)  # ' '
        where += 1
        str.memcpy(buf + where, pstr.ptr, pstr.len)
        where += pstr.len
        buf[where] = byte(62)  # '>'
        free(pstr.ptr)
        return str(buf, total)

    def tuple_str(strs: Ptr[str], names: Ptr[str], n: int) -> str:
        total = 2  # one for each of '(' and ')'
        i = 0
        while i < n:
            total += strs[i].len
            if names[i].len:
                total += names[i].len + 2  # extra : and space
            if i < n - 1:
                total += 2  # ", "
            i += 1
        buf = Ptr[byte](total)
        where = 0
        buf[where] = byte(40)  # '('
        where += 1
        i = 0
        while i < n:
            s = names[i]
            l = s.len
            if l:
                str.memcpy(buf + where, s.ptr, l)
                where += l
                buf[where] = byte(58)  # ':'
                where += 1
                buf[where] = byte(32)  # ' '
                where += 1
            s = strs[i]
            l = s.len
            str.memcpy(buf + where, s.ptr, l)
            where += l
            if i < n - 1:
                buf[where] = byte(44)  # ','
                where += 1
                buf[where] = byte(32)  # ' '
                where += 1
            i += 1
        buf[where] = byte(41)  # ')'
        return str(buf, total)

    def undef(v, s):
        if not v:
            raise NameError(f"variable '{s}' not yet defined")

    @__hidden__
    def set_header(e, func, file, line, col):
        if not isinstance(e, BaseException):
            compile_error("exceptions must derive from BaseException")

        e.func = func
        e.file = file
        e.line = line
        e.col = col
        return e


@extend
class __magic__:
    # always present
    def tuplesize(T: type) -> int:
        return (t() for t in vars_types(T)).__elemsize__

    # @dataclass parameter: init=True for tuples;
    # always present for reference types only
    def new(T: type) -> T:
        """Create a new reference (class) type"""
        return __internal__.class_alloc(T)

    # init is compiler-generated when init=True for reference types
    # def init(self, a1, ..., aN): ...

    # always present for reference types only
    def raw(obj) -> Ptr[byte]:
        if __has_rtti__(type(obj)):
            return __internal__.class_raw_rtti_ptr(obj)
        else:
            return __internal__.class_raw_ptr(obj)

    # always present for reference types only
    def dict(slf) -> List[str]:
        d = List[str](staticlen(slf))
        for k, _ in vars(slf):
            d.append(k)
        return d

    # always present for tuple types only
    def len(slf) -> int:
        return staticlen(slf)

    # always present for tuple types only
    def add(slf, obj):
        if not isinstance(obj, Tuple):
            compile_error("can only concatenate tuple to tuple")
        return (*slf, *obj)

    # always present for tuple types only
    def mul(slf, i: Static[int]):
        if i < 1:
            return ()
        elif i == 1:
            return slf
        else:
            return (*(__magic__.mul(slf, i - 1)), *slf)

    # always present for tuples
    def contains(slf, what) -> bool:
        for _, v in vars(slf):
            if isinstance(what, type(v)):
                if what == v:
                    return True
        return False

    # @dataclass parameter: container=True
    def getitem(slf, index: int):
        if staticlen(slf) == 0:
            raise IndexError(f"tuple index {index} out of range 0..0")
        else:
            return __internal__.tuple_getitem(slf, index, type(slf), tuple_type(slf, 0))

    # @dataclass parameter: container=True
    def iter(slf):
        if staticlen(slf) == 0:
            if int(0): yield 0  # set generator type without yielding anything
        for _, v in vars(slf):
            yield v

    # @dataclass parameter: order=True or eq=True
    def eq(slf, obj) -> bool:
        for k, v in vars(slf):
            if not (v == getattr(obj, k)):
                return False
        return True

    # @dataclass parameter: order=True or eq=True
    def ne(slf, obj) -> bool:
        return not (slf == obj)

    # @dataclass parameter: order=True
    def lt(slf, obj) -> bool:
        for k, v in vars(slf):
            z = getattr(obj, k)
            if v < z:
                return True
            if not (v == z):
                return False
        return False

    # @dataclass parameter: order=True
    def le(slf, obj) -> bool:
        for k, v in vars(slf):
            z = getattr(obj, k)
            if v < z:
                return True
            if not (v == z):
                return False
        return True

    # @dataclass parameter: order=True
    def gt(slf, obj) -> bool:
        for k, v in vars(slf):
            z = getattr(obj, k)
            if z < v:
                return True
            if not (v == z):
                return False
        return False

    # @dataclass parameter: order=True
    def ge(slf, obj) -> bool:
        for k, v in vars(slf):
            z = getattr(obj, k)
            if z < v:
                return True
            if not (v == z):
                return False
        return True

    # @dataclass parameter: hash=True
    def hash(slf) -> int:
        seed = 0
        for _, v in vars(slf):
            seed = seed ^ ((v.__hash__() + 2_654_435_769) + ((seed << 6) + (seed >> 2)))
        return seed

    # @dataclass parameter: pickle=True
    def pickle(slf, dest: Ptr[byte]) -> None:
        for _, v in vars(slf):
            v.__pickle__(dest)

    # @dataclass parameter: pickle=True
    def unpickle(src: Ptr[byte], T: type) -> T:
        if isinstance(T, ByVal):
            return tuple(type(t).__unpickle__(src) for t in vars_types(T))
        else:
            obj = T.__new__()
            for k, v in vars(obj):
                setattr(obj, k, type(v).__unpickle__(src))
            return obj

    # @dataclass parameter: python=True
    def to_py(slf) -> Ptr[byte]:
        o = pyobj._tuple_new(staticlen(slf))
        for i, _, v in vars(slf, with_index=1):
            pyobj._tuple_set(o, i, v.__to_py__())
        return o

    # @dataclass parameter: python=True
    def from_py(src: Ptr[byte], T: type) -> T:
        if isinstance(T, ByVal):
            return tuple(
                type(t).__from_py__(pyobj._tuple_get(src, i))
                for i, t in vars_types(T, with_index=1)
            )
        else:
            obj = T.__new__()
            for i, k, v in vars(obj, with_index=True):
                setattr(obj, k, type(v).__from_py__(pyobj._tuple_get(src, i)))
            return obj

    # @dataclass parameter: gpu=True
    def to_gpu(slf, cache):
        return __internal__.class_to_gpu(slf, cache)

    # @dataclass parameter: gpu=True
    def from_gpu(slf: T, other: T, T: type):
        __internal__.class_from_gpu(slf, other)

    # @dataclass parameter: gpu=True
    def from_gpu_new(other: T, T: type) -> T:
        __internal__.class_from_gpu_new(other)

    # @dataclass parameter: repr=True
    def repr(slf) -> str:
        if staticlen(slf) == 0:
            return "()"
        a = __array__[str](staticlen(slf))
        n = __array__[str](staticlen(slf))
        for i, k, v in vars(slf, with_index=True):
            a[i] = v.__repr__()
            if isinstance(slf, Tuple):
                n[i] = ""
            else:
                n[i] = k
        return __internal__.tuple_str(a.ptr, n.ptr, staticlen(slf))

    # @dataclass parameter: repr=True
    def str(slf) -> str:
        return slf.__repr__()


@extend
class Import:
    def __repr__(self) -> str:
        return f"<module '{self.name}' from '{self.file}'>"

@extend
class Function:
    @pure
    @llvm
    def __new__(what: Ptr[byte]) -> Function[T, TR]:
        ret ptr %what

    def __new__(what: Function[T, TR]) -> Function[T, TR]:
       return what

    @pure
    @llvm
    def __raw__(self) -> Ptr[byte]:
        ret ptr %self

    def __repr__(self) -> str:
       return __internal__.raw_type_str(self.__raw__(), "function")

    @llvm
    def __call_internal__(self: Function[T, TR], args: T) -> TR:
        noop  # compiler will populate this one

    def __call__(self, *args) -> TR:
        return Function.__call_internal__(self, args)


@tuple
class PyObject:
    refcnt: int
    pytype: Ptr[byte]


@tuple
class PyWrapper[T]:
    head: PyObject
    data: T


@extend
class RTTI:
    def __new__() -> RTTI:
        return __magic__.new(RTTI)
    def __init__(self, i: int):
        self.id = i
    def __raw__(self):
        return __internal__.class_raw_ptr(self)

@extend
class ellipsis:
    def __repr__(self):
        return 'Ellipsis'
    def __eq__(self, other: ellipsis):
        return True
    def __ne__(self, other: ellipsis):
        return False
    def __hash__(self):
        return 269626442  # same as CPython

__internal__.class_init_vtables()