# (c) 2022 Exaloop Inc. All rights reserved.

from internal.gc import free


@pure
@C
def seq_check_errno() -> str:
    pass


from C import seq_print(str)
from C import exit(int)


@extend
class __internal__:
    @pure
    @llvm
    def class_raw(obj) -> Ptr[byte]:
        ret i8* %obj

    @pure
    @llvm
    def _tuple_getitem_llvm(t: T, idx: int, T: type, E: type) -> E:
        %x = alloca {=T}
        store {=T} %t, {=T}* %x
        %p0 = bitcast {=T}* %x to {=E}*
        %p1 = getelementptr {=E}, {=E}* %p0, i64 %idx
        %v  = load {=E}, {=E}* %p1
        ret {=E} %v

    def tuple_fix_index(idx: int, len: int) -> int:
        if idx < 0:
            idx += len
        if idx < 0 or idx >= len:
            raise IndexError("tuple index " + str(idx) + " out of range 0.." + str(len))
        return idx

    def tuple_getitem(t: T, idx: int, T: type, E: type) -> E:
        return __internal__._tuple_getitem_llvm(
            t, __internal__.tuple_fix_index(idx, staticlen(t)), T, E
        )

    @pure
    @llvm
    def fn_new(ptr: Ptr[byte], T: type) -> T:
        %0 = bitcast i8* %ptr to {=T}
        ret {=T} %0

    @pure
    @llvm
    def fn_raw(fn: T, T: type) -> Ptr[byte]:
        %0 = bitcast {=T} %fn to i8*
        ret i8* %0

    @pure
    @llvm
    def int_sext(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = sext i{=F} %what to i{=T}
        ret i{=T} %0

    @pure
    @llvm
    def int_zext(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = zext i{=F} %what to i{=T}
        ret i{=T} %0

    @pure
    @llvm
    def int_trunc(what, F: Static[int], T: Static[int]) -> Int[T]:
        %0 = trunc i{=F} %what to i{=T}
        ret i{=T} %0

    def seq_assert(file: str, line: int, msg: str) -> AssertionError:
        s = "Assert failed"
        if msg:
            s += ": " + msg
        s += " (" + file + ":" + line.__repr__() + ")"
        return AssertionError(s)

    def seq_assert_test(file: str, line: int, msg: str) -> void:
        s = "\033[1;31mTEST FAILED:\033[0m " + file + " (line " + str(line) + ")"
        if msg:
            s += ": " + msg
        seq_print(s + "\n")

    def check_errno(prefix: str) -> void:
        msg = seq_check_errno()
        if msg:
            raise OSError(prefix + msg)

    @pure
    @llvm
    def opt_tuple_new(T: type) -> Optional[T]:
        ret { i1, {=T} } { i1 false, {=T} undef }

    @pure
    @llvm
    def opt_ref_new(T: type) -> Optional[T]:
        ret i8* null

    @pure
    @llvm
    def opt_tuple_new_arg(what: T, T: type) -> Optional[T]:
        %0 = insertvalue { i1, {=T} } { i1 true, {=T} undef }, {=T} %what, 1
        ret { i1, {=T} } %0

    @pure
    @llvm
    def opt_ref_new_arg(what: T, T: type) -> Optional[T]:
        ret i8* %what

    @pure
    @llvm
    def opt_tuple_bool(what: Optional[T], T: type) -> bool:
        %0 = extractvalue { i1, {=T} } %what, 0
        %1 = zext i1 %0 to i8
        ret i8 %1

    @pure
    @llvm
    def opt_ref_bool(what: Optional[T], T: type) -> bool:
        %0 = icmp ne i8* %what, null
        %1 = zext i1 %0 to i8
        ret i8 %1

    @pure
    @llvm
    def opt_tuple_invert(what: Optional[T], T: type) -> T:
        %0 = extractvalue { i1, {=T} } %what, 1
        ret {=T} %0

    @pure
    @llvm
    def opt_ref_invert(what: Optional[T], T: type) -> T:
        ret i8* %what

    @pure
    @llvm
    def to_class_ptr(ptr: Ptr[byte], T: type) -> T:
        %0 = bitcast i8* %ptr to {=T}
        ret {=T} %0

    @pure
    def _tuple_offsetof(x, field: Static[int]) -> int:
        @llvm
        def _llvm_offsetof(T: type, idx: Static[int], TE: type) -> int:
            %a = alloca {=T}
            %b = getelementptr inbounds {=T}, {=T}* %a, i64 0, i32 {=idx}
            %base = ptrtoint {=T}* %a to i64
            %elem = ptrtoint {=TE}* %b to i64
            %offset = sub i64 %elem, %base
            ret i64 %offset

        return _llvm_offsetof(type(x), field, type(x[field]))

    def raw_type_str(p: Ptr[byte], name: str) -> str:
        pstr = p.__repr__()
        # '<[name] at [pstr]>'
        total = 1 + name.len + 4 + pstr.len + 1
        buf = Ptr[byte](total)
        where = 0
        buf[where] = byte(60)  # '<'
        where += 1
        str.memcpy(buf + where, name.ptr, name.len)
        where += name.len
        buf[where] = byte(32)  # ' '
        where += 1
        buf[where] = byte(97)  # 'a'
        where += 1
        buf[where] = byte(116)  # 't'
        where += 1
        buf[where] = byte(32)  # ' '
        where += 1
        str.memcpy(buf + where, pstr.ptr, pstr.len)
        where += pstr.len
        buf[where] = byte(62)  # '>'
        free(pstr.ptr)
        return str(buf, total)

    def tuple_str(strs: Ptr[str], names: Ptr[str], n: int) -> str:
        total = 2  # one for each of '(' and ')'
        i = 0
        while i < n:
            total += strs[i].len
            if names[i].len:
                total += names[i].len + 2  # extra : and space
            if i < n - 1:
                total += 2  # ", "
            i += 1
        buf = Ptr[byte](total)
        where = 0
        buf[where] = byte(40)  # '('
        where += 1
        i = 0
        while i < n:
            s = names[i]
            l = s.len
            if l:
                str.memcpy(buf + where, s.ptr, l)
                where += l
                buf[where] = byte(58)  # ':'
                where += 1
                buf[where] = byte(32)  # ' '
                where += 1
            s = strs[i]
            l = s.len
            str.memcpy(buf + where, s.ptr, l)
            where += l
            if i < n - 1:
                buf[where] = byte(44)  # ','
                where += 1
                buf[where] = byte(32)  # ' '
                where += 1
            i += 1
        buf[where] = byte(41)  # ')'
        return str(buf, total)


@dataclass(init=True)
class Import:
    name: str
    file: str

    def __repr__(self) -> str:
        return f"<module '{self.name}' from '{self.file}'>"