# Copyright (C) 2022-2025 Exaloop Inc. import gpu import util from .ndarray import ndarray @extend class ndarray: def __to_gpu__(self, cache: gpu.AllocCache): if self._is_contig: data_gpu = gpu._ptr_to_gpu(self.data, self.size, cache) return ndarray(self.shape, self.strides, data_gpu) else: n = self.size p = Ptr[dtype](n) i = 0 for idx in util.multirange(self.shape): p[i] = self._ptr(idx)[0] i += 1 data_gpu = gpu._ptr_to_gpu(p, self.size, cache) util.free(p) return ndarray(self.shape, data_gpu) def __from_gpu__(self, other: ndarray[dtype, ndim]): if self._is_contig: gpu._ptr_from_gpu(self.data, other.data, self.size) else: for idx in util.multirange(self.shape): gpu._ptr_from_gpu(self._ptr(idx), other._ptr(idx), 1) def __from_gpu_new__(other: ndarray[dtype, ndim]): data = Ptr[dtype](other.size) if other._is_contig: gpu._ptr_from_gpu(data, other.data, other.size) return ndarray(other.shape, other.strides, data) else: i = 0 for idx in util.multirange(other.shape): gpu._ptr_from_gpu(data + i, other._ptr(idx), 1) i += 1 return ndarray(other.shape, data)