import os
from internal.dlopen import dlext

PyUnicode_AsEncodedString = Function[[cobj, cobj, cobj], cobj](cobj())
PyBytes_AsString = Function[[cobj], cobj](cobj())
PyErr_Fetch = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], void](cobj())
PyObject_GetAttrString = Function[[cobj, cobj], cobj](cobj())
PyObject_GetAttr = Function[[cobj, cobj], cobj](cobj())
PyObject_Str = Function[[cobj], cobj](cobj())
PyRun_SimpleString = Function[[cobj], void](cobj())
Py_IncRef = Function[[cobj], void](cobj())
Py_DecRef = Function[[cobj], void](cobj())
PyObject_Call = Function[[cobj, cobj, cobj], cobj](cobj())
PyObject_SetAttrString = Function[[cobj, cobj, cobj], cobj](cobj())
PyObject_Length = Function[[cobj], int](cobj())
Py_Initialize = Function[[], void](cobj())
PyImport_ImportModule = Function[[cobj], cobj](cobj())
PyLong_FromLong = Function[[int], cobj](cobj())
PyLong_AsLong = Function[[cobj], int](cobj())
PyFloat_FromDouble = Function[[float], cobj](cobj())
PyFloat_AsDouble = Function[[cobj], float](cobj())
PyBool_FromLong = Function[[int], cobj](cobj())
PyObject_IsTrue = Function[[cobj], int](cobj())
PyUnicode_DecodeFSDefaultAndSize = Function[[cobj, int], cobj](cobj())
PyTuple_New = Function[[int], cobj](cobj())
PyTuple_SetItem = Function[[cobj, int, cobj], void](cobj())
PyTuple_GetItem = Function[[cobj, int], cobj](cobj())
PyList_New = Function[[int], cobj](cobj())
PyList_SetItem = Function[[cobj, int, cobj], cobj](cobj())
PyList_GetItem = Function[[cobj, int], cobj](cobj())
PySet_New = Function[[cobj], cobj](cobj())
PySet_Add = Function[[cobj, cobj], cobj](cobj())
PyDict_New = Function[[], cobj](cobj())
PyDict_SetItem = Function[[cobj, cobj, cobj], cobj](cobj())
PyDict_Next = Function[[cobj, Ptr[int], Ptr[cobj], Ptr[cobj]], int](cobj())
PyObject_GetIter = Function[[cobj], cobj](cobj())
PyIter_Next = Function[[cobj], cobj](cobj())

_PY_MODULE_CACHE = Dict[str, pyobj]()

_PY_INITIALIZED = False
def init():
    global _PY_INITIALIZED
    if _PY_INITIALIZED:
        return

    LD = os.getenv('CODON_PYTHON', default='libpython.' + dlext())

    global PyUnicode_AsEncodedString
    PyUnicode_AsEncodedString = _dlsym(LD, "PyUnicode_AsEncodedString")
    global PyBytes_AsString
    PyBytes_AsString = _dlsym(LD, "PyBytes_AsString")
    global PyErr_Fetch
    PyErr_Fetch = _dlsym(LD, "PyErr_Fetch")
    global PyObject_GetAttrString
    PyObject_GetAttrString = _dlsym(LD, "PyObject_GetAttrString")
    global PyObject_GetAttr
    PyObject_GetAttr = _dlsym(LD, "PyObject_GetAttr")
    global PyObject_Str
    PyObject_Str = _dlsym(LD, "PyObject_Str")
    global PyRun_SimpleString
    PyRun_SimpleString = _dlsym(LD, "PyRun_SimpleString")
    global Py_IncRef
    Py_IncRef = _dlsym(LD, "Py_IncRef")
    global Py_DecRef
    Py_DecRef = _dlsym(LD, "Py_DecRef")
    global PyObject_Call
    PyObject_Call = _dlsym(LD, "PyObject_Call")
    global PyObject_SetAttrString
    PyObject_SetAttrString = _dlsym(LD, "PyObject_SetAttrString")
    global PyObject_Length
    PyObject_Length = _dlsym(LD, "PyObject_Length")
    global Py_Initialize
    Py_Initialize = _dlsym(LD, "Py_Initialize")
    global PyImport_ImportModule
    PyImport_ImportModule = _dlsym(LD, "PyImport_ImportModule")
    global PyLong_FromLong
    PyLong_FromLong = _dlsym(LD, "PyLong_FromLong")
    global PyLong_AsLong
    PyLong_AsLong = _dlsym(LD, "PyLong_AsLong")
    global PyFloat_FromDouble
    PyFloat_FromDouble = _dlsym(LD, "PyFloat_FromDouble")
    global PyFloat_AsDouble
    PyFloat_AsDouble = _dlsym(LD, "PyFloat_AsDouble")
    global PyBool_FromLong
    PyBool_FromLong = _dlsym(LD, "PyBool_FromLong")
    global PyObject_IsTrue
    PyObject_IsTrue = _dlsym(LD, "PyObject_IsTrue")
    global PyUnicode_DecodeFSDefaultAndSize
    PyUnicode_DecodeFSDefaultAndSize = _dlsym(LD, "PyUnicode_DecodeFSDefaultAndSize")
    global PyTuple_New
    PyTuple_New = _dlsym(LD, "PyTuple_New")
    global PyTuple_SetItem
    PyTuple_SetItem = _dlsym(LD, "PyTuple_SetItem")
    global PyTuple_GetItem
    PyTuple_GetItem = _dlsym(LD, "PyTuple_GetItem")
    global PyList_New
    PyList_New = _dlsym(LD, "PyList_New")
    global PyList_SetItem
    PyList_SetItem = _dlsym(LD, "PyList_SetItem")
    global PyList_GetItem
    PyList_GetItem = _dlsym(LD, "PyList_GetItem")
    global PySet_New
    PySet_New = _dlsym(LD, "PySet_New")
    global PySet_Add
    PySet_Add = _dlsym(LD, "PySet_Add")
    global PyDict_New
    PyDict_New = _dlsym(LD, "PyDict_New")
    global PyDict_SetItem
    PyDict_SetItem = _dlsym(LD, "PyDict_SetItem")
    global PyDict_Next
    PyDict_Next = _dlsym(LD, "PyDict_Next")
    global PyObject_GetIter
    PyObject_GetIter = _dlsym(LD, "PyObject_GetIter")
    global PyIter_Next
    PyIter_Next = _dlsym(LD, "PyIter_Next")

    Py_Initialize()
    _PY_INITIALIZED = True

def ensure_initialized():
    if not _PY_INITIALIZED:
        init()
        # raise ValueError("Python not initialized; make sure to 'import python'")

@extend
class pyobj:
    def __new__(p: Ptr[byte]) -> pyobj:
        return (p, )

    def _getattr(self, name: str):
        return pyobj.exc_wrap(pyobj(PyObject_GetAttrString(self.p, name.c_str())))

    def __getitem__(self, t):
        return self._getattr("__getitem__")(t)

    def __add__(self, t):
        return self._getattr("__add__")(t)

    def __setitem__(self, name: str, val: pyobj):
        return pyobj.exc_wrap(pyobj(PyObject_SetAttrString(self.p, name.c_str(), val.p)))

    def __len__(self):
        return pyobj.exc_wrap(PyObject_Length(self.p))

    def __to_py__(self):
        return self

    def __from_py__(self):
        return self

    def __str__(self):
        return str.__from_py__(self._getattr("__str__").__call__())

    def __iter__(self):
        it = PyObject_GetIter(self.p)
        if not it:
            raise ValueError("Python object is not iterable")
        while i := PyIter_Next(it):
            yield pyobj(pyobj.exc_wrap(i))
            pyobj(i).decref()
        pyobj(it).decref()
        pyobj.exc_check()

    def to_str(self, errors: str, empty: str = "") -> str:
        obj = PyUnicode_AsEncodedString(self.p, "utf-8".c_str(), errors.c_str())
        if obj == cobj():
            return empty
        bts = PyBytes_AsString(obj)
        pyobj(obj).decref()
        return str.from_ptr(bts)

    def exc_check():
        ptype, pvalue, ptraceback = cobj(), cobj(), cobj()
        PyErr_Fetch(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback))
        if ptype != cobj():
            py_msg = PyObject_Str(pvalue) if pvalue != cobj() else pvalue
            msg = pyobj(py_msg).to_str("ignore", "<empty Python message>")
            typ = pyobj.to_str(pyobj(PyObject_GetAttrString(ptype, "__name__".c_str())), "ignore")

            pyobj(ptype).decref()
            pyobj(pvalue).decref()
            pyobj(ptraceback).decref()
            pyobj(py_msg).decref()

            raise PyError(msg, typ)

    def exc_wrap(_retval):
        pyobj.exc_check()
        return _retval

    def incref(self):
        Py_IncRef(self.p)

    def decref(self):
        Py_DecRef(self.p)

    def __call__(self, *args, **kwargs) -> pyobj:
        names = iter(kwargs.__dict__())
        kws = Dict[str, pyobj]()
        if staticlen(kwargs) > 0:
            kws = {next(names): i.__to_py__() for i in kwargs}
        return pyobj.exc_wrap(pyobj(PyObject_Call(self.p, args.__to_py__().p, kws.__to_py__().p)))

    def _tuple_new(length: int) -> pyobj:
        t = PyTuple_New(length)
        pyobj.exc_check()
        return pyobj(t)

    def _tuple_set(self, idx: int, val: pyobj):
        PyTuple_SetItem(self.p, idx, val.p)
        pyobj.exc_check()

    def _tuple_get(self, idx: int) -> pyobj:
        t = PyTuple_GetItem(self.p, idx)
        pyobj.exc_check()
        return pyobj(t)

    def _import(name: str):
        ensure_initialized()
        if name in _PY_MODULE_CACHE:
            return _PY_MODULE_CACHE[name]
        m = pyobj.exc_wrap(pyobj(PyImport_ImportModule(name.c_str())))
        _PY_MODULE_CACHE[name] = m
        return m

    def _exec(code: str):
        ensure_initialized()
        PyRun_SimpleString(code.c_str())

    def get[T](self) -> T:
        return T.__from_py__(self)

def none():
    raise NotImplementedError()


# Type conversions

def py(x) -> pyobj:
    return x.__to_py__()

def get[T](x: pyobj) -> T:
    return T.__from_py__(x)

@extend
class int:
    def __to_py__(self) -> pyobj:
        return pyobj.exc_wrap(pyobj(PyLong_FromLong(self)))

    def __from_py__(i: pyobj):
        return pyobj.exc_wrap(PyLong_AsLong(i.p))

@extend
class float:
    def __to_py__(self) -> pyobj:
        return pyobj.exc_wrap(pyobj(PyFloat_FromDouble(self)))

    def __from_py__(d: pyobj):
        return pyobj.exc_wrap(PyFloat_AsDouble(d.p))

@extend
class bool:
    def __to_py__(self) -> pyobj:
        return pyobj.exc_wrap(pyobj(PyBool_FromLong(int(self))))

    def __from_py__(b: pyobj):
        return pyobj.exc_wrap(PyObject_IsTrue(b.p)) != 0

@extend
class byte:
    def __to_py__(self) -> pyobj:
        return str.__to_py__(str(__ptr__(self), 1))

    def __from_py__(c: pyobj):
        return str.__from_py__(c).p[0]

@extend
class str:
    def __to_py__(self) -> pyobj:
        return pyobj.exc_wrap(pyobj(PyUnicode_DecodeFSDefaultAndSize(self.ptr, self.len)))

    def __from_py__(s: pyobj):
        r = s.to_str("strict")
        pyobj.exc_check()
        return r

@extend
class List:
    def __to_py__(self) -> pyobj:
        pylist = PyList_New(len(self))
        pyobj.exc_check()
        idx = 0
        for a in self:
            PyList_SetItem(pylist, idx, py(a).p)
            pyobj.exc_check()
            idx += 1
        return pyobj(pylist)

    def __from_py__(v: pyobj):
        n = v.__len__()
        t = List[T](n)
        for i in range(n):
            elem = pyobj(PyList_GetItem(v.p, i))
            pyobj.exc_check()
            t.append(get(elem, T))
        return t

@extend
class Dict:
    def __to_py__(self) -> pyobj:
        pydict = PyDict_New()
        pyobj.exc_check()
        for k,v in self.items():
            PyDict_SetItem(pydict, py(k).p, py(v).p)
            pyobj.exc_check()
        return pyobj(pydict)

    def __from_py__(d: pyobj):
        b = Dict[K,V]()
        pos = 0
        k_ptr = cobj()
        v_ptr = cobj()
        while PyDict_Next(d.p, __ptr__(pos), __ptr__(k_ptr), __ptr__(v_ptr)):
            pyobj.exc_check()
            k = get(pyobj(k_ptr), K)
            v = get(pyobj(v_ptr), V)
            b[k] = v
        return b

@extend
class Set:
    def __to_py__(self) -> pyobj:
        pyset = PySet_New(cobj())
        pyobj.exc_check()
        for a in self:
            PySet_Add(pyset, py(a).p)
            pyobj.exc_check()
        return pyobj(pyset)