mirror of https://github.com/exaloop/codon.git
275 lines
7.4 KiB
Python
275 lines
7.4 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
from .ndarray import ndarray
|
|
from .npdatetime import datetime64, timedelta64
|
|
import util
|
|
|
|
# https://github.com/numpy/numpy/blob/main/numpy/_core/include/numpy/ndarraytypes.h
|
|
NPY_BOOL: Static[int] = 0
|
|
NPY_BYTE: Static[int] = 1
|
|
NPY_UBYTE: Static[int] = 2
|
|
NPY_SHORT: Static[int] = 3
|
|
NPY_USHORT: Static[int] = 4
|
|
NPY_INT: Static[int] = 5
|
|
NPY_UINT: Static[int] = 6
|
|
NPY_LONG: Static[int] = 7
|
|
NPY_ULONG: Static[int] = 8
|
|
NPY_LONGLONG: Static[int] = 9
|
|
NPY_ULONGLONG: Static[int] = 10
|
|
NPY_FLOAT: Static[int] = 11
|
|
NPY_DOUBLE: Static[int] = 12
|
|
NPY_LONGDOUBLE: Static[int] = 13
|
|
NPY_CFLOAT: Static[int] = 14
|
|
NPY_CDOUBLE: Static[int] = 15
|
|
NPY_CLONGDOUBLE: Static[int] = 16
|
|
NPY_OBJECT: Static[int] = 17
|
|
NPY_STRING: Static[int] = 18
|
|
NPY_UNICODE: Static[int] = 19
|
|
NPY_VOID: Static[int] = 20
|
|
NPY_DATETIME: Static[int] = 21
|
|
NPY_TIMEDELTA: Static[int] = 22
|
|
NPY_HALF: Static[int] = 23
|
|
NPY_NTYPES: Static[int] = 24
|
|
NPY_NOTYPE: Static[int] = 25
|
|
NPY_CHAR: Static[int] = 26
|
|
NPY_STRING: Static[int] = 18
|
|
NPY_USERDEF: Static[int] = 256
|
|
NPY_NTYPES_ABI_COMPATIBLE: Static[int] = 21
|
|
|
|
def _type_code(dtype: type):
|
|
if dtype is bool:
|
|
return NPY_BOOL
|
|
|
|
if dtype is i8:
|
|
return NPY_BYTE
|
|
|
|
if dtype is byte or dtype is u8:
|
|
return NPY_UBYTE
|
|
|
|
if dtype is i16:
|
|
return NPY_SHORT
|
|
|
|
if dtype is u16:
|
|
return NPY_USHORT
|
|
|
|
if dtype is i32:
|
|
return NPY_INT
|
|
|
|
if dtype is u32:
|
|
return NPY_UINT
|
|
|
|
if dtype is int or dtype is i64:
|
|
return NPY_LONG
|
|
|
|
if dtype is u64:
|
|
return NPY_ULONG
|
|
|
|
if dtype is float16:
|
|
return NPY_HALF
|
|
|
|
if dtype is float32:
|
|
return NPY_FLOAT
|
|
|
|
if dtype is float:
|
|
return NPY_DOUBLE
|
|
|
|
if dtype is complex64:
|
|
return NPY_CFLOAT
|
|
|
|
if dtype is complex:
|
|
return NPY_CDOUBLE
|
|
|
|
if isinstance(dtype, datetime64):
|
|
return NPY_DATETIME
|
|
|
|
if isinstance(dtype, timedelta64):
|
|
return NPY_TIMEDELTA
|
|
|
|
# TODO: add other types like string etc.
|
|
# once we have them.
|
|
|
|
return NPY_OBJECT
|
|
|
|
# https://github.com/numpy/numpy/blob/main/numpy/_core/include/numpy/ndarraytypes.h
|
|
@tuple
|
|
class PyObject:
|
|
refcnt: int
|
|
typptr: cobj
|
|
|
|
@tuple
|
|
class NpyAuxData:
|
|
free_func: cobj
|
|
clone_func: cobj
|
|
data1: cobj
|
|
data2: cobj
|
|
|
|
@tuple
|
|
class PyArray_DatetimeMetaData:
|
|
base: i32
|
|
num: i32
|
|
|
|
@tuple
|
|
class PyArray_DatetimeDTypeMetaData:
|
|
base: NpyAuxData
|
|
meta: PyArray_DatetimeMetaData
|
|
|
|
@tuple
|
|
class PyArrayDescr:
|
|
head: PyObject
|
|
typeobj: cobj
|
|
kind: u8
|
|
type: u8
|
|
byteorder: u8
|
|
flags: u8
|
|
type_num: i32
|
|
elsize: i32
|
|
alignment: i32
|
|
subarray: cobj
|
|
fields: cobj
|
|
names: cobj
|
|
f: cobj
|
|
metadata: cobj
|
|
c_metadata: cobj
|
|
hash: int
|
|
|
|
@tuple
|
|
class PyArrayObject:
|
|
head: PyObject
|
|
data: cobj
|
|
nd: i32
|
|
dimensions: Ptr[int]
|
|
strides: Ptr[int]
|
|
base: cobj
|
|
descr: Ptr[PyArrayDescr]
|
|
flags: i32
|
|
weakreflist: cobj
|
|
|
|
NPY_ARRAY_C_CONTIGUOUS: Static[int] = 1
|
|
NPY_ARRAY_F_CONTIGUOUS: Static[int] = 2
|
|
|
|
PyArray_Type = cobj()
|
|
PyArray_New = Function[[cobj, i32, cobj, i32, cobj, cobj, i32, i32, cobj], cobj](cobj())
|
|
|
|
def _pyobj_type(p: cobj):
|
|
return Ptr[PyObject](p)[0].typptr
|
|
|
|
def _set_datetime_descr(a: Ptr[PyArrayObject], T: type):
|
|
p = Ptr[PyArray_DatetimeDTypeMetaData](a[0].descr[0].c_metadata)
|
|
meta = PyArray_DatetimeMetaData(i32(T._code()), i32(T.num))
|
|
p[0] = PyArray_DatetimeDTypeMetaData(p[0].base, meta)
|
|
|
|
def _setup_numpy_bridge():
|
|
import python
|
|
from internal.python import PyImport_ImportModule, PyObject_GetAttrString, PyCapsule_Type, PyCapsule_GetPointer
|
|
global PyArray_Type, PyArray_New
|
|
|
|
module = PyImport_ImportModule("numpy.core._multiarray_umath".ptr)
|
|
|
|
if not module:
|
|
raise RuntimeError("Failed to import 'numpy.core._multiarray_umath'")
|
|
|
|
attr = PyObject_GetAttrString(module, "_ARRAY_API".ptr)
|
|
|
|
if not attr or _pyobj_type(attr) != PyCapsule_Type:
|
|
raise RuntimeError("NumPy API object not found or did not have type 'capsule'")
|
|
|
|
api = Ptr[cobj](PyCapsule_GetPointer(attr, cobj()))
|
|
# https://github.com/numpy/numpy/blob/main/numpy/_core/code_generators/numpy_api.py
|
|
PyArray_Type = api[2]
|
|
PyArray_New = Function[[cobj, i32, cobj, i32, cobj, cobj, i32, i32, cobj], cobj](api[93])
|
|
|
|
_setup_numpy_bridge()
|
|
|
|
@extend
|
|
class ndarray:
|
|
def __to_py__(self):
|
|
dims = self.shape
|
|
code = _type_code(self.dtype)
|
|
arr = PyArray_New(PyArray_Type, i32(self.ndim), __ptr__(dims).as_byte(),
|
|
i32(code), cobj(), cobj(), i32(0), i32(0), cobj())
|
|
arr_ptr = Ptr[PyArrayObject](arr)
|
|
data = arr_ptr[0].data
|
|
|
|
if code == NPY_OBJECT:
|
|
p = Ptr[cobj](data)
|
|
k = 0
|
|
for idx in util.multirange(dims):
|
|
e = self._ptr(idx)[0]
|
|
if hasattr(e, "__to_py__"):
|
|
p[k] = e.__to_py__()
|
|
k += 1
|
|
else:
|
|
cc, _ = self._contig
|
|
if cc:
|
|
str.memcpy(data, self.data.as_byte(), self.nbytes)
|
|
else:
|
|
p = Ptr[self.dtype](data)
|
|
k = 0
|
|
for idx in util.multirange(dims):
|
|
e = self._ptr(idx)[0]
|
|
p[k] = e
|
|
k += 1
|
|
|
|
if isinstance(self.dtype, datetime64) or isinstance(self.dtype, timedelta64):
|
|
_set_datetime_descr(arr_ptr, self.dtype)
|
|
|
|
return arr
|
|
|
|
def _from_py(a: cobj, copy: bool):
|
|
if _pyobj_type(a) != PyArray_Type:
|
|
raise PyError("NumPy conversion error: Python object did not have array type")
|
|
|
|
arr = Ptr[PyArrayObject](a)[0]
|
|
|
|
if int(arr.nd) != ndim:
|
|
raise PyError("NumPy conversion error: Python array has incorrect dimension")
|
|
|
|
code = _type_code(dtype)
|
|
if int(arr.descr[0].type_num) != code:
|
|
raise PyError("NumPy conversion error: Python array has incorrect dtype")
|
|
|
|
arr_data = arr.data
|
|
arr_shape = arr.dimensions
|
|
arr_strides = arr.strides
|
|
arr_flags = int(arr.flags)
|
|
shape = tuple(arr_shape[i] for i in staticrange(ndim))
|
|
strides = tuple(arr_strides[i] for i in staticrange(ndim))
|
|
size = util.count(shape)
|
|
|
|
if code != NPY_OBJECT and not copy:
|
|
return ndarray(shape, strides, Ptr[dtype](arr_data))
|
|
|
|
data = Ptr[dtype](size)
|
|
|
|
if code == NPY_OBJECT:
|
|
k = 0
|
|
for idx in util.multirange(shape):
|
|
off = 0
|
|
for i in range(ndim):
|
|
off += idx[i] * strides[i]
|
|
e = Ptr[cobj](arr_data + off)[0]
|
|
if hasattr(dtype, "__from_py__"):
|
|
data[k] = dtype.__from_py__(e)
|
|
k += 1
|
|
return ndarray(shape, data)
|
|
else:
|
|
cc = (arr_flags & NPY_ARRAY_C_CONTIGUOUS != 0)
|
|
fc = (arr_flags & NPY_ARRAY_F_CONTIGUOUS != 0)
|
|
|
|
if cc or fc:
|
|
str.memcpy(data.as_byte(), arr_data, size * util.sizeof(dtype))
|
|
return ndarray(shape, strides, data)
|
|
else:
|
|
k = 0
|
|
for idx in util.multirange(shape):
|
|
off = 0
|
|
for i in range(ndim):
|
|
off += idx[i] * strides[i]
|
|
e = Ptr[dtype](arr_data + off)[0]
|
|
data[k] = e
|
|
k += 1
|
|
return ndarray(shape, data)
|
|
|
|
def __from_py__(a: cobj):
|
|
return ndarray[dtype, ndim]._from_py(a, copy=False)
|