# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>

from internal.gc import sizeof as _sizeof

@tuple
class Device:
    _device: i32

    def __new__(device: int):
        from C import seq_nvptx_device(int) -> i32
        return Device(seq_nvptx_device(device))

    @staticmethod
    def count():
        from C import seq_nvptx_device_count() -> int
        return seq_nvptx_device_count()

    def __str__(self):
        from C import seq_nvptx_device_name(i32) -> str
        return seq_nvptx_device_name(self._device)

    def __index__(self):
        return int(self._device)

    def __bool__(self):
        return True

    @property
    def compute_capability(self):
        from C import seq_nvptx_device_capability(i32) -> int
        c = seq_nvptx_device_capability(self._device)
        return (c >> 32, c & 0xffffffff)

@tuple
class Memory[T]:
    _ptr: Ptr[byte]

    def _alloc(n: int, T: type):
        from C import seq_nvptx_device_alloc(int) -> Ptr[byte]
        return Memory[T](seq_nvptx_device_alloc(n * _sizeof(T)))

    def _read(self, p: Ptr[T], n: int):
        from C import seq_nvptx_memcpy_d2h(Ptr[byte], Ptr[byte], int)
        seq_nvptx_memcpy_d2h(p.as_byte(), self._ptr, n * _sizeof(T))

    def _write(self, p: Ptr[T], n: int):
        from C import seq_nvptx_memcpy_h2d(Ptr[byte], Ptr[byte], int)
        seq_nvptx_memcpy_h2d(self._ptr, p.as_byte(), n * _sizeof(T))

    def _free(self):
        from C import seq_nvptx_device_free(Ptr[byte])
        seq_nvptx_device_free(self._ptr)

@llvm
def syncthreads() -> None:
    declare void @llvm.nvvm.barrier0()
    call void @llvm.nvvm.barrier0()
    ret {} {}

@tuple
class Dim3:
    _x: u32
    _y: u32
    _z: u32

    def __new__(x: int, y: int, z: int):
        return Dim3(u32(x), u32(y), u32(z))

    @property
    def x(self):
        return int(self._x)

    @property
    def y(self):
        return int(self._y)

    @property
    def z(self):
        return int(self._z)

@tuple
class Thread:
    @property
    def x(self):
        @pure
        @llvm
        def get_x() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
            ret i32 %res

        return int(get_x())

    @property
    def y(self):
        @pure
        @llvm
        def get_y() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.tid.y()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
            ret i32 %res

        return int(get_y())

    @property
    def z(self):
        @pure
        @llvm
        def get_z() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.tid.z()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
            ret i32 %res

        return int(get_z())

@tuple
class Block:
    @property
    def x(self):
        @pure
        @llvm
        def get_x() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
            ret i32 %res

        return int(get_x())

    @property
    def y(self):
        @pure
        @llvm
        def get_y() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
            ret i32 %res

        return int(get_y())

    @property
    def z(self):
        @pure
        @llvm
        def get_z() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
            ret i32 %res

        return int(get_z())

    @property
    def dim(self):
        @pure
        @llvm
        def get_x() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
            ret i32 %res

        @pure
        @llvm
        def get_y() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
            ret i32 %res

        @pure
        @llvm
        def get_z() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
            ret i32 %res

        return Dim3(get_x(), get_y(), get_z())

@tuple
class Grid:
    @property
    def dim(self):
        @pure
        @llvm
        def get_x() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
            ret i32 %res

        @pure
        @llvm
        def get_y() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
            ret i32 %res

        @pure
        @llvm
        def get_z() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
            ret i32 %res

        return Dim3(get_x(), get_y(), get_z())

@tuple
class Warp:
    def __len__(self):
        @pure
        @llvm
        def get_warpsize() -> u32:
            declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()
            %res = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
            ret i32 %res

        return int(get_warpsize())

thread = Thread()
block = Block()
grid = Grid()
warp = Warp()

def _catch():
    return (thread, block, grid, warp)

_catch()

@tuple
class AllocCache:
    v: List[Ptr[byte]]

    def add(self, p: Ptr[byte]):
        self.v.append(p)

    def free(self):
        for p in self.v:
            Memory[byte](p)._free()

def _tuple_from_gpu(args, gpu_args):
    if staticlen(args) > 0:
        a = args[0]
        g = gpu_args[0]
        a.__from_gpu__(g)
        _tuple_from_gpu(args[1:], gpu_args[1:])

def kernel(fn):
    from C import seq_nvptx_function(str) -> cobj
    from C import seq_nvptx_invoke(cobj, u32, u32, u32, u32, u32, u32, u32, cobj)

    def canonical_dim(dim):
        if isinstance(dim, NoneType):
            return (1, 1, 1)
        elif isinstance(dim, int):
            return (dim, 1, 1)
        elif isinstance(dim, Tuple[int,int]):
            return (dim[0], dim[1], 1)
        elif isinstance(dim, Tuple[int,int,int]):
            return dim
        elif isinstance(dim, Dim3):
            return (dim.x, dim.y, dim.z)
        else:
            compile_error("bad dimension argument")

    def offsets(t):
        @pure
        @llvm
        def offsetof(t: T, i: Static[int], T: type, S: type) -> int:
            %p = getelementptr {=T}, ptr null, i64 0, i32 {=i}
            %s = ptrtoint ptr %p to i64
            ret i64 %s

        if staticlen(t) == 0:
            return ()
        else:
            T = type(t)
            S = type(t[-1])
            return (*offsets(t[:-1]), offsetof(t, staticlen(t) - 1, T, S))

    def wrapper(*args, grid, block):
        grid = canonical_dim(grid)
        block = canonical_dim(block)
        cache = AllocCache([])
        shared_mem = 0
        gpu_args = tuple(arg.__to_gpu__(cache) for arg in args)
        kernel_ptr = seq_nvptx_function(__realized__(fn, gpu_args).__llvm_name__)
        p = __ptr__(gpu_args).as_byte()
        arg_ptrs = tuple((p + offset) for offset in offsets(gpu_args))
        seq_nvptx_invoke(kernel_ptr, u32(grid[0]), u32(grid[1]), u32(grid[2]), u32(block[0]),
                         u32(block[1]), u32(block[2]), u32(shared_mem), __ptr__(arg_ptrs).as_byte())
        _tuple_from_gpu(args, gpu_args)
        cache.free()

    return wrapper

def _ptr_to_gpu(p: Ptr[T], n: int, cache: AllocCache, index_filter = lambda i: True, T: type):
    from internal.gc import atomic

    if not atomic(T):
        tmp = Ptr[T](n)
        for i in range(n):
            if index_filter(i):
                tmp[i] = p[i].__to_gpu__(cache)
        p = tmp

    mem = Memory._alloc(n, T)
    cache.add(mem._ptr)
    mem._write(p, n)
    return Ptr[T](mem._ptr)

def _ptr_from_gpu(p: Ptr[T], q: Ptr[T], n: int, index_filter = lambda i: True, T: type):
    from internal.gc import atomic

    mem = Memory[T](q.as_byte())
    if not atomic(T):
        tmp = Ptr[T](n)
        mem._read(tmp, n)
        for i in range(n):
            if index_filter(i):
                p[i] = T.__from_gpu_new__(tmp[i])
    else:
        mem._read(p, n)

@pure
@llvm
def _ptr_to_type(p: cobj, T: type) -> T:
    ret ptr %p

def _object_to_gpu(obj: T, cache: AllocCache, T: type):
    s = tuple(obj)
    gpu_mem = Memory._alloc(1, type(s))
    cache.add(gpu_mem._ptr)
    gpu_mem._write(__ptr__(s), 1)
    return _ptr_to_type(gpu_mem._ptr, T)

def _object_from_gpu(obj):
    T = type(obj)
    S = type(tuple(obj))

    tmp = T.__new__()
    p = Ptr[S](tmp.__raw__())
    q = Ptr[S](obj.__raw__())

    mem = Memory[S](q.as_byte())
    mem._read(p, 1)
    return tmp

@tuple
class Pointer[T]:
    _ptr: Ptr[T]
    _len: int

    def __to_gpu__(self, cache: AllocCache):
        return _ptr_to_gpu(self._ptr, self._len, cache)

    def __from_gpu__(self, other: Ptr[T]):
        _ptr_from_gpu(self._ptr, other, self._len)

    def __from_gpu_new__(other: Ptr[T]):
        return other

def raw(v: List[T], T: type):
    return Pointer(v.arr.ptr, len(v))

@extend
class Ptr:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: Ptr[T]):
        pass

    def __from_gpu_new__(other: Ptr[T]):
        return other

@extend
class NoneType:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: NoneType):
        pass

    def __from_gpu_new__(other: NoneType):
        return other

@extend
class int:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: int):
        pass

    def __from_gpu_new__(other: int):
        return other

@extend
class float:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: float):
        pass

    def __from_gpu_new__(other: float):
        return other

@extend
class float32:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: float32):
        pass

    def __from_gpu_new__(other: float32):
        return other

@extend
class bool:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: bool):
        pass

    def __from_gpu_new__(other: bool):
        return other

@extend
class byte:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: byte):
        pass

    def __from_gpu_new__(other: byte):
        return other

@extend
class Int:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: Int[N]):
        pass

    def __from_gpu_new__(other: Int[N]):
        return other

@extend
class UInt:
    def __to_gpu__(self, cache: AllocCache):
        return self

    def __from_gpu__(self, other: UInt[N]):
        pass

    def __from_gpu_new__(other: UInt[N]):
        return other

@extend
class str:
    def __to_gpu__(self, cache: AllocCache):
        n = self.len
        return str(_ptr_to_gpu(self.ptr, n, cache), n)

    def __from_gpu__(self, other: str):
        pass

    def __from_gpu_new__(other: str):
        n = other.len
        p = Ptr[byte](n)
        _ptr_from_gpu(p, other.ptr, n)
        return str(p, n)

@extend
class List:
    @inline
    def __to_gpu__(self, cache: AllocCache):
        mem = List[T].__new__()
        n = self.len
        gpu_ptr = _ptr_to_gpu(self.arr.ptr, n, cache)
        mem.arr = Array[T](gpu_ptr, n)
        mem.len = n
        return _object_to_gpu(mem, cache)

    @inline
    def __from_gpu__(self, other: List[T]):
        mem = _object_from_gpu(other)
        my_cap = self.arr.len
        other_cap = mem.arr.len

        if other_cap > my_cap:
            self._resize(other_cap)

        _ptr_from_gpu(self.arr.ptr, mem.arr.ptr, mem.len)
        self.len = mem.len

    @inline
    def __from_gpu_new__(other: List[T]):
        mem = _object_from_gpu(other)
        arr = Array[T](mem.arr.len)
        _ptr_from_gpu(arr.ptr, mem.arr.ptr, arr.len)
        mem.arr = arr
        return mem

@extend
class DynamicTuple:
    @inline
    def __to_gpu__(self, cache: AllocCache):
        n = self._len
        gpu_ptr = _ptr_to_gpu(self._ptr, n, cache)
        return DynamicTuple(gpu_ptr, n)

    @inline
    def __from_gpu__(self, other: DynamicTuple[T]):
        _ptr_from_gpu(self._ptr, other._ptr, self._len)

    @inline
    def __from_gpu_new__(other: DynamicTuple[T]):
        n = other._len
        p = Ptr[T](n)
        _ptr_from_gpu(p, other._ptr, n)
        return DynamicTuple(p, n)

@extend
class Dict:
    def __to_gpu__(self, cache: AllocCache):
        from internal.khash import __ac_fsize
        mem = Dict[K,V].__new__()
        n = self._n_buckets
        f = __ac_fsize(n) if n else 0

        mem._n_buckets = n
        mem._size = self._size
        mem._n_occupied = self._n_occupied
        mem._upper_bound = self._upper_bound
        mem._flags = _ptr_to_gpu(self._flags, f, cache)
        mem._keys = _ptr_to_gpu(self._keys, n, cache, lambda i: self._kh_exist(i))
        mem._vals = _ptr_to_gpu(self._vals, n, cache, lambda i: self._kh_exist(i))

        return _object_to_gpu(mem, cache)

    def __from_gpu__(self, other: Dict[K,V]):
        from internal.khash import __ac_fsize
        mem = _object_from_gpu(other)
        my_n = self._n_buckets
        n = mem._n_buckets
        f = __ac_fsize(n) if n else 0

        if my_n != n:
            self._flags = Ptr[u32](f)
            self._keys = Ptr[K](n)
            self._vals = Ptr[V](n)

        _ptr_from_gpu(self._flags, mem._flags, f)
        _ptr_from_gpu(self._keys, mem._keys, n, lambda i: self._kh_exist(i))
        _ptr_from_gpu(self._vals, mem._vals, n, lambda i: self._kh_exist(i))

        self._n_buckets = n
        self._size = mem._size
        self._n_occupied = mem._n_occupied
        self._upper_bound = mem._upper_bound

    def __from_gpu_new__(other: Dict[K,V]):
        from internal.khash import __ac_fsize
        mem = _object_from_gpu(other)

        n = mem._n_buckets
        f = __ac_fsize(n) if n else 0
        flags = Ptr[u32](f)
        keys = Ptr[K](n)
        vals = Ptr[V](n)

        _ptr_from_gpu(flags, mem._flags, f)
        mem._flags = flags
        _ptr_from_gpu(keys, mem._keys, n, lambda i: mem._kh_exist(i))
        mem._keys = keys
        _ptr_from_gpu(vals, mem._vals, n, lambda i: mem._kh_exist(i))
        mem._vals = vals
        return mem

@extend
class Set:
    def __to_gpu__(self, cache: AllocCache):
        from internal.khash import __ac_fsize
        mem = Set[K].__new__()
        n = self._n_buckets
        f = __ac_fsize(n) if n else 0

        mem._n_buckets = n
        mem._size = self._size
        mem._n_occupied = self._n_occupied
        mem._upper_bound = self._upper_bound
        mem._flags = _ptr_to_gpu(self._flags, f, cache)
        mem._keys = _ptr_to_gpu(self._keys, n, cache, lambda i: self._kh_exist(i))

        return _object_to_gpu(mem, cache)

    def __from_gpu__(self, other: Set[K]):
        from internal.khash import __ac_fsize
        mem = _object_from_gpu(other)

        my_n = self._n_buckets
        n = mem._n_buckets
        f = __ac_fsize(n) if n else 0

        if my_n != n:
            self._flags = Ptr[u32](f)
            self._keys = Ptr[K](n)

        _ptr_from_gpu(self._flags, mem._flags, f)
        _ptr_from_gpu(self._keys, mem._keys, n, lambda i: self._kh_exist(i))

        self._n_buckets = n
        self._size = mem._size
        self._n_occupied = mem._n_occupied
        self._upper_bound = mem._upper_bound

    def __from_gpu_new__(other: Set[K]):
        from internal.khash import __ac_fsize
        mem = _object_from_gpu(other)

        n = mem._n_buckets
        f = __ac_fsize(n) if n else 0
        flags = Ptr[u32](f)
        keys = Ptr[K](n)

        _ptr_from_gpu(flags, mem._flags, f)
        mem._flags = flags
        _ptr_from_gpu(keys, mem._keys, n, lambda i: mem._kh_exist(i))
        mem._keys = keys
        return mem

@extend
class Optional:
    def __to_gpu__(self, cache: AllocCache):
        if self is None:
            return self
        else:
            return Optional[T](self.__val__().__to_gpu__(cache))

    def __from_gpu__(self, other: Optional[T]):
        if self is not None and other is not None:
            self.__val__().__from_gpu__(other.__val__())

    def __from_gpu_new__(other: Optional[T]):
        if other is None:
            return Optional[T]()
        else:
            return Optional[T](T.__from_gpu_new__(other.__val__()))

@extend
class __internal__:
    def class_to_gpu(obj, cache: AllocCache):
        if isinstance(obj, Tuple):
            return tuple(a.__to_gpu__(cache) for a in obj)
        elif isinstance(obj, ByVal):
            T = type(obj)
            return T(*tuple(a.__to_gpu__(cache) for a in tuple(obj)))
        else:
            T = type(obj)
            S = type(tuple(obj))
            mem = T.__new__()
            Ptr[S](mem.__raw__())[0] = tuple(obj).__to_gpu__(cache)
            return _object_to_gpu(mem, cache)

    def class_from_gpu(obj, other):
        if isinstance(obj, Tuple):
            _tuple_from_gpu(obj, other)
        elif isinstance(obj, ByVal):
            _tuple_from_gpu(tuple(obj), tuple(other))
        else:
            S = type(tuple(obj))
            Ptr[S](obj.__raw__())[0] = S.__from_gpu_new__(tuple(_object_from_gpu(other)))

    def class_from_gpu_new(other):
        if isinstance(other, Tuple):
            return tuple(type(a).__from_gpu_new__(a) for a in other)
        elif isinstance(other, ByVal):
            T = type(other)
            return T(*tuple(type(a).__from_gpu_new__(a) for a in tuple(other)))
        else:
            S = type(tuple(other))
            mem = _object_from_gpu(other)
            Ptr[S](mem.__raw__())[0] = S.__from_gpu_new__(tuple(mem))
            return mem

# @par(gpu=True) support

@pure
@llvm
def _gpu_thread_x() -> u32:
    declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
    %res = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
    ret i32 %res

@pure
@llvm
def _gpu_block_x() -> u32:
    declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
    %res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
    ret i32 %res

@pure
@llvm
def _gpu_block_dim_x() -> u32:
    declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
    %res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
    ret i32 %res

def _gpu_loop_outline_template(start, stop, args, instance: Static[int]):
    @nonpure
    def _loop_step():
        return 1

    @kernel
    def _kernel_stub(start: int, count: int, args):
        @nonpure
        def _gpu_loop_body_stub(idx, args):
            pass

        @nonpure
        def _dummy_use(n):
            pass

        _dummy_use(instance)
        idx = (int(_gpu_block_dim_x()) * int(_gpu_block_x())) + int(_gpu_thread_x())
        step = _loop_step()
        if idx < count:
            _gpu_loop_body_stub(start + (idx * step), args)

    step = _loop_step()
    loop = range(start, stop, step)

    MAX_BLOCK = 1024
    MAX_GRID = 2147483647
    G = MAX_BLOCK * MAX_GRID
    n = len(loop)

    if n == 0:
        return
    elif n > G:
        raise ValueError(f'loop exceeds GPU iteration limit of {G}')

    block = n
    grid = 1
    if n > MAX_BLOCK:
        block = MAX_BLOCK
        grid = (n // MAX_BLOCK) + (0 if n % MAX_BLOCK == 0 else 1)

    _kernel_stub(start, n, args, grid=grid, block=block)