from internal.gc import alloc_atomic, free

@tuple
class object:
    def __repr__(self):
        return '<object>'

def id(x):
    if isinstance(x, ByRef):
        return int(x.__raw__())
    else:
        return 0

_stdout = _C.seq_stdout()
def print(*args, sep: str = ' ', end: str = '\n', file=_stdout, flush: bool = False):
    """
    Print args to the text stream file.
    """
    fp = cobj()
    if isinstance(file, Ptr[byte]):
        fp = file
    else:
        fp = file.fp
    i = 0
    for a in args:
        if i and sep:
            _C.seq_print_full(sep, fp)
        _C.seq_print_full(str(a), fp)
        i += 1
    _C.seq_print_full(end, fp)
    if flush:
        _C.fflush(fp)

def min(*args):
    if staticlen(args) == 0:
        raise ValueError("empty sequence")
    elif staticlen(args) == 1 and hasattr(args[0], "__iter__"):
        x = args[0].__iter__()
        if not x.done():
            s = x.next()
            while not x.done():
                i = x.next()
                if i < s: s = i
            x.destroy()
            return s
        else:
            x.destroy()
        raise ValueError("empty sequence")
    elif staticlen(args) == 2:
        a, b = args
        return a if a <= b else b
    else:
        m = args[0]
        for i in args:
            if i < m: m = i
        return m

def max(*args):
    if staticlen(args) == 0:
        raise ValueError("empty sequence")
    elif staticlen(args) == 1 and hasattr(args[0], "__iter__"):
        x = args[0].__iter__()
        if not x.done():
            s = x.next()
            while not x.done():
                i = x.next()
                if i > s: s = i
            x.destroy()
            return s
        else:
            x.destroy()
        raise ValueError("empty sequence")
    elif staticlen(args) == 2:
        a, b = args
        return a if a >= b else b
    else:
        m = args[0]
        for i in args:
            if i > m: m = i
        return m

def len(x):
    """
    Return the length of x
    """
    return x.__len__()

def iter(x):
    """
    Return an iterator for the given object
    """
    return x.__iter__()

def copy(x):
    """
    Return a copy of x
    """
    return x.__copy__()

def abs(x):
    """
    Return the absolute value of x
    """
    return x.__abs__()

def hash(x):
    """
    Returns hashed value only for immutable objects
    """
    return x.__hash__()

def ord(s: str):
    """
    Return an integer representing the Unicode code point of s
    """
    if len(s) != 1:
        raise TypeError("ord() expected a character, but string of length " + str(len(s)) + " found")
    return int(s.ptr[0])

def divmod(a, b):
    if hasattr(a, "__divmod__"):
        return a.__divmod__(b)
    else:
        return (a // b, a % b)

def chr(i: int):
    """
    Return a string representing a character whose Unicode
    code point is an integer
    """
    p = cobj(1)
    p[0] = byte(i)
    return str(p, 1)

def next[T](g: Generator[T], default: Optional[T] = None):
    """
    Return the next item from g
    """
    if g.done():
        if default:
            return ~default
        else:
            raise StopIteration()
    return g.next()

def any(x):
    """
    Returns True if any item in x is true,
    False otherwise
    """
    for a in x:
        if a:
            return True
    return False

def all(x):
    """
    Returns True when all elements in x are true,
    False otherwise
    """
    for a in x:
        if not a:
            return False
    return True

def zip(*args):
    """
    Returns a zip object, which is an iterator of tuples
    that aggregates elements based on the iterables passed
    """
    if staticlen(args) == 0:
        yield from List[int]()
    else:
        iters = tuple(iter(i) for i in args)
        done = False
        while not done:
            for i in iters:
                if i.done():
                    done = True
            if not done:
                yield tuple(i.next() for i in iters)
        for i in iters:
            i.destroy()

def filter[T](f: Callable[[T], bool], x: Generator[T]):
    """
    Returns all a from the iterable x that are filtered by f
    """
    for a in x:
        if f(a):
            yield a

def map(f, *args):
    """
    Applies a function on all a in x and returns map object
    """
    if staticlen(args) == 0:
        compile_error("map() expects at least one iterator")
    elif staticlen(args) == 1:
        for a in args[0]:
            yield f(a)
    else:
        for a in zip(*args):
            yield f(*a)

def enumerate(x, start: int = 0):
    """
    Creates a tuple containing a count (from start which defaults
    to 0) and the values obtained from iterating over x
    """
    i = start
    for a in x:
        yield (i,a)
        i += 1

def echo(x):
    """
    Print and return argument
    """
    print x
    return x

def reversed(x):
    """
    Return an iterator that accesses x in the reverse order
    """
    if hasattr(x, "__reversed__"):
        return x.__reversed__()
    else:
        i = x.__len__() - 1
        while i >= 0:
            yield x[i]
            i -= 1

def round(x, n) -> float:
    """
    Return the x rounded off to the given
    n digits after the decimal point.
    """
    nx = float.__pow__(10.0, n)
    return float.__round__(x * nx) / nx

def sum(xi):
    """
    Return the sum of the items added together from xi
    """
    x = iter(xi)
    if not x.done():
        s = x.next()
        while not x.done():
            s += x.next()
        x.destroy()
        return s
    else:
        x.destroy()

def repr(x):
    """    Return the string representation of x
    """
    return x.__repr__()

def _int_format(a: int, base: int, prefix: str = ''):
    assert base == 2 or base == 8 or base == 10 or base == 16
    chars = '0123456789abcdef-'

    b = a
    digits = 0
    while b != 0:
        digits += 1
        b //= base

    sz = digits + (1 if a <= 0 else 0) + len(prefix)
    p = Ptr[byte](sz)
    q = p

    if a < 0:
        q[0] = chars[-1].ptr[0]
        q += 1

    if prefix:
        str.memcpy(q, prefix.ptr, len(prefix))
        q += len(prefix)

    if digits != 0:
        b = a
        q += digits - 1
        i = 1
        while b != 0:
            i += 1
            q[0] = chars.ptr[abs(b % base)]
            q += -1
            b //= base
    else:
        q[0] = chars.ptr[0]

    return str(p, sz)

def bin(n):
    return _int_format(n.__index__(), 2, '0b')

def oct(n):
    return _int_format(n.__index__(), 8, '0o')

def hex(n):
    return _int_format(n.__index__(), 16, '0x')

@extend
class int:
    def _from_str(s: str, base: int):
        from C import strtoll(cobj, Ptr[cobj], i32) -> int
        if base < 0 or base > 36 or base == 1:
            raise ValueError("int() base must be >= 2 and <= 36, or 0")

        buf = __array__[byte](32)
        n = len(s)
        need_dyn_alloc = (n >= len(buf))

        p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
        str.memcpy(p, s.ptr, n)
        p[n] = byte(0)

        end = cobj()
        result = strtoll(p, __ptr__(end), i32(base))

        if need_dyn_alloc:
            free(p)

        if end != p + n:
            raise ValueError("invalid literal for int() with base " + str(base) + ": " + s)

        return result