# (c) 2022 Exaloop Inc. All rights reserved.

from sys import stderr


def time() -> float:
    return _C.seq_time() / 1e9


def time_ns() -> int:
    return _C.seq_time()


def monotonic() -> float:
    return _C.seq_time_monotonic() / 1e9


def monotonic_ns() -> int:
    return _C.seq_time_monotonic()


def perf_counter() -> float:
    return _C.seq_time_highres() / 1e9


def perf_counter_ns() -> int:
    return _C.seq_time_highres()


def sleep(secs: float) -> void:
    if secs < 0:
        raise ValueError("sleep length must be non-negative")
    _C.seq_sleep(secs)


class TimeInterval:
    """
    Utility class for timing Seq code
    """

    start: int
    msg: str

    def __init__(self) -> void:
        self.start = _C.seq_time()
        self.msg = ""

    def __enter__(self) -> void:
        self.start = _C.seq_time()

    def __exit__(self) -> void:
        print(self.report(self.msg), file=stderr)

    def report(self, msg="", memory=False) -> str:
        msg = f"{'Block' if not self.msg else self.msg} took {self.elapsed()}s"
        # if memory:
        # msg = f'{msg} ({_C.memory()} MB)'
        return msg

    def elapsed(self) -> float:
        return float(_C.seq_time() - self.start) / 1e9

    def tick(self, msg, memory=False) -> void:
        ret = self.report(msg)
        self.start = _C.seq_time()


def timing(msg: str = "") -> TimeInterval:
    """
    Example usage:

    .. code-block:: python

        from time import timing
        with timing('foo function'):
            foo()  # prints runtime of foo
    """
    return TimeInterval(0, msg)


@tuple
class struct_time:
    _year: i16
    _yday: i16
    _sec: i8
    _min: i8
    _hour: i8
    _mday: i8
    _mon: i8
    _wday: i8
    _isdst: i8

    # (sunday=0) --> (monday=0)
    def _wday_adjust_monday_start(wday: int) -> int:
        x = wday - 1
        if x < 0:
            x = 6
        return x

    # (monday=0) --> (sunday=0)
    def _wday_adjust_sunday_start(wday: int) -> int:
        x = wday + 1
        if x > 6:
            x = 0
        return x

    def __new__(
        year: int,
        mon: int,
        mday: int,
        hour: int,
        min: int,
        sec: int,
        wday: int,
        yday: int,
        isdst: int,
    ) -> struct_time:
        return struct_time(
            i16(year - 1900),
            i16(yday - 1),
            i8(sec),
            i8(min),
            i8(hour),
            i8(mday),
            i8(mon - 1),
            i8(struct_time._wday_adjust_sunday_start(wday)),
            i8(isdst),
        )

    @property
    def tm_year(self) -> int:
        return int(self._year) + 1900

    @property
    def tm_yday(self) -> int:
        return int(self._yday) + 1

    @property
    def tm_sec(self) -> int:
        return int(self._sec)

    @property
    def tm_min(self) -> int:
        return int(self._min)

    @property
    def tm_hour(self) -> int:
        return int(self._hour)

    @property
    def tm_mday(self) -> int:
        return int(self._mday)

    @property
    def tm_mon(self) -> int:
        return int(self._mon) + 1

    @property
    def tm_wday(self) -> int:
        return struct_time._wday_adjust_monday_start(int(self._wday))

    @property
    def tm_isdst(self) -> int:
        return int(self._isdst)


def localtime(secs: int = -1) -> struct_time:
    tm = struct_time()
    worked = _C.seq_localtime(secs, __ptr__(tm).as_byte())
    if not worked:
        raise OSError("localtime failed")
    return tm


def gmtime(secs: int = -1) -> struct_time:
    tm = struct_time()
    worked = _C.seq_gmtime(secs, __ptr__(tm).as_byte())
    if not worked:
        raise OSError("localtime failed")
    return tm


def mktime(t) -> int:
    if isinstance(t, struct_time):
        return _C.seq_mktime(__ptr__(t).as_byte())
    else:
        tm = struct_time(*t)
        return _C.seq_mktime(__ptr__(tm).as_byte())


# pytime.h funcs

_ROUND_HALF_EVEN = 0
_ROUND_CEILING = 1
_ROUND_FLOOR = 2
_ROUND_UP = 3

_MIN = 0x8000000000000000
_MAX = 0x7FFFFFFFFFFFFFFF


def _overflow() -> void:
    raise OverflowError("timestamp too large")


def _add(t1: int, t2: int) -> int:
    if t2 > 0 and t1 > _MAX - t2:
        return _MAX
    elif t2 < 0 and t1 < _MIN - t2:
        return _MIN
    else:
        return t1 + t2


def _mul_check_overflow(a: int, b: int) -> bool:
    if b != 0:
        # assert b > 0
        return (a < _MIN // b) or (_MAX // b < a)
    else:
        return False


def _mul(t: int, k: int) -> int:
    # assert k >= 0
    if _mul_check_overflow(t, k):
        return _MAX if t >= 0 else _MIN
    else:
        return t * k


def _muldiv(ticks: int, mul: int, div: int) -> int:
    intpart = ticks / div
    ticks %= div
    remaining = _mul(ticks, mul) // div
    return _add(_mul(intpart, mul), remaining)


def _round_half_even(x: float) -> float:
    from math import fabs

    rounded = x.__round__()
    if fabs(x - rounded) == 0.5:
        rounded = 2.0 * (x / 2.0).__round__()
    return rounded


def _round(x: float, mode: int) -> float:
    d = x
    if mode == _ROUND_HALF_EVEN:
        d = _round_half_even(d)
    elif mode == _ROUND_CEILING:
        d = d.__ceil__()
    elif mode == _ROUND_FLOOR:
        d = d.__floor__()
    elif mode == _ROUND_UP:
        d = d.__ceil__() if d >= 0 else d.__floor__()
    return d


def _double_to_denominator(d: float, idenominator: int, mode: int) -> Tuple[int, int]:
    from math import modf

    denominator = float(idenominator)
    floatpart, intpart = modf(d)

    floatpart *= denominator
    floatpart = _round(floatpart, mode)
    if floatpart >= denominator:
        floatpart -= denominator
        intpart += 1.0
    elif floatpart < 0.0:
        floatpart += denominator
        intpart -= 1.0
    # assert 0.0 <= floatpart < denominator

    if intpart < _MIN or intpart > _MAX:
        _overflow()

    sec = int(intpart)
    numerator = int(floatpart)
    # assert 0 <= numerator < idenominator
    return sec, numerator


def _time_to_timespec(t: float, mode: int) -> Tuple[int, int]:
    return _double_to_denominator(t, 1000000000, mode)


def _time_to_timeval(t: float, mode: int) -> Tuple[int, int]:
    return _double_to_denominator(t, 1000000, mode)