mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
import gpu
|
|
import util
|
|
from .ndarray import ndarray
|
|
|
|
@extend
|
|
class ndarray:
|
|
def __to_gpu__(self, cache: gpu.AllocCache):
|
|
if self._is_contig:
|
|
data_gpu = gpu._ptr_to_gpu(self.data, self.size, cache)
|
|
return ndarray(self.shape, self.strides, data_gpu)
|
|
else:
|
|
n = self.size
|
|
p = Ptr[dtype](n)
|
|
i = 0
|
|
for idx in util.multirange(self.shape):
|
|
p[i] = self._ptr(idx)[0]
|
|
i += 1
|
|
data_gpu = gpu._ptr_to_gpu(p, self.size, cache)
|
|
util.free(p)
|
|
return ndarray(self.shape, data_gpu)
|
|
|
|
def __from_gpu__(self, other: ndarray[dtype, ndim]):
|
|
if self._is_contig:
|
|
gpu._ptr_from_gpu(self.data, other.data, self.size)
|
|
else:
|
|
for idx in util.multirange(self.shape):
|
|
gpu._ptr_from_gpu(self._ptr(idx), other._ptr(idx), 1)
|
|
|
|
def __from_gpu_new__(other: ndarray[dtype, ndim]):
|
|
data = Ptr[dtype](other.size)
|
|
if other._is_contig:
|
|
gpu._ptr_from_gpu(data, other.data, other.size)
|
|
return ndarray(other.shape, other.strides, data)
|
|
else:
|
|
i = 0
|
|
for idx in util.multirange(other.shape):
|
|
gpu._ptr_from_gpu(data + i, other._ptr(idx), 1)
|
|
i += 1
|
|
return ndarray(other.shape, data)
|