# Copyright (C) 2022-2025 Exaloop Inc. from .ndarray import ndarray from .npdatetime import datetime64, timedelta64 import util # https://github.com/numpy/numpy/blob/main/numpy/_core/include/numpy/ndarraytypes.h NPY_BOOL: Static[int] = 0 NPY_BYTE: Static[int] = 1 NPY_UBYTE: Static[int] = 2 NPY_SHORT: Static[int] = 3 NPY_USHORT: Static[int] = 4 NPY_INT: Static[int] = 5 NPY_UINT: Static[int] = 6 NPY_LONG: Static[int] = 7 NPY_ULONG: Static[int] = 8 NPY_LONGLONG: Static[int] = 9 NPY_ULONGLONG: Static[int] = 10 NPY_FLOAT: Static[int] = 11 NPY_DOUBLE: Static[int] = 12 NPY_LONGDOUBLE: Static[int] = 13 NPY_CFLOAT: Static[int] = 14 NPY_CDOUBLE: Static[int] = 15 NPY_CLONGDOUBLE: Static[int] = 16 NPY_OBJECT: Static[int] = 17 NPY_STRING: Static[int] = 18 NPY_UNICODE: Static[int] = 19 NPY_VOID: Static[int] = 20 NPY_DATETIME: Static[int] = 21 NPY_TIMEDELTA: Static[int] = 22 NPY_HALF: Static[int] = 23 NPY_NTYPES: Static[int] = 24 NPY_NOTYPE: Static[int] = 25 NPY_CHAR: Static[int] = 26 NPY_STRING: Static[int] = 18 NPY_USERDEF: Static[int] = 256 NPY_NTYPES_ABI_COMPATIBLE: Static[int] = 21 def _type_code(dtype: type): if dtype is bool: return NPY_BOOL if dtype is i8: return NPY_BYTE if dtype is byte or dtype is u8: return NPY_UBYTE if dtype is i16: return NPY_SHORT if dtype is u16: return NPY_USHORT if dtype is i32: return NPY_INT if dtype is u32: return NPY_UINT if dtype is int or dtype is i64: return NPY_LONG if dtype is u64: return NPY_ULONG if dtype is float16: return NPY_HALF if dtype is float32: return NPY_FLOAT if dtype is float: return NPY_DOUBLE if dtype is complex64: return NPY_CFLOAT if dtype is complex: return NPY_CDOUBLE if isinstance(dtype, datetime64): return NPY_DATETIME if isinstance(dtype, timedelta64): return NPY_TIMEDELTA # TODO: add other types like string etc. # once we have them. return NPY_OBJECT # https://github.com/numpy/numpy/blob/main/numpy/_core/include/numpy/ndarraytypes.h @tuple class PyObject: refcnt: int typptr: cobj @tuple class NpyAuxData: free_func: cobj clone_func: cobj data1: cobj data2: cobj @tuple class PyArray_DatetimeMetaData: base: i32 num: i32 @tuple class PyArray_DatetimeDTypeMetaData: base: NpyAuxData meta: PyArray_DatetimeMetaData @tuple class PyArrayDescr: head: PyObject typeobj: cobj kind: u8 type: u8 byteorder: u8 flags: u8 type_num: i32 elsize: i32 alignment: i32 subarray: cobj fields: cobj names: cobj f: cobj metadata: cobj c_metadata: cobj hash: int @tuple class PyArrayObject: head: PyObject data: cobj nd: i32 dimensions: Ptr[int] strides: Ptr[int] base: cobj descr: Ptr[PyArrayDescr] flags: i32 weakreflist: cobj NPY_ARRAY_C_CONTIGUOUS: Static[int] = 1 NPY_ARRAY_F_CONTIGUOUS: Static[int] = 2 PyArray_Type = cobj() PyArray_New = Function[[cobj, i32, cobj, i32, cobj, cobj, i32, i32, cobj], cobj](cobj()) def _pyobj_type(p: cobj): return Ptr[PyObject](p)[0].typptr def _set_datetime_descr(a: Ptr[PyArrayObject], T: type): p = Ptr[PyArray_DatetimeDTypeMetaData](a[0].descr[0].c_metadata) meta = PyArray_DatetimeMetaData(i32(T._code()), i32(T.num)) p[0] = PyArray_DatetimeDTypeMetaData(p[0].base, meta) def _setup_numpy_bridge(): import python from internal.python import PyImport_ImportModule, PyObject_GetAttrString, PyCapsule_Type, PyCapsule_GetPointer global PyArray_Type, PyArray_New module = PyImport_ImportModule("numpy.core._multiarray_umath".ptr) if not module: raise RuntimeError("Failed to import 'numpy.core._multiarray_umath'") attr = PyObject_GetAttrString(module, "_ARRAY_API".ptr) if not attr or _pyobj_type(attr) != PyCapsule_Type: raise RuntimeError("NumPy API object not found or did not have type 'capsule'") api = Ptr[cobj](PyCapsule_GetPointer(attr, cobj())) # https://github.com/numpy/numpy/blob/main/numpy/_core/code_generators/numpy_api.py PyArray_Type = api[2] PyArray_New = Function[[cobj, i32, cobj, i32, cobj, cobj, i32, i32, cobj], cobj](api[93]) _setup_numpy_bridge() @extend class ndarray: def __to_py__(self): dims = self.shape code = _type_code(self.dtype) arr = PyArray_New(PyArray_Type, i32(self.ndim), __ptr__(dims).as_byte(), i32(code), cobj(), cobj(), i32(0), i32(0), cobj()) arr_ptr = Ptr[PyArrayObject](arr) data = arr_ptr[0].data if code == NPY_OBJECT: p = Ptr[cobj](data) k = 0 for idx in util.multirange(dims): e = self._ptr(idx)[0] if hasattr(e, "__to_py__"): p[k] = e.__to_py__() k += 1 else: cc, _ = self._contig if cc: str.memcpy(data, self.data.as_byte(), self.nbytes) else: p = Ptr[self.dtype](data) k = 0 for idx in util.multirange(dims): e = self._ptr(idx)[0] p[k] = e k += 1 if isinstance(self.dtype, datetime64) or isinstance(self.dtype, timedelta64): _set_datetime_descr(arr_ptr, self.dtype) return arr def _from_py(a: cobj, copy: bool): if _pyobj_type(a) != PyArray_Type: raise PyError("NumPy conversion error: Python object did not have array type") arr = Ptr[PyArrayObject](a)[0] if int(arr.nd) != ndim: raise PyError("NumPy conversion error: Python array has incorrect dimension") code = _type_code(dtype) if int(arr.descr[0].type_num) != code: raise PyError("NumPy conversion error: Python array has incorrect dtype") arr_data = arr.data arr_shape = arr.dimensions arr_strides = arr.strides arr_flags = int(arr.flags) shape = tuple(arr_shape[i] for i in staticrange(ndim)) strides = tuple(arr_strides[i] for i in staticrange(ndim)) size = util.count(shape) if code != NPY_OBJECT and not copy: return ndarray(shape, strides, Ptr[dtype](arr_data)) data = Ptr[dtype](size) if code == NPY_OBJECT: k = 0 for idx in util.multirange(shape): off = 0 for i in staticrange(ndim): off += idx[i] * strides[i] e = Ptr[cobj](arr_data + off)[0] if hasattr(dtype, "__from_py__"): data[k] = dtype.__from_py__(e) k += 1 return ndarray(shape, data) else: cc = (arr_flags & NPY_ARRAY_C_CONTIGUOUS != 0) fc = (arr_flags & NPY_ARRAY_F_CONTIGUOUS != 0) if cc or fc: str.memcpy(data.as_byte(), arr_data, size * util.sizeof(dtype)) return ndarray(shape, strides, data) else: k = 0 for idx in util.multirange(shape): off = 0 for i in staticrange(ndim): off += idx[i] * strides[i] e = Ptr[dtype](arr_data + off)[0] data[k] = e k += 1 return ndarray(shape, data) def __from_py__(a: cobj): return ndarray[dtype, ndim]._from_py(a, copy=False)