# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>

import os

if __py_extension__:
    from internal.cpy_static import *
else:
    from internal.cpy_dlopen import *

_PY_MODULE_CACHE = Dict[str, pyobj]()

_PY_INIT = """
import io

clsf = None
clsa = None
plt = None
try:
    import matplotlib.figure
    import matplotlib.pyplot
    plt = matplotlib.pyplot
    clsf = matplotlib.figure.Figure
    clsa = matplotlib.artist.Artist
except ModuleNotFoundError:
    pass

def __codon_repr__(fig):
    if clsf and isinstance(fig, clsf):
        stream = io.StringIO()
        fig.savefig(stream, format="svg")
        return 'image/svg+xml', stream.getvalue()
    elif clsa and isinstance(fig, list) and all(
        isinstance(i, clsa) for i in fig
    ):
        stream = io.StringIO()
        plt.gcf().savefig(stream, format="svg")
        return 'image/svg+xml', stream.getvalue()
    elif hasattr(fig, "_repr_html_"):
        return 'text/html', fig._repr_html_()
    else:
        return 'text/plain', fig.__repr__()
"""

_PY_INITIALIZED = False

def init_error_py_types():
    BaseException._pytype = PyExc_BaseException
    Exception._pytype = PyExc_Exception
    NameError._pytype = PyExc_NameError
    OSError._pytype = PyExc_OSError
    IOError._pytype = PyExc_IOError
    ValueError._pytype = PyExc_ValueError
    LookupError._pytype = PyExc_LookupError
    IndexError._pytype = PyExc_IndexError
    KeyError._pytype = PyExc_KeyError
    TypeError._pytype = PyExc_TypeError
    ArithmeticError._pytype = PyExc_ArithmeticError
    ZeroDivisionError._pytype = PyExc_ZeroDivisionError
    OverflowError._pytype = PyExc_OverflowError
    AttributeError._pytype = PyExc_AttributeError
    RuntimeError._pytype = PyExc_RuntimeError
    NotImplementedError._pytype = PyExc_NotImplementedError
    StopIteration._pytype = PyExc_StopIteration
    AssertionError._pytype = PyExc_AssertionError
    SystemExit._pytype = PyExc_SystemExit

def setup_python(python_loaded: bool):
    global _PY_INITIALIZED
    if _PY_INITIALIZED:
        return

    py_handle = cobj()
    if python_loaded:
        py_handle = dlopen("", RTLD_LOCAL | RTLD_NOW)
    else:
        LD = os.getenv("CODON_PYTHON", default="libpython." + dlext())
        py_handle = dlopen(LD, RTLD_LOCAL | RTLD_NOW)

    init_dl_handles(py_handle)
    init_error_py_types()

    if not python_loaded:
        Py_Initialize()

    _PY_INITIALIZED = True

def ensure_initialized(python_loaded: bool = False):
    if not __py_extension__:
        setup_python(python_loaded)
        PyRun_SimpleString(_PY_INIT.c_str())

def setup_decorator():
    setup_python(True)

@extend
class pyobj:
    @__internal__
    def __new__() -> pyobj:
        pass

    def __raw__(self) -> Ptr[byte]:
        return __internal__.class_raw(self)

    def __init__(self, p: Ptr[byte], steal: bool = False):
        self.p = p
        if not steal:
            self.incref()

    def __del__(self):
        self.decref()

    def _getattr(self, name: str) -> pyobj:
        return pyobj(pyobj.exc_wrap(PyObject_GetAttrString(self.p, name.c_str())), steal=True)

    def __add__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Add(self.p, other.__to_py__())), steal=True)

    def __radd__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Add(other.__to_py__(), self.p)), steal=True)

    def __sub__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Subtract(self.p, other.__to_py__())), steal=True)

    def __rsub__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Subtract(other.__to_py__(), self.p)), steal=True)

    def __mul__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Multiply(self.p, other.__to_py__())), steal=True)

    def __rmul__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Multiply(other.__to_py__(), self.p)), steal=True)

    def __matmul__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_MatrixMultiply(self.p, other.__to_py__())), steal=True)

    def __rmatmul__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_MatrixMultiply(other.__to_py__(), self.p)), steal=True)

    def __floordiv__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_FloorDivide(self.p, other.__to_py__())), steal=True)

    def __rfloordiv__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_FloorDivide(other.__to_py__(), self.p)), steal=True)

    def __truediv__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_TrueDivide(self.p, other.__to_py__())), steal=True)

    def __rtruediv__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_TrueDivide(other.__to_py__(), self.p)), steal=True)

    def __mod__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Remainder(self.p, other.__to_py__())), steal=True)

    def __rmod__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Remainder(other.__to_py__(), self.p)), steal=True)

    def __divmod__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Divmod(self.p, other.__to_py__())), steal=True)

    def __rdivmod__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Divmod(other.__to_py__(), self.p)), steal=True)

    def __pow__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Power(self.p, other.__to_py__(), Py_None)), steal=True)

    def __rpow__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Power(other.__to_py__(), self.p, Py_None)), steal=True)

    def __neg__(self):
        return pyobj(pyobj.exc_wrap(PyNumber_Negative(self.p)), steal=True)

    def __pos__(self):
        return pyobj(pyobj.exc_wrap(PyNumber_Positive(self.p)), steal=True)

    def __invert__(self):
        return pyobj(pyobj.exc_wrap(PyNumber_Invert(self.p)), steal=True)

    def __lshift__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Lshift(self.p, other.__to_py__())), steal=True)

    def __rlshift__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Lshift(other.__to_py__(), self.p)), steal=True)

    def __rshift__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Rshift(self.p, other.__to_py__())), steal=True)

    def __rrshift__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Rshift(other.__to_py__(), self.p)), steal=True)

    def __and__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_And(self.p, other.__to_py__())), steal=True)

    def __rand__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_And(other.__to_py__(), self.p)), steal=True)

    def __xor__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Xor(self.p, other.__to_py__())), steal=True)

    def __rxor__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Xor(other.__to_py__(), self.p)), steal=True)

    def __or__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Or(self.p, other.__to_py__())), steal=True)

    def __ror__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_Or(other.__to_py__(), self.p)), steal=True)

    def __iadd__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceAdd(self.p, other.__to_py__())), steal=True)

    def __isub__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceSubtract(self.p, other.__to_py__())), steal=True)

    def __imul__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceMultiply(self.p, other.__to_py__())), steal=True)

    def __imatmul__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceMatrixMultiply(self.p, other.__to_py__())), steal=True)

    def __ifloordiv__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceFloorDivide(self.p, other.__to_py__())), steal=True)

    def __itruediv__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceTrueDivide(self.p, other.__to_py__())), steal=True)

    def __imod__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceRemainder(self.p, other.__to_py__())), steal=True)

    def __ipow__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlacePower(self.p, other.__to_py__(), Py_None)), steal=True)

    def __ilshift__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceLshift(self.p, other.__to_py__())), steal=True)

    def __irshift__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceRshift(self.p, other.__to_py__())), steal=True)

    def __iand__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceAnd(self.p, other.__to_py__())), steal=True)

    def __ixor__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceXor(self.p, other.__to_py__())), steal=True)

    def __ior__(self, other):
        return pyobj(pyobj.exc_wrap(PyNumber_InPlaceOr(self.p, other.__to_py__())), steal=True)

    def __int__(self):
        o = pyobj.exc_wrap(PyNumber_Long(self.p))
        x = int.__from_py__(o)
        pyobj.decref(o)
        return x

    def __float__(self):
        o = pyobj.exc_wrap(PyNumber_Float(self.p))
        x = float.__from_py__(o)
        pyobj.decref(o)
        return x

    def __index__(self):
        o = pyobj.exc_wrap(PyNumber_Index(self.p))
        x = int.__from_py__(o)
        pyobj.decref(o)
        return x

    def __len__(self) -> int:
        return pyobj.exc_wrap(PyObject_Length(self.p))

    def __length_hint__(self) -> int:
        return pyobj.exc_wrap(PyObject_LengthHint(self.p))

    def __getitem__(self, key):
        return pyobj(pyobj.exc_wrap(PyObject_GetItem(self.p, key.__to_py__())), steal=True)

    def __setitem__(self, key, v):
        pyobj.exc_wrap(PyObject_SetItem(self.p, key.__to_py__(), v.__to_py__()))

    def __delitem__(self, key):
        return pyobj.exc_wrap(PyObject_DelItem(self.p, key.__to_py__()))

    def __lt__(self, other):
        return pyobj(pyobj.exc_wrap(PyObject_RichCompare(self.p, other.__to_py__(), i32(Py_LT))), steal=True)

    def __le__(self, other):
        return pyobj(pyobj.exc_wrap(PyObject_RichCompare(self.p, other.__to_py__(), i32(Py_LE))), steal=True)

    def __eq__(self, other):
        return pyobj(pyobj.exc_wrap(PyObject_RichCompare(self.p, other.__to_py__(), i32(Py_EQ))), steal=True)

    def __ne__(self, other):
        return pyobj(pyobj.exc_wrap(PyObject_RichCompare(self.p, other.__to_py__(), i32(Py_NE))), steal=True)

    def __gt__(self, other):
        return pyobj(pyobj.exc_wrap(PyObject_RichCompare(self.p, other.__to_py__(), i32(Py_GT))), steal=True)

    def __ge__(self, other):
        return pyobj(pyobj.exc_wrap(PyObject_RichCompare(self.p, other.__to_py__(), i32(Py_GE))), steal=True)

    def __to_py__(self) -> cobj:
        return self.p

    def __from_py__(p: cobj) -> pyobj:
        return pyobj(p)

    def __str__(self) -> str:
        o = pyobj.exc_wrap(PyObject_Str(self.p))
        return pyobj.exc_wrap(str.__from_py__(o))

    def __repr__(self) -> str:
        o = pyobj.exc_wrap(PyObject_Repr(self.p))
        return pyobj.exc_wrap(str.__from_py__(o))

    def __hash__(self) -> int:
        return pyobj.exc_wrap(PyObject_Hash(self.p))

    def __iter__(self) -> Generator[pyobj]:
        it = PyObject_GetIter(self.p)
        if not it:
            raise ValueError("Python object is not iterable")
        try:
            while i := PyIter_Next(it):
                yield pyobj(pyobj.exc_wrap(i), steal=True)
        finally:
            pyobj.decref(it)
        pyobj.exc_check()

    def to_str(self, errors: str, empty: str = "") -> str:
        return pyobj.to_str(self.p, errors, empty)

    def to_str(p: cobj, errors: str, empty: str = "") -> str:
        obj = PyUnicode_AsEncodedString(p, "utf-8".c_str(), errors.c_str())
        if obj == cobj():
            return empty
        bts = PyBytes_AsString(obj)
        res = str.from_ptr(bts)
        pyobj.decref(obj)
        return res

    def exc_check():
        ptype, pvalue, ptraceback = cobj(), cobj(), cobj()
        PyErr_Fetch(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback))
        PyErr_NormalizeException(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback))
        if ptype != cobj():
            py_msg = PyObject_Str(pvalue) if pvalue != cobj() else pvalue
            msg = pyobj.to_str(py_msg, "ignore", "<empty Python message>")

            pyobj.decref(ptype)
            pyobj.decref(ptraceback)
            pyobj.decref(py_msg)

            # pyobj.decref(pvalue)
            raise PyError(msg, pyobj(pvalue))

    def exc_wrap(_retval: T, T: type) -> T:
        pyobj.exc_check()
        return _retval

    def incref(self):
        Py_IncRef(self.p)
        return self

    def incref(ptr: Ptr[byte]):
        Py_IncRef(ptr)

    def decref(self):
        Py_DecRef(self.p)
        return self

    def decref(ptr: Ptr[byte]):
        Py_DecRef(ptr)

    def __call__(self, *args, **kwargs):
        args_py = args.__to_py__()
        kws_py = cobj()
        if staticlen(kwargs) > 0:
            names = iter(kwargs.__dict__())
            kws = {next(names): pyobj(i.__to_py__(), steal=True) for i in kwargs}
            kws_py = kws.__to_py__()
        return pyobj(pyobj.exc_wrap(PyObject_Call(self.p, args_py, kws_py)), steal=True)

    def _tuple_new(length: int):
        return pyobj.exc_wrap(PyTuple_New(length))

    def _tuple_size(p: cobj):
        return pyobj.exc_wrap(PyTuple_Size(p))

    def _tuple_set(p: cobj, idx: int, val: cobj):
        PyTuple_SetItem(p, idx, val)
        pyobj.exc_check()

    def _tuple_get(p: cobj, idx: int) -> cobj:
        return pyobj.exc_wrap(PyTuple_GetItem(p, idx))

    def _import(name: str) -> pyobj:
        ensure_initialized()
        if name in _PY_MODULE_CACHE:
            return _PY_MODULE_CACHE[name]
        m = pyobj(pyobj.exc_wrap(PyImport_ImportModule(name.c_str())), steal=True)
        _PY_MODULE_CACHE[name] = m
        return m

    def _exec(code: str):
        ensure_initialized()
        PyRun_SimpleString(code.c_str())

    def _globals() -> pyobj:
        p = PyEval_GetGlobals()
        if p == cobj():
            Py_IncRef(Py_None)
            return pyobj(Py_None)
        return pyobj(p)

    def _builtins() -> pyobj:
        return pyobj(PyEval_GetBuiltins())

    def _get_module(name: str) -> pyobj:
        p = pyobj(pyobj.exc_wrap(PyImport_AddModule(name.c_str())))
        return p

    def _main_module() -> pyobj:
        return pyobj._get_module("__main__")

    def _repr_mimebundle_(self, bundle=Set[str]()) -> Dict[str, str]:
        fn = pyobj._main_module()._getattr("__codon_repr__")
        assert fn.p != cobj(), "cannot find python.__codon_repr__"
        mime, txt = Tuple[str, str].__from_py__(fn.__call__(self).p)
        return {mime: txt}

    def __bool__(self):
        return bool(pyobj.exc_wrap(PyObject_IsTrue(self.p) == 1))

def _get_identifier(typ: str) -> pyobj:
    t = pyobj._builtins()[typ]
    if t.p == cobj():
        t = pyobj._main_module()[typ]
    return t

def _isinstance(what: pyobj, typ: pyobj) -> bool:
    return bool(pyobj.exc_wrap(PyObject_IsInstance(what.p, typ.p)))

def _extension_bad_args(got: int, expected: int) -> bool:
    if got != expected:
        PyErr_SetString(PyExc_RuntimeError,
                        f"expected {expected} arguments, but got {got}".c_str())
        return True
    return False

# Type conversions

@extend
class NoneType:
    def __to_py__(self) -> cobj:
        Py_IncRef(Py_None)
        return Py_None

    def __from_py__(i: cobj) -> None:
        return

@extend
class int:
    def __to_py__(self) -> cobj:
        return pyobj.exc_wrap(PyLong_FromLong(self))

    def __from_py__(i: cobj) -> int:
        return pyobj.exc_wrap(PyLong_AsLong(i))

@extend
class float:
    def __to_py__(self) -> cobj:
        return pyobj.exc_wrap(PyFloat_FromDouble(self))

    def __from_py__(d: cobj) -> float:
        return pyobj.exc_wrap(PyFloat_AsDouble(d))

@extend
class bool:
    def __to_py__(self) -> cobj:
        return pyobj.exc_wrap(PyBool_FromLong(int(self)))

    def __from_py__(b: cobj) -> bool:
        return pyobj.exc_wrap(PyObject_IsTrue(b)) != 0

@extend
class byte:
    def __to_py__(self) -> cobj:
        return str.__to_py__(str(__ptr__(self), 1))

    def __from_py__(c: cobj) -> byte:
        return str.__from_py__(c).ptr[0]

@extend
class str:
    def __to_py__(self) -> cobj:
        return pyobj.exc_wrap(PyUnicode_DecodeFSDefaultAndSize(self.ptr, self.len))

    def __from_py__(s: cobj) -> str:
        return pyobj.exc_wrap(pyobj.to_str(s, "strict"))

@extend
class complex:
    def __to_py__(self) -> cobj:
        return pyobj.exc_wrap(PyComplex_FromDoubles(self.real, self.imag))

    def __from_py__(c: cobj) -> complex:
        real = pyobj.exc_wrap(PyComplex_RealAsDouble(c))
        imag = pyobj.exc_wrap(PyComplex_ImagAsDouble(c))
        return complex(real, imag)

@extend
class List:
    def __to_py__(self) -> cobj:
        pylist = PyList_New(len(self))
        pyobj.exc_check()
        idx = 0
        for a in self:
            PyList_SetItem(pylist, idx, a.__to_py__())
            pyobj.exc_check()
            idx += 1
        return pylist

    def __from_py__(v: cobj) -> List[T]:
        n = pyobj.exc_wrap(PyObject_Length(v))
        t = List[T](n)
        for i in range(n):
            elem = PyList_GetItem(v, i)
            pyobj.exc_check()
            t.append(T.__from_py__(elem))
        return t

@extend
class Dict:
    def __to_py__(self) -> cobj:
        pydict = PyDict_New()
        pyobj.exc_check()
        for k, v in self.items():
            PyDict_SetItem(pydict, k.__to_py__(), v.__to_py__())
            pyobj.exc_check()
        return pydict

    def __from_py__(d: cobj) -> Dict[K, V]:
        b = dict[K, V]()
        pos = 0
        k_ptr = cobj()
        v_ptr = cobj()
        while PyDict_Next(d, __ptr__(pos), __ptr__(k_ptr), __ptr__(v_ptr)):
            pyobj.exc_check()
            k = K.__from_py__(k_ptr)
            v = V.__from_py__(v_ptr)
            b[k] = v
        return b

@extend
class Set:
    def __to_py__(self) -> cobj:
        pyset = PySet_New(cobj())
        pyobj.exc_check()
        for a in self:
            PySet_Add(pyset, a.__to_py__())
            pyobj.exc_check()
        return pyset

    def __from_py__(s: cobj) -> Set[K]:
        b = set[K]()
        s_iter = pyobj.exc_wrap(PyObject_GetIter(s))
        while True:
            k_ptr = pyobj.exc_wrap(PyIter_Next(s_iter))
            if not k_ptr:
                break
            k = K.__from_py__(k_ptr)
            pyobj.decref(k_ptr)
            b.add(k)
        pyobj.decref(s_iter)
        return b

@extend
class DynamicTuple:
    def __to_py__(self) -> cobj:
        pytup = PyTuple_New(len(self))
        i = 0
        for a in self:
            PyTuple_SetItem(pytup, i, a.__to_py__())
            pyobj.exc_check()
            i += 1
        return pytup

    def __from_py__(t: cobj) -> DynamicTuple[T]:
        n = pyobj.exc_wrap(PyTuple_Size(t))
        p = Ptr[T](n)
        for i in range(n):
            p[i] = T.__from_py__(PyTuple_GetItem(t, i))
        return DynamicTuple(p, n)

@extend
class Slice:
    def __to_py__(self) -> cobj:
        start = self.start
        stop = self.stop
        step = self.step
        start_py = start.__to_py__() if start is not None else cobj()
        stop_py = stop.__to_py__() if stop is not None else cobj()
        step_py = step.__to_py__() if step is not None else cobj()
        return PySlice_New(start_py, stop_py, step_py)

    def __from_py__(s: cobj) -> Slice:
        start = 0
        stop = 0
        step = 0
        pyobj.exc_wrap(PySlice_Unpack(s, __ptr__(start), __ptr__(stop), __ptr__(step)))
        return Slice(Optional(start), Optional(stop), Optional(step))

@extend
class Optional:
    def __to_py__(self) -> cobj:
        if self is None:
            return Py_None
        else:
            return self.__val__().__to_py__()

    def __from_py__(o: cobj) -> Optional[T]:
        if o == Py_None:
            return Optional[T]()
        else:
            return Optional[T](T.__from_py__(o))


__pyenv__: Optional[pyobj] = None
def _____(): __pyenv__  # make it global!