81 lines
3.1 KiB
Plaintext
81 lines
3.1 KiB
Plaintext
#include <stdio.h>
|
|
|
|
#include <vector>
|
|
|
|
#include "common_cuda_helper.hpp"
|
|
#include "trt_cuda_helper.cuh"
|
|
#include "trt_plugin_helper.hpp"
|
|
|
|
using mmlab::TensorDesc;
|
|
|
|
template <typename T>
|
|
__global__ void onnx_scatternd_kernel(const int n, const int* indices,
|
|
const T* update, T* output,
|
|
TensorDesc tensor_desc,
|
|
TensorDesc indice_desc) {
|
|
const int indice_cols = indice_desc.shape[indice_desc.dim - 1];
|
|
const int copy_stride = tensor_desc.stride[indice_cols - 1];
|
|
const int* stride = &(tensor_desc.stride[0]);
|
|
CUDA_1D_KERNEL_LOOP(index, n) {
|
|
int output_offset = 0;
|
|
const int* indices_current = indices + index * indice_cols;
|
|
for (int i = 0; i < indice_cols; ++i) {
|
|
output_offset += stride[i] * indices_current[i];
|
|
}
|
|
memcpy(output + output_offset, update + index * copy_stride,
|
|
copy_stride * sizeof(T));
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices,
|
|
const T* update, const int* dims,
|
|
int nbDims, const int* indices_dims,
|
|
int indice_nbDims, T* output,
|
|
cudaStream_t stream) {
|
|
// fill tensordesc and initial
|
|
TensorDesc tensor_desc;
|
|
memset((void*)&tensor_desc, 0, sizeof(TensorDesc));
|
|
tensor_desc.dim = nbDims;
|
|
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
|
|
tensor_desc.stride[nbDims - 1] = 1;
|
|
for (int i = nbDims - 2; i >= 0; --i) {
|
|
tensor_desc.shape[i] = dims[i];
|
|
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
|
|
}
|
|
const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0];
|
|
|
|
TensorDesc indice_desc;
|
|
memset((void*)&indice_desc, 0, sizeof(TensorDesc));
|
|
indice_desc.dim = indice_nbDims;
|
|
indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1];
|
|
indice_desc.stride[indice_nbDims - 1] = 1;
|
|
for (int i = indice_nbDims - 2; i >= 0; --i) {
|
|
indice_desc.shape[i] = indices_dims[i];
|
|
indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1];
|
|
}
|
|
|
|
// output = np.copy(data)
|
|
cudaMemcpyAsync(output, data, data_size * sizeof(T),
|
|
cudaMemcpyDeviceToDevice);
|
|
|
|
int num_update_indice = 1;
|
|
for (int i = 0; i < indice_nbDims - 1; ++i) {
|
|
num_update_indice *= indice_desc.shape[i];
|
|
}
|
|
// scatter
|
|
const int col_block = DIVUP(num_update_indice, THREADS_PER_BLOCK);
|
|
onnx_scatternd_kernel<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
|
|
num_update_indice, indices, update, output, tensor_desc, indice_desc);
|
|
}
|
|
|
|
template void TRTONNXScatterNDKernelLauncher<float>(
|
|
const float* data, const int* indices, const float* update, const int* dims,
|
|
int nbDims, const int* indices_dims, int indice_nbDims, float* output,
|
|
cudaStream_t stream);
|
|
|
|
template void TRTONNXScatterNDKernelLauncher<int>(
|
|
const int* data, const int* indices, const int* update, const int* dims,
|
|
int nbDims, const int* indices_dims, int indice_nbDims, int* output,
|
|
cudaStream_t stream);
|