# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>

from internal.gc import realloc, free

class File:
    sz: int
    buf: Ptr[byte]
    fp: cobj

    def __init__(self, fp: cobj):
        self.fp = fp
        self._reset()

    def __init__(self, path: str, mode: str):
        self.fp = _C.fopen(path.c_str(), mode.c_str())
        if not self.fp:
            raise IOError(f"file {path} could not be opened")
        self._reset()

    def _errcheck(self, msg: str):
        err = int(_C.ferror(self.fp))
        if err:
            raise IOError(f"file I/O error: {msg}")

    def __enter__(self):
        pass

    def __exit__(self):
        self.close()

    def __iter__(self) -> Generator[str]:
        for a in self._iter():
            yield a.__ptrcopy__()

    def readlines(self) -> List[str]:
        return [l for l in self]

    def write(self, s: str):
        self._ensure_open()
        _C.fwrite(s.ptr, 1, len(s), self.fp)
        self._errcheck("error in write")

    def __file_write_gen__(self, g: Generator[T], T: type):
        for s in g:
            self.write(str(s))

    def read(self, sz: int = -1) -> str:
        self._ensure_open()
        if sz < 0:
            SEEK_SET = 0
            SEEK_END = 2
            cur = _C.ftell(self.fp)
            _C.fseek(self.fp, 0, i32(SEEK_END))
            sz = _C.ftell(self.fp) - cur
            _C.fseek(self.fp, cur, i32(SEEK_SET))
        buf = Ptr[byte](sz)
        ret = _C.fread(buf, 1, sz, self.fp)
        self._errcheck("error in read")
        return str(buf, ret)

    def tell(self) -> int:
        self._ensure_open()
        ret = _C.ftell(self.fp)
        self._errcheck("error in tell")
        return ret

    def seek(self, offset: int, whence: int):
        self._ensure_open()
        _C.fseek(self.fp, offset, i32(whence))
        self._errcheck("error in seek")

    def flush(self):
        self._ensure_open()
        _C.fflush(self.fp)

    def close(self):
        if self.fp:
            _C.fclose(self.fp)
            self.fp = cobj()
        if self.buf:
            _C.free(self.buf)
            self._reset()

    def _ensure_open(self):
        if not self.fp:
            raise IOError("I/O operation on closed file")

    def _reset(self):
        self.buf = Ptr[byte]()
        self.sz = 0

    def _iter(self) -> Generator[str]:
        self._ensure_open()
        while True:
            rd = _C.getline(
                Ptr[Ptr[byte]](self.__raw__() + 8), Ptr[int](self.__raw__()), self.fp
            )
            if rd != -1:
                yield str(self.buf, rd)
            else:
                break

    def _iter_trim_newline(self) -> Generator[str]:
        self._ensure_open()
        while True:
            rd = _C.getline(
                Ptr[Ptr[byte]](self.__raw__() + 8), Ptr[int](self.__raw__()), self.fp
            )
            if rd != -1:
                if self.buf[rd - 1] == byte(10):
                    rd -= 1
                yield str(self.buf, rd)
            else:
                break

def _gz_errcheck(stream: cobj):
    errnum = i32(0)
    msg = _C.gzerror(stream, __ptr__(errnum))
    if msg and msg[0]:
        raise IOError(f"zlib error: {str(msg, _C.strlen(msg))}")

class gzFile:
    sz: int
    buf: Ptr[byte]
    fp: cobj

    def __init__(self, fp: cobj):
        self.fp = fp
        self._reset()

    def __init__(self, path: str, mode: str):
        self.fp = _C.gzopen(path.c_str(), mode.c_str())
        if not self.fp:
            raise IOError(f"file {path} could not be opened")
        self._reset()

    def _getline(self) -> int:
        if not self.buf:
            self.sz = 128
            self.buf = Ptr[byte](self.sz)

        offset = 0
        while True:
            if not _C.gzgets(self.fp, self.buf + offset, i32(self.sz - offset)):
                _gz_errcheck(self.fp)
                if offset == 0:
                    return -1
                break

            offset += _C.strlen(self.buf + offset)

            if self.buf[offset - 1] == byte(10):  # '\n'
                break

            oldsz = self.sz
            self.sz *= 2
            self.buf = realloc(self.buf, self.sz, oldsz)

        return offset

    def __iter__(self) -> Generator[str]:
        for a in self._iter():
            yield a.__ptrcopy__()

    def __enter__(self):
        pass

    def __exit__(self):
        self.close()

    def close(self):
        if self.fp:
            _C.gzclose(self.fp)
            self.fp = cobj()
        if self.buf:
            free(self.buf)
            self._reset()

    def readlines(self) -> List[str]:
        return [l for l in self]

    def write(self, s: str):
        self._ensure_open()
        _C.gzwrite(self.fp, s.ptr, u32(len(s)))
        _gz_errcheck(self.fp)

    def __file_write_gen__(self, g: Generator[T], T: type):
        for s in g:
            self.write(str(s))

    def read(self, sz: int = -1) -> str:
        self._ensure_open()
        if sz < 0:
            buf = _strbuf()
            for a in self._iter():
                buf.append(a)
            return buf.__str__()
        buf = Ptr[byte](sz)
        ret = _C.gzread(self.fp, buf, u32(sz))
        _gz_errcheck(self.fp)
        return str(buf, int(ret))

    def tell(self) -> int:
        self._ensure_open()
        ret = _C.gztell(self.fp)
        _gz_errcheck(self.fp)
        return ret

    def seek(self, offset: int, whence: int):
        self._ensure_open()
        _C.gzseek(self.fp, offset, i32(whence))
        _gz_errcheck(self.fp)

    def flush(self):
        Z_FINISH = 4
        self._ensure_open()
        _C.gzflush(self.fp, i32(Z_FINISH))
        _gz_errcheck(self.fp)

    def _iter(self) -> Generator[str]:
        self._ensure_open()
        while True:
            rd = self._getline()
            if rd != -1:
                yield str(self.buf, rd)
            else:
                break

    def _iter_trim_newline(self) -> Generator[str]:
        self._ensure_open()
        while True:
            rd = self._getline()
            if rd != -1:
                if self.buf[rd - 1] == byte(10):
                    rd -= 1
                yield str(self.buf, rd)
            else:
                break

    def _ensure_open(self):
        if not self.fp:
            raise IOError("I/O operation on closed file")

    def _reset(self):
        self.buf = cobj()
        self.sz = 0

def open(path: str, mode: str = "r") -> File:
    return File(path, mode)

def gzopen(path: str, mode: str = "r") -> gzFile:
    return gzFile(path, mode)

def is_binary(path: str) -> bool:
    # https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391
    # Can get both false positive and false negatives, but still is a
    # clever approach that works for the large majority of files
    textchars = {7, 8, 9, 10, 12, 13, 27} | set(iter(range(0x20, 0x100))) - {0x7F}
    with open(path, "rb") as f:
        header = f.read(1024)
        return any(ord(c) not in textchars for c in header)