# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

from .ndarray import ndarray
from .npdatetime import datetime64, timedelta64
import util

# https://github.com/numpy/numpy/blob/main/numpy/_core/include/numpy/ndarraytypes.h
NPY_BOOL: Static[int] = 0
NPY_BYTE: Static[int] = 1
NPY_UBYTE: Static[int] = 2
NPY_SHORT: Static[int] = 3
NPY_USHORT: Static[int] = 4
NPY_INT: Static[int] = 5
NPY_UINT: Static[int] = 6
NPY_LONG: Static[int] = 7
NPY_ULONG: Static[int] = 8
NPY_LONGLONG: Static[int] = 9
NPY_ULONGLONG: Static[int] = 10
NPY_FLOAT: Static[int] = 11
NPY_DOUBLE: Static[int] = 12
NPY_LONGDOUBLE: Static[int] = 13
NPY_CFLOAT: Static[int] = 14
NPY_CDOUBLE: Static[int] = 15
NPY_CLONGDOUBLE: Static[int] = 16
NPY_OBJECT: Static[int] = 17
NPY_STRING: Static[int] = 18
NPY_UNICODE: Static[int] = 19
NPY_VOID: Static[int] = 20
NPY_DATETIME: Static[int] = 21
NPY_TIMEDELTA: Static[int] = 22
NPY_HALF: Static[int] = 23
NPY_NTYPES: Static[int] = 24
NPY_NOTYPE: Static[int] = 25
NPY_CHAR: Static[int] = 26
NPY_STRING: Static[int] = 18
NPY_USERDEF: Static[int] = 256
NPY_NTYPES_ABI_COMPATIBLE: Static[int] = 21

def _type_code(dtype: type):
    if dtype is bool:
        return NPY_BOOL

    if dtype is i8:
        return NPY_BYTE

    if dtype is byte or dtype is u8:
        return NPY_UBYTE

    if dtype is i16:
        return NPY_SHORT

    if dtype is u16:
        return NPY_USHORT

    if dtype is i32:
        return NPY_INT

    if dtype is u32:
        return NPY_UINT

    if dtype is int or dtype is i64:
        return NPY_LONG

    if dtype is u64:
        return NPY_ULONG

    if dtype is float16:
        return NPY_HALF

    if dtype is float32:
        return NPY_FLOAT

    if dtype is float:
        return NPY_DOUBLE

    if dtype is complex64:
        return NPY_CFLOAT

    if dtype is complex:
        return NPY_CDOUBLE

    if isinstance(dtype, datetime64):
        return NPY_DATETIME

    if isinstance(dtype, timedelta64):
        return NPY_TIMEDELTA

    # TODO: add other types like string etc.
    #       once we have them.

    return NPY_OBJECT

# https://github.com/numpy/numpy/blob/main/numpy/_core/include/numpy/ndarraytypes.h
@tuple
class PyObject:
    refcnt: int
    typptr: cobj

@tuple
class NpyAuxData:
    free_func: cobj
    clone_func: cobj
    data1: cobj
    data2: cobj

@tuple
class PyArray_DatetimeMetaData:
    base: i32
    num: i32

@tuple
class PyArray_DatetimeDTypeMetaData:
    base: NpyAuxData
    meta: PyArray_DatetimeMetaData

@tuple
class PyArrayDescr:
    head: PyObject
    typeobj: cobj
    kind: u8
    type: u8
    byteorder: u8
    flags: u8
    type_num: i32
    elsize: i32
    alignment: i32
    subarray: cobj
    fields: cobj
    names: cobj
    f: cobj
    metadata: cobj
    c_metadata: cobj
    hash: int

@tuple
class PyArrayObject:
    head: PyObject
    data: cobj
    nd: i32
    dimensions: Ptr[int]
    strides: Ptr[int]
    base: cobj
    descr: Ptr[PyArrayDescr]
    flags: i32
    weakreflist: cobj

NPY_ARRAY_C_CONTIGUOUS: Static[int] = 1
NPY_ARRAY_F_CONTIGUOUS: Static[int] = 2

PyArray_Type = cobj()
PyArray_New = Function[[cobj, i32, cobj, i32, cobj, cobj, i32, i32, cobj], cobj](cobj())

def _pyobj_type(p: cobj):
    return Ptr[PyObject](p)[0].typptr

def _set_datetime_descr(a: Ptr[PyArrayObject], T: type):
    p = Ptr[PyArray_DatetimeDTypeMetaData](a[0].descr[0].c_metadata)
    meta = PyArray_DatetimeMetaData(i32(T._code()), i32(T.num))
    p[0] = PyArray_DatetimeDTypeMetaData(p[0].base, meta)

def _setup_numpy_bridge():
    import python
    from internal.python import PyImport_ImportModule, PyObject_GetAttrString, PyCapsule_Type, PyCapsule_GetPointer
    global PyArray_Type, PyArray_New

    module = PyImport_ImportModule("numpy.core._multiarray_umath".ptr)

    if not module:
        raise RuntimeError("Failed to import 'numpy.core._multiarray_umath'")

    attr = PyObject_GetAttrString(module, "_ARRAY_API".ptr)

    if not attr or _pyobj_type(attr) != PyCapsule_Type:
        raise RuntimeError("NumPy API object not found or did not have type 'capsule'")

    api = Ptr[cobj](PyCapsule_GetPointer(attr, cobj()))
    # https://github.com/numpy/numpy/blob/main/numpy/_core/code_generators/numpy_api.py
    PyArray_Type = api[2]
    PyArray_New = Function[[cobj, i32, cobj, i32, cobj, cobj, i32, i32, cobj], cobj](api[93])

_setup_numpy_bridge()

@extend
class ndarray:
    def __to_py__(self):
        dims = self.shape
        code = _type_code(self.dtype)
        arr = PyArray_New(PyArray_Type, i32(self.ndim), __ptr__(dims).as_byte(),
                          i32(code), cobj(), cobj(), i32(0), i32(0), cobj())
        arr_ptr = Ptr[PyArrayObject](arr)
        data = arr_ptr[0].data

        if code == NPY_OBJECT:
            p = Ptr[cobj](data)
            k = 0
            for idx in util.multirange(dims):
                e = self._ptr(idx)[0]
                if hasattr(e, "__to_py__"):
                    p[k] = e.__to_py__()
                k += 1
        else:
            cc, _ = self._contig
            if cc:
                str.memcpy(data, self.data.as_byte(), self.nbytes)
            else:
                p = Ptr[self.dtype](data)
                k = 0
                for idx in util.multirange(dims):
                    e = self._ptr(idx)[0]
                    p[k] = e
                    k += 1

        if isinstance(self.dtype, datetime64) or isinstance(self.dtype, timedelta64):
            _set_datetime_descr(arr_ptr, self.dtype)

        return arr

    def _from_py(a: cobj, copy: bool):
        if _pyobj_type(a) != PyArray_Type:
            raise PyError("NumPy conversion error: Python object did not have array type")

        arr = Ptr[PyArrayObject](a)[0]

        if int(arr.nd) != ndim:
            raise PyError("NumPy conversion error: Python array has incorrect dimension")

        code = _type_code(dtype)
        if int(arr.descr[0].type_num) != code:
            raise PyError("NumPy conversion error: Python array has incorrect dtype")

        arr_data = arr.data
        arr_shape = arr.dimensions
        arr_strides = arr.strides
        arr_flags = int(arr.flags)
        shape = tuple(arr_shape[i] for i in staticrange(ndim))
        strides = tuple(arr_strides[i] for i in staticrange(ndim))
        size = util.count(shape)

        if code != NPY_OBJECT and not copy:
            return ndarray(shape, strides, Ptr[dtype](arr_data))

        data = Ptr[dtype](size)

        if code == NPY_OBJECT:
            k = 0
            for idx in util.multirange(shape):
                off = 0
                for i in range(ndim):
                    off += idx[i] * strides[i]
                e = Ptr[cobj](arr_data + off)[0]
                if hasattr(dtype, "__from_py__"):
                    data[k] = dtype.__from_py__(e)
                k += 1
            return ndarray(shape, data)
        else:
            cc = (arr_flags & NPY_ARRAY_C_CONTIGUOUS != 0)
            fc = (arr_flags & NPY_ARRAY_F_CONTIGUOUS != 0)

            if cc or fc:
                str.memcpy(data.as_byte(), arr_data, size * util.sizeof(dtype))
                return ndarray(shape, strides, data)
            else:
                k = 0
                for idx in util.multirange(shape):
                    off = 0
                    for i in range(ndim):
                        off += idx[i] * strides[i]
                    e = Ptr[dtype](arr_data + off)[0]
                    data[k] = e
                    k += 1
                return ndarray(shape, data)

    def __from_py__(a: cobj):
        return ndarray[dtype, ndim]._from_py(a, copy=False)