# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>

from internal.gc import (
    alloc, alloc_atomic, alloc_uncollectable, alloc_atomic_uncollectable,
    free, atomic, sizeof, register_finalizer
)

__vtable_size__ = 0

@extend
class __internal__:
    def yield_final(val):
        pass
    def yield_in_no_suspend(T: type) -> T:
        pass

    @pure
    @derives
    @llvm
    def class_raw(obj) -> Ptr[byte]:
        ret ptr %obj

    def class_alloc(T: type) -> T:
        """Allocates a new reference (class) type"""
        sz = sizeof(tuple(T))
        p = alloc_atomic(sz) if T.__contents_atomic__ else alloc(sz)
        register_finalizer(p)
        return __internal__.to_class_ptr(p, T)

    def class_new(T: type) -> T:
        """Create a new reference (class) type"""
        pf = __internal__.class_alloc(T)
        __internal__.class_set_obj_vtable(pf)
        return pf

    def class_ctr(T: type, *args, **kwargs) -> T:
        """Shorthand for `t = T.__new__(); t.__init__(*args, **kwargs); t`"""
        return T(*args, **kwargs)

    def class_set_obj_vtable(pf: T, T: type) -> None:
        """
        Initialize a vtable of a T() object. Compiler generated.
        Corresponds to:
            pf.__vtable__ = __vtables__[pf.__vtable_id___]
        """
        pass

    def class_init_vtables() -> Ptr[Ptr[cobj]]:
        """
        Create a global vtable. Compiler generated.
        Corresponds to:
            return __internal__.class_make_n_vtables(<number of class realizations>)
        """
        pass

    def class_make_n_vtables(sz: int) -> Ptr[Ptr[cobj]]:
        """Create a global vtable."""
        p = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj])))
        __internal__.class_populate_vtables(p)
        return p

    def class_populate_vtables(p: Ptr[Ptr[cobj]]) -> None:
        """
        Populate content of vtables. Compiler generated.
        Corresponds to:
            for each realized class C:
                p.__setitem__(<C's realization ID>, Ptr[cobj](<C's vtable size> + 1))
                __internal__.class_set_typeinfo(p[<C's realization ID>], <C's realization ID>)
                for each fn F in C's vtable:
                    p[<C's realization ID>].__setitem__(<F's vtable ID>, Function(<instantiated F>).__raw__())
        """
        pass

    def class_set_typeinfo(p: Ptr[cobj], typeinfo: T, T: type) -> None:
        i = Ptr[T](1)
        i[0] = typeinfo
        p[0] = i.as_byte()

    def class_get_typeinfo(p) -> int:
        c = Ptr[Ptr[cobj]](p.__raw__())
        vt = c[0]
        return Ptr[int](vt[0])[0]

    @inline
    def class_base_derived_dist(B: type, D: type) -> int:
        """Calculates the byte distance of base class B and derived class D. Compiler generated."""
        return 0

    def class_copy(obj: T, T: type) -> T:
        p = __internal__.class_alloc(T)
        str.memcpy(p.__raw__(), obj.__raw__(), sizeof(tuple(T)))
        return p

    def class_super(obj: D, B: type, D: type) -> B:
        pf = __internal__.to_class_ptr(obj.__raw__() + __internal__.class_base_derived_dist(B, D), B)
        pn = __internal__.class_copy(pf)
        # Replace vtables
        __internal__.class_set_obj_vtable(pn)  # replace vtables to point to its vtables!
        return pn

    # Unions

    @llvm
    def union_set_tag(tag: byte, U: type) -> U:
        %0 = insertvalue {=U} undef, i8 %tag, 0
        ret {=U} %0

    @llvm
    def union_get_data_ptr(ptr: Ptr[U], U: type, T: type) -> Ptr[T]:
        %0 = getelementptr inbounds {=U}, ptr %ptr, i64 0, i32 1
        ret ptr %0

    @llvm
    def union_get_tag(u: U, U: type) -> byte:
        %0 = extractvalue {=U} %u, 0
        ret i8 %0

    def union_get_data(u, T: type) -> T:
        return __internal__.union_get_data_ptr(__ptr__(u), T=T)[0]

    def union_make(tag: int, value, U: type) -> U:
        u = __internal__.union_set_tag(byte(tag), U)
        __internal__.union_get_data_ptr(__ptr__(u), T=type(value))[0] = value
        return u

    def new_union(value, U: type) -> U:
        pass

    def get_union(union, T: type) -> T:
        pass

    def get_union_first(union):
        pass

    def _get_union_method(union, method: Static[str], *args, **kwargs):
        pass

    def get_union_method(union, method: Static[str], *args, **kwargs):
        t = __internal__._get_union_method(union, method, *args, **kwargs)
        if staticlen(t) == 1:
            return __internal__.get_union_first(t)
        return t

    # Tuples

    @pure
    @derives
    @llvm
    def _tuple_getitem_llvm(t: T, idx: int, T: type, E: type) -> E:
        %x = alloca {=T}
        store {=T} %t, ptr %x
        %p = getelementptr {=E}, ptr %x, i64 %idx
        %v = load {=E}, ptr %p
        ret {=E} %v

    def tuple_fix_index(idx: int, len: int) -> int:
        if idx < 0:
            idx += len
        if idx < 0 or idx >= len:
            raise IndexError(f"tuple index {idx} out of range 0..{len}")
        return idx

    def tuple_getitem(t: T, idx: int, T: type, E: type) -> E:
        return __internal__._tuple_getitem_llvm(
            t, __internal__.tuple_fix_index(idx, staticlen(t)), T, E
        )

    def tuple_add(t, i):
        if isinstance(i, Tuple):
            return (*t, *i)
        else:
            compile_error("can only concatenate tuple to tuple")

    def tuple_mul(t, i: Static[int]):
        if i < 1:
            return ()
        elif i == 1:
            return t
        else:
            return (*(__internal__.tuple_mul(t, i - 1)), *t)

    # ...

    @pure
    @derives
    @llvm
    def fn_new(p: Ptr[byte], T: type) -> T:
        ret ptr %p

    @pure
    @derives
    @llvm
    def fn_raw(fn: T, T: type) -> Ptr[byte]:
        ret ptr %fn

    @pure
    @llvm
    def int_sext(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = sext i{=F} %what to i{=T}
        ret i{=T} %0

    @pure
    @llvm
    def int_zext(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = zext i{=F} %what to i{=T}
        ret i{=T} %0

    @pure
    @llvm
    def int_trunc(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = trunc i{=F} %what to i{=T}
        ret i{=T} %0

    def seq_assert(file: str, line: int, msg: str) -> AssertionError:
        s = f": {msg}" if msg else ""
        s = f"Assert failed{s} ({file}:{line.__repr__()})"
        return AssertionError(s)

    def seq_assert_test(file: str, line: int, msg: str):
        from C import seq_print(str)
        s = f": {msg}" if msg else ""
        s = f"\033[1;31mTEST FAILED:\033[0m {file} (line {line}){s}\n"
        seq_print(s)

    def check_errno(prefix: str):
        @pure
        @C
        def seq_check_errno() -> str:
            pass

        msg = seq_check_errno()
        if msg:
            raise OSError(prefix + msg)

    @pure
    @llvm
    def opt_tuple_new(T: type) -> Optional[T]:
        ret { i1, {=T} } { i1 false, {=T} undef }

    @pure
    @llvm
    def opt_ref_new(T: type) -> Optional[T]:
        ret ptr null

    @pure
    @derives
    @llvm
    def opt_tuple_new_arg(what: T, T: type) -> Optional[T]:
        %0 = insertvalue { i1, {=T} } { i1 true, {=T} undef }, {=T} %what, 1
        ret { i1, {=T} } %0

    @pure
    @derives
    @llvm
    def opt_ref_new_arg(what: T, T: type) -> Optional[T]:
        ret ptr %what

    @pure
    @llvm
    def opt_tuple_bool(what: Optional[T], T: type) -> bool:
        %0 = extractvalue { i1, {=T} } %what, 0
        %1 = zext i1 %0 to i8
        ret i8 %1

    @pure
    @llvm
    def opt_ref_bool(what: Optional[T], T: type) -> bool:
        %0 = icmp ne ptr %what, null
        %1 = zext i1 %0 to i8
        ret i8 %1

    @pure
    @derives
    @llvm
    def opt_tuple_invert(what: Optional[T], T: type) -> T:
        %0 = extractvalue { i1, {=T} } %what, 1
        ret {=T} %0

    @pure
    @derives
    @llvm
    def opt_ref_invert(what: Optional[T], T: type) -> T:
        ret ptr %what

    @pure
    @derives
    @llvm
    def to_class_ptr(p: Ptr[byte], T: type) -> T:
        ret ptr %p

    def _tuple_offsetof(x, field: Static[int]) -> int:
        @pure
        @llvm
        def _llvm_offsetof(T: type, idx: Static[int], TE: type) -> int:
            %a = alloca {=T}
            %b = getelementptr inbounds {=T}, ptr %a, i64 0, i32 {=idx}
            %base = ptrtoint ptr %a to i64
            %elem = ptrtoint ptr %b to i64
            %offset = sub i64 %elem, %base
            ret i64 %offset

        return _llvm_offsetof(type(x), field, type(x[field]))

    def raw_type_str(p: Ptr[byte], name: str) -> str:
        pstr = p.__repr__()
        # '<[name] at [pstr]>'
        total = 1 + name.len + 4 + pstr.len + 1
        buf = Ptr[byte](total)
        where = 0
        buf[where] = byte(60)  # '<'
        where += 1
        str.memcpy(buf + where, name.ptr, name.len)
        where += name.len
        buf[where] = byte(32)  # ' '
        where += 1
        buf[where] = byte(97)  # 'a'
        where += 1
        buf[where] = byte(116)  # 't'
        where += 1
        buf[where] = byte(32)  # ' '
        where += 1
        str.memcpy(buf + where, pstr.ptr, pstr.len)
        where += pstr.len
        buf[where] = byte(62)  # '>'
        free(pstr.ptr)
        return str(buf, total)

    def tuple_str(strs: Ptr[str], names: Ptr[str], n: int) -> str:
        total = 2  # one for each of '(' and ')'
        i = 0
        while i < n:
            total += strs[i].len
            if names[i].len:
                total += names[i].len + 2  # extra : and space
            if i < n - 1:
                total += 2  # ", "
            i += 1
        buf = Ptr[byte](total)
        where = 0
        buf[where] = byte(40)  # '('
        where += 1
        i = 0
        while i < n:
            s = names[i]
            l = s.len
            if l:
                str.memcpy(buf + where, s.ptr, l)
                where += l
                buf[where] = byte(58)  # ':'
                where += 1
                buf[where] = byte(32)  # ' '
                where += 1
            s = strs[i]
            l = s.len
            str.memcpy(buf + where, s.ptr, l)
            where += l
            if i < n - 1:
                buf[where] = byte(44)  # ','
                where += 1
                buf[where] = byte(32)  # ' '
                where += 1
            i += 1
        buf[where] = byte(41)  # ')'
        return str(buf, total)

    def undef(v, s):
        if not v:
            raise NameError(f"variable '{s}' not yet defined")

    @__hidden__
    def set_header(e, func, file, line, col):
        if not isinstance(e, BaseException):
            compile_error("exceptions must derive from BaseException")

        e.func = func
        e.file = file
        e.line = line
        e.col = col
        return e

    # TODO: keep this commented until static for lands
    # @pure
    # @llvm
    # def __raw__(self) -> Ptr[byte]:
    #     ret i8* %self
    # def __iter__(self):
    #     for _, i in self.__static_list__:
    #         yield i
    # def __contains__(self, what):
    #     for _, i in self.__static_list__:
    #         if isinstance(what, type(i)):
    #             if what == i:
    #                 return True
    #     return False
    # def __eq__(self, obj) -> bool:
    #     for ii, i in self.__static_list__:
    #         if not i == obj.__static_list__[ii]:
    #             return False
    #     return True
    # def __ne__(self, obj) -> bool:
    #     return not self == obj
    # def __lt__(self, obj) -> bool:
    #     for ii, i in self.__static_list__:
    #         if i < obj.__static_list__[ii]:
    #             return True
    #         elif not i == obj.__static_list__[ii]:
    #             return False
    #     return False
    # def __le__(self, obj) -> bool:
    #     return self < obj or self == obj
    # def __ge__(self, obj) -> bool:
    #     return not self < obj
    # def __gt__(self, obj) -> bool:
    #     return not self < obj and not self == obj
    # def __hash__(self) -> int:
    #     seed = 0
    #     for _, i in self.__static_list__:
    #         seed = seed ^ ((i.__hash__() + 2654435769) + ((seed << 6) + (seed >> 2)))
    #     return seed
    # def __pickle__(self, dest: Ptr[byte]) -> None:
    #     for _, i in self.__static_list__:
    #         i.__pickle__(dest)
    # # __unpickle
    # def __len__(self):
    #     return staticlen(self)
    # def __to_py__(self) -> cobj:
    #     o = pyobj._tuple_new(staticlen(self))
    #     for ii, i in self.__static_list__:
    #         pyobj._tuple_set(o, ii + 1, i.__to_py__())
    #     return o
    # # __from_py__
    # def __repr__(self) -> str:
    #     a = __array__[str](staticlen(self))
    #     n = __array__[str](staticlen(self))
    #     for ii, i in self.__static_list__:
    #         a.__setitem__(ii, i.__repr__())
    #         n.__setitem__(ii, i.__static_name__)
    #     return __internal__.tuple_str(a.ptr, n.ptr, staticlen(self))
    # def __dict__(self) -> str:
    #     d = List[str](staticlen(self))
    #     for _, i in self.__static_list__:
    #         d.append(i.__static_name)
    #     return d
    # def __add__(self, obj):
    #     return (*self, *obj)

@dataclass(init=True)
@tuple
class Import:
    name: str
    file: str

    def __repr__(self) -> str:
        return f"<module '{self.name}' from '{self.file}'>"

@extend
class Function:
    @pure
    @llvm
    def __new__(what: Ptr[byte]) -> Function[T, TR]:
        ret ptr %what

    def __new__(what: Function[T, TR]) -> Function[T, TR]:
       return what

    @pure
    @llvm
    def __raw__(self) -> Ptr[byte]:
        ret ptr %self

    def __repr__(self) -> str:
       return __internal__.raw_type_str(self.__raw__(), "function")

    @llvm
    def __call_internal__(self: Function[T, TR], args: T) -> TR:
        noop  # compiler will populate this one

    def __call__(self, *args) -> TR:
        return Function.__call_internal__(self, args)

__vtables__ = __internal__.class_init_vtables()
def _____(): __vtables__  # make it global!


@tuple
class PyObject:
    refcnt: int
    pytype: Ptr[byte]


@tuple
class PyWrapper[T]:
    head: PyObject
    data: T