codon/stdlib/numpy/pybridge.codon

275 lines
7.4 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from .ndarray import ndarray
from .npdatetime import datetime64, timedelta64
import util
# https://github.com/numpy/numpy/blob/main/numpy/_core/include/numpy/ndarraytypes.h
NPY_BOOL: Static[int] = 0
NPY_BYTE: Static[int] = 1
NPY_UBYTE: Static[int] = 2
NPY_SHORT: Static[int] = 3
NPY_USHORT: Static[int] = 4
NPY_INT: Static[int] = 5
NPY_UINT: Static[int] = 6
NPY_LONG: Static[int] = 7
NPY_ULONG: Static[int] = 8
NPY_LONGLONG: Static[int] = 9
NPY_ULONGLONG: Static[int] = 10
NPY_FLOAT: Static[int] = 11
NPY_DOUBLE: Static[int] = 12
NPY_LONGDOUBLE: Static[int] = 13
NPY_CFLOAT: Static[int] = 14
NPY_CDOUBLE: Static[int] = 15
NPY_CLONGDOUBLE: Static[int] = 16
NPY_OBJECT: Static[int] = 17
NPY_STRING: Static[int] = 18
NPY_UNICODE: Static[int] = 19
NPY_VOID: Static[int] = 20
NPY_DATETIME: Static[int] = 21
NPY_TIMEDELTA: Static[int] = 22
NPY_HALF: Static[int] = 23
NPY_NTYPES: Static[int] = 24
NPY_NOTYPE: Static[int] = 25
NPY_CHAR: Static[int] = 26
NPY_STRING: Static[int] = 18
NPY_USERDEF: Static[int] = 256
NPY_NTYPES_ABI_COMPATIBLE: Static[int] = 21
def _type_code(dtype: type):
if dtype is bool:
return NPY_BOOL
if dtype is i8:
return NPY_BYTE
if dtype is byte or dtype is u8:
return NPY_UBYTE
if dtype is i16:
return NPY_SHORT
if dtype is u16:
return NPY_USHORT
if dtype is i32:
return NPY_INT
if dtype is u32:
return NPY_UINT
if dtype is int or dtype is i64:
return NPY_LONG
if dtype is u64:
return NPY_ULONG
if dtype is float16:
return NPY_HALF
if dtype is float32:
return NPY_FLOAT
if dtype is float:
return NPY_DOUBLE
if dtype is complex64:
return NPY_CFLOAT
if dtype is complex:
return NPY_CDOUBLE
if isinstance(dtype, datetime64):
return NPY_DATETIME
if isinstance(dtype, timedelta64):
return NPY_TIMEDELTA
# TODO: add other types like string etc.
# once we have them.
return NPY_OBJECT
# https://github.com/numpy/numpy/blob/main/numpy/_core/include/numpy/ndarraytypes.h
@tuple
class PyObject:
refcnt: int
typptr: cobj
@tuple
class NpyAuxData:
free_func: cobj
clone_func: cobj
data1: cobj
data2: cobj
@tuple
class PyArray_DatetimeMetaData:
base: i32
num: i32
@tuple
class PyArray_DatetimeDTypeMetaData:
base: NpyAuxData
meta: PyArray_DatetimeMetaData
@tuple
class PyArrayDescr:
head: PyObject
typeobj: cobj
kind: u8
type: u8
byteorder: u8
flags: u8
type_num: i32
elsize: i32
alignment: i32
subarray: cobj
fields: cobj
names: cobj
f: cobj
metadata: cobj
c_metadata: cobj
hash: int
@tuple
class PyArrayObject:
head: PyObject
data: cobj
nd: i32
dimensions: Ptr[int]
strides: Ptr[int]
base: cobj
descr: Ptr[PyArrayDescr]
flags: i32
weakreflist: cobj
NPY_ARRAY_C_CONTIGUOUS: Static[int] = 1
NPY_ARRAY_F_CONTIGUOUS: Static[int] = 2
PyArray_Type = cobj()
PyArray_New = Function[[cobj, i32, cobj, i32, cobj, cobj, i32, i32, cobj], cobj](cobj())
def _pyobj_type(p: cobj):
return Ptr[PyObject](p)[0].typptr
def _set_datetime_descr(a: Ptr[PyArrayObject], T: type):
p = Ptr[PyArray_DatetimeDTypeMetaData](a[0].descr[0].c_metadata)
meta = PyArray_DatetimeMetaData(i32(T._code()), i32(T.num))
p[0] = PyArray_DatetimeDTypeMetaData(p[0].base, meta)
def _setup_numpy_bridge():
import python
from internal.python import PyImport_ImportModule, PyObject_GetAttrString, PyCapsule_Type, PyCapsule_GetPointer
global PyArray_Type, PyArray_New
module = PyImport_ImportModule("numpy.core._multiarray_umath".ptr)
if not module:
raise RuntimeError("Failed to import 'numpy.core._multiarray_umath'")
attr = PyObject_GetAttrString(module, "_ARRAY_API".ptr)
if not attr or _pyobj_type(attr) != PyCapsule_Type:
raise RuntimeError("NumPy API object not found or did not have type 'capsule'")
api = Ptr[cobj](PyCapsule_GetPointer(attr, cobj()))
# https://github.com/numpy/numpy/blob/main/numpy/_core/code_generators/numpy_api.py
PyArray_Type = api[2]
PyArray_New = Function[[cobj, i32, cobj, i32, cobj, cobj, i32, i32, cobj], cobj](api[93])
_setup_numpy_bridge()
@extend
class ndarray:
def __to_py__(self):
dims = self.shape
code = _type_code(self.dtype)
arr = PyArray_New(PyArray_Type, i32(self.ndim), __ptr__(dims).as_byte(),
i32(code), cobj(), cobj(), i32(0), i32(0), cobj())
arr_ptr = Ptr[PyArrayObject](arr)
data = arr_ptr[0].data
if code == NPY_OBJECT:
p = Ptr[cobj](data)
k = 0
for idx in util.multirange(dims):
e = self._ptr(idx)[0]
if hasattr(e, "__to_py__"):
p[k] = e.__to_py__()
k += 1
else:
cc, _ = self._contig
if cc:
str.memcpy(data, self.data.as_byte(), self.nbytes)
else:
p = Ptr[self.dtype](data)
k = 0
for idx in util.multirange(dims):
e = self._ptr(idx)[0]
p[k] = e
k += 1
if isinstance(self.dtype, datetime64) or isinstance(self.dtype, timedelta64):
_set_datetime_descr(arr_ptr, self.dtype)
return arr
def _from_py(a: cobj, copy: bool):
if _pyobj_type(a) != PyArray_Type:
raise PyError("NumPy conversion error: Python object did not have array type")
arr = Ptr[PyArrayObject](a)[0]
if int(arr.nd) != ndim:
raise PyError("NumPy conversion error: Python array has incorrect dimension")
code = _type_code(dtype)
if int(arr.descr[0].type_num) != code:
raise PyError("NumPy conversion error: Python array has incorrect dtype")
arr_data = arr.data
arr_shape = arr.dimensions
arr_strides = arr.strides
arr_flags = int(arr.flags)
shape = tuple(arr_shape[i] for i in staticrange(ndim))
strides = tuple(arr_strides[i] for i in staticrange(ndim))
size = util.count(shape)
if code != NPY_OBJECT and not copy:
return ndarray(shape, strides, Ptr[dtype](arr_data))
data = Ptr[dtype](size)
if code == NPY_OBJECT:
k = 0
for idx in util.multirange(shape):
off = 0
for i in staticrange(ndim):
off += idx[i] * strides[i]
e = Ptr[cobj](arr_data + off)[0]
if hasattr(dtype, "__from_py__"):
data[k] = dtype.__from_py__(e)
k += 1
return ndarray(shape, data)
else:
cc = (arr_flags & NPY_ARRAY_C_CONTIGUOUS != 0)
fc = (arr_flags & NPY_ARRAY_F_CONTIGUOUS != 0)
if cc or fc:
str.memcpy(data.as_byte(), arr_data, size * util.sizeof(dtype))
return ndarray(shape, strides, data)
else:
k = 0
for idx in util.multirange(shape):
off = 0
for i in staticrange(ndim):
off += idx[i] * strides[i]
e = Ptr[dtype](arr_data + off)[0]
data[k] = e
k += 1
return ndarray(shape, data)
def __from_py__(a: cobj):
return ndarray[dtype, ndim]._from_py(a, copy=False)