# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>

import gpu
import util
from .ndarray import ndarray

@extend
class ndarray:
    def __to_gpu__(self, cache: gpu.AllocCache):
        if self._is_contig:
            data_gpu = gpu._ptr_to_gpu(self.data, self.size, cache)
            return ndarray(self.shape, self.strides, data_gpu)
        else:
            n = self.size
            p = Ptr[dtype](n)
            i = 0
            for idx in util.multirange(self.shape):
                p[i] = self._ptr(idx)[0]
                i += 1
            data_gpu = gpu._ptr_to_gpu(p, self.size, cache)
            util.free(p)
            return ndarray(self.shape, data_gpu)

    def __from_gpu__(self, other: ndarray[dtype, ndim]):
        if self._is_contig:
            gpu._ptr_from_gpu(self.data, other.data, self.size)
        else:
            for idx in util.multirange(self.shape):
                gpu._ptr_from_gpu(self._ptr(idx), other._ptr(idx), 1)

    def __from_gpu_new__(other: ndarray[dtype, ndim]):
        data = Ptr[dtype](other.size)
        if other._is_contig:
            gpu._ptr_from_gpu(data, other.data, other.size)
            return ndarray(other.shape, other.strides, data)
        else:
            i = 0
            for idx in util.multirange(other.shape):
                gpu._ptr_from_gpu(data + i, other._ptr(idx), 1)
                i += 1
            return ndarray(other.shape, data)