from internal.file import _gz_errcheck
from internal.gc import sizeof, atomic

def pickle[T](x: T, jar: Jar):
    x.__pickle__(jar)

def unpickle[T](jar: Jar):
    return T.__unpickle__(jar)

def dump[T](x: T, f):
    x.__pickle__(f.fp)

def load[T](f) -> T:
    return T.__unpickle__(f.fp)

def _write_raw(jar: Jar, p: cobj, n: int):
    LIMIT = 0x7fffffff
    while n > 0:
        b = n if n < LIMIT else LIMIT
        status = int(_C.gzwrite(jar, p, u32(b)))
        if status != b:
            _gz_errcheck(jar)
            raise IOError("pickle error: gzwrite returned " + str(status))
        p += b
        n -= b

def _read_raw(jar: Jar, p: cobj, n: int):
    LIMIT = 0x7fffffff
    while n > 0:
        b = n if n < LIMIT else LIMIT
        status = int(_C.gzread(jar, p, u32(b)))
        if status != b:
            _gz_errcheck(jar)
            raise IOError("pickle error: gzread returned " + str(status))
        p += b
        n -= b

def _write[T](jar: Jar, x: T):
    y = __ptr__(x)
    _write_raw(jar, y.as_byte(), sizeof(T))

def _read[T](jar: Jar):
    x = T()
    y = __ptr__(x)
    _read_raw(jar, y.as_byte(), sizeof(T))
    return x

# Extend core types to allow pickling

@extend
class int:
    def __pickle__(self, jar: Jar):
        _write(jar, self)

    def __unpickle__(jar: Jar):
        return _read(jar, int)

@extend
class float:
    def __pickle__(self, jar: Jar):
        _write(jar, self)

    def __unpickle__(jar: Jar):
        return _read(jar, float)

@extend
class bool:
    def __pickle__(self, jar: Jar):
        _write(jar, self)

    def __unpickle__(jar: Jar):
        return _read(jar, bool)

@extend
class byte:
    def __pickle__(self, jar: Jar):
        _write(jar, self)

    def __unpickle__(jar: Jar):
        return _read(jar, byte)

@extend
class str:
    def __pickle__(self, jar: Jar):
        _write(jar, self.len)
        _write_raw(jar, self.ptr, self.len)

    def __unpickle__(jar: Jar):
        n = _read(jar, int)
        p = Ptr[byte](n)
        _read_raw(jar, p, n)
        return str(p, n)

@extend
class List:
    def __pickle__(self, jar: Jar):
        n = len(self)
        pickle(n, jar)
        if atomic(T):
            _write_raw(jar, (self.arr.ptr).as_byte(), n * sizeof(T))
        else:
            for i in range(n):
                pickle(self.arr[i], jar)

    def __unpickle__(jar: Jar):
        n = unpickle(jar, int)
        arr = Array[T](n)
        if atomic(T):
            _read_raw(jar, (arr.ptr).as_byte(), n * sizeof(T))
        else:
            for i in range(n):
                arr[i] = unpickle(jar, T)
        return List[T](arr, n)

@extend
class Dict:
    def __pickle__(self, jar: Jar):
        import internal.khash as khash
        if atomic(K) and atomic(V):
            pickle(self._n_buckets, jar)
            pickle(self._size, jar)
            pickle(self._n_occupied, jar)
            pickle(self._upper_bound, jar)
            fsize = khash.__ac_fsize(self._n_buckets) if self._n_buckets > 0 else 0
            _write_raw(jar, self._flags.as_byte(), fsize * sizeof(u32))
            _write_raw(jar, self._keys.as_byte(), self._n_buckets * sizeof(K))
            _write_raw(jar, self._vals.as_byte(), self._n_buckets * sizeof(V))
        else:
            pickle(self._n_buckets, jar)
            size = len(self)
            pickle(size, jar)

            for k,v in self.items():
                pickle(k, jar)
                pickle(v, jar)

    def __unpickle__(jar: Jar):
        import internal.khash as khash
        d = Dict[K,V]()
        if atomic(K) and atomic(V):
            n_buckets = unpickle(jar, int)
            size = unpickle(jar, int)
            n_occupied = unpickle(jar, int)
            upper_bound = unpickle(jar, int)
            fsize = khash.__ac_fsize(n_buckets) if n_buckets > 0 else 0
            flags = Ptr[u32](fsize)
            keys = Ptr[K](n_buckets)
            vals = Ptr[V](n_buckets)
            _read_raw(jar, flags.as_byte(), fsize * sizeof(u32))
            _read_raw(jar, keys.as_byte(), n_buckets * sizeof(K))
            _read_raw(jar, vals.as_byte(), n_buckets * sizeof(V))

            d._n_buckets = n_buckets
            d._size = size
            d._n_occupied = n_occupied
            d._upper_bound = upper_bound
            d._flags = flags
            d._keys = keys
            d._vals = vals
        else:
            n_buckets = unpickle(jar, int)
            size = unpickle(jar, int)
            d.resize(n_buckets)
            i = 0
            while i < size:
                k = unpickle(jar, K)
                v = unpickle(jar, V)
                d[k] = v
                i += 1
        return d

@extend
class Set:
    def __pickle__(self, jar: Jar):
        import internal.khash as khash
        if atomic(K):
            pickle(self._n_buckets, jar)
            pickle(self._size, jar)
            pickle(self._n_occupied, jar)
            pickle(self._upper_bound, jar)
            fsize = khash.__ac_fsize(self._n_buckets) if self._n_buckets > 0 else 0
            _write_raw(jar, self._flags.as_byte(), fsize * sizeof(u32))
            _write_raw(jar, self._keys.as_byte(), self._n_buckets * sizeof(K))
        else:
            pickle(self._n_buckets, jar)
            size = len(self)
            pickle(size, jar)

            for k in self:
                pickle(k, jar)

    def __unpickle__(jar: Jar):
        import internal.khash as khash
        s = Set[K]()
        if atomic(K):
            n_buckets = unpickle(jar, int)
            size = unpickle(jar, int)
            n_occupied = unpickle(jar, int)
            upper_bound = unpickle(jar, int)
            fsize = khash.__ac_fsize(n_buckets) if n_buckets > 0 else 0
            flags = Ptr[u32](fsize)
            keys = Ptr[K](n_buckets)
            _read_raw(jar, flags.as_byte(), fsize * sizeof(u32))
            _read_raw(jar, keys.as_byte(), n_buckets * sizeof(K))

            s._n_buckets = n_buckets
            s._size = size
            s._n_occupied = n_occupied
            s._upper_bound = upper_bound
            s._flags = flags
            s._keys = keys
        else:
            n_buckets = unpickle(jar, int)
            size = unpickle(jar, int)
            s.resize(n_buckets)
            i = 0
            while i < size:
                k = unpickle(jar, K)
                s.add(k)
                i += 1
        return s