# (c) 2022 Exaloop Inc. All rights reserved.

from internal.gc import alloc_atomic, free
from internal.types.optional import unwrap


@tuple
class object:
    def __repr__(self) -> str:
        return "<object>"


def id(x) -> int:
    if isinstance(x, ByRef):
        return int(x.__raw__())
    else:
        return 0


_stdout = _C.seq_stdout()


def print(*args, sep: str = " ", end: str = "\n", file=_stdout, flush: bool = False):
    """
    Print args to the text stream file.
    """
    fp = cobj()
    if isinstance(file, cobj):
        fp = file
    else:
        fp = file.fp
    i = 0
    for a in args:
        if i and sep:
            _C.seq_print_full(sep, fp)
        _C.seq_print_full(str(a), fp)
        i += 1
    _C.seq_print_full(end, fp)
    if flush:
        _C.fflush(fp)


def min(*args):
    if staticlen(args) == 0:
        raise ValueError("empty sequence")
    elif staticlen(args) == 1 and hasattr(args[0], "__iter__"):
        x = args[0].__iter__()
        if not x.done():
            s = x.next()
            while not x.done():
                i = x.next()
                if i < s:
                    s = i
            x.destroy()
            return s
        else:
            x.destroy()
        raise ValueError("empty sequence")
    elif staticlen(args) == 2:
        a, b = args
        return a if a <= b else b
    else:
        m = args[0]
        for i in args:
            if i < m:
                m = i
        return m


def max(*args):
    if staticlen(args) == 0:
        raise ValueError("empty sequence")
    elif staticlen(args) == 1 and hasattr(args[0], "__iter__"):
        x = args[0].__iter__()
        if not x.done():
            s = x.next()
            while not x.done():
                i = x.next()
                if i > s:
                    s = i
            x.destroy()
            return s
        else:
            x.destroy()
        raise ValueError("empty sequence")
    elif staticlen(args) == 2:
        a, b = args
        return a if a >= b else b
    else:
        m = args[0]
        for i in args:
            if i > m:
                m = i
        return m


def len(x) -> int:
    """
    Return the length of x
    """
    return x.__len__()


def iter(x):
    """
    Return an iterator for the given object
    """
    return x.__iter__()


def copy(x):
    """
    Return a copy of x
    """
    return x.__copy__()


def abs(x):
    """
    Return the absolute value of x
    """
    return x.__abs__()


def hash(x) -> int:
    """
    Returns hashed value only for immutable objects
    """
    return x.__hash__()


def ord(s: str) -> int:
    """
    Return an integer representing the Unicode code point of s
    """
    if len(s) != 1:
        raise TypeError(
            f"ord() expected a character, but string of length {len(s)} found"
        )
    return int(s.ptr[0])


def divmod(a, b):
    if hasattr(a, "__divmod__"):
        return a.__divmod__(b)
    else:
        return (a // b, a % b)


def chr(i: int) -> str:
    """
    Return a string representing a character whose Unicode
    code point is an integer
    """
    p = cobj(1)
    p[0] = byte(i)
    return str(p, 1)


def next(g: Generator[T], default: Optional[T] = None, T: type) -> T:
    """
    Return the next item from g
    """
    if g.done():
        if default:
            return unwrap(default)
        else:
            raise StopIteration()
    return g.next()


def any(x: Generator[T], T: type) -> bool:
    """
    Returns True if any item in x is true,
    False otherwise
    """
    for a in x:
        if a:
            return True
    return False


def all(x: Generator[T], T: type) -> bool:
    """
    Returns True when all elements in x are true,
    False otherwise
    """
    for a in x:
        if not a:
            return False
    return True


def zip(*args):
    """
    Returns a zip object, which is an iterator of tuples
    that aggregates elements based on the iterables passed
    """
    if staticlen(args) == 0:
        yield from List[int]()
    else:
        iters = tuple(iter(i) for i in args)
        done = False
        while not done:
            for i in iters:
                if i.done():
                    done = True
            if not done:
                yield tuple(i.next() for i in iters)
        for i in iters:
            i.destroy()


def filter(f: Callable[[T], bool], x: Generator[T], T: type) -> Generator[T]:
    """
    Returns all a from the iterable x that are filtered by f
    """
    for a in x:
        if f(a):
            yield a


def map(f, *args):
    """
    Applies a function on all a in x and returns map object
    """
    if staticlen(args) == 0:
        compile_error("map() expects at least one iterator")
    elif staticlen(args) == 1:
        for a in args[0]:
            yield f(a)
    else:
        for a in zip(*args):
            yield f(*a)


def enumerate(x, start: int = 0):
    """
    Creates a tuple containing a count (from start which defaults
    to 0) and the values obtained from iterating over x
    """
    i = start
    for a in x:
        yield (i, a)
        i += 1


def echo(x):
    """
    Print and return argument
    """
    print x
    return x


def reversed(x):
    """
    Return an iterator that accesses x in the reverse order
    """
    if hasattr(x, "__reversed__"):
        return x.__reversed__()
    else:
        i = x.__len__() - 1
        while i >= 0:
            yield x[i]
            i -= 1


def round(x, n=0):
    """
    Return the x rounded off to the given
    n digits after the decimal point.
    """
    nx = float.__pow__(10.0, n)
    return float.__round__(x * nx) / nx


def sum(xi):
    """
    Return the sum of the items added together from xi
    """
    x = iter(xi)
    if not x.done():
        s = x.next()
        while not x.done():
            s += x.next()
        x.destroy()
        return s
    else:
        x.destroy()


def repr(x):
    """Return the string representation of x"""
    return x.__repr__()


def _int_format(a: int, base: int, prefix: str = ""):
    assert base == 2 or base == 8 or base == 10 or base == 16
    chars = "0123456789abcdef-"

    b = a
    digits = 0
    while b != 0:
        digits += 1
        b //= base

    sz = digits + (1 if a <= 0 else 0) + len(prefix)
    p = Ptr[byte](sz)
    q = p

    if a < 0:
        q[0] = chars[-1].ptr[0]
        q += 1

    if prefix:
        str.memcpy(q, prefix.ptr, len(prefix))
        q += len(prefix)

    if digits != 0:
        b = a
        q += digits - 1
        i = 1
        while b != 0:
            i += 1
            q[0] = chars.ptr[abs(b % base)]
            q += -1
            b //= base
    else:
        q[0] = chars.ptr[0]

    return str(p, sz)


def bin(n):
    return _int_format(n.__index__(), 2, "0b")


def oct(n):
    return _int_format(n.__index__(), 8, "0o")


def hex(n):
    return _int_format(n.__index__(), 16, "0x")


@extend
class int:
    def _from_str(s: str, base: int):
        from C import strtoll(cobj, Ptr[cobj], i32) -> int

        if base < 0 or base > 36 or base == 1:
            raise ValueError("int() base must be >= 2 and <= 36, or 0")

        buf = __array__[byte](32)
        n = len(s)
        need_dyn_alloc = n >= len(buf)

        p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
        str.memcpy(p, s.ptr, n)
        p[n] = byte(0)

        end = cobj()
        result = strtoll(p, __ptr__(end), i32(base))

        if need_dyn_alloc:
            free(p)

        if end != p + n:
            raise ValueError(
                f"invalid literal for int() with base {base}: {s}"
            )

        return result


def _jit_display(x, s: Static[str], bundle: Set[str] = Set[str]()):
    if isinstance(x, None):
        return
    if hasattr(x, "_repr_mimebundle_") and s == "jupyter":
        d = x._repr_mimebundle_(bundle)
        # TODO: pick appropriate mime
        mime = next(d.keys()) # just pick first
        print(f"\x00\x00__codon/mime__\x00{mime}\x00{d[mime]}", end='')
    elif hasattr(x, "__repr__"):
        print(x.__repr__(), end='')
    elif hasattr(x, "__str__"):
        print(x.__str__(), end='')