From 1e47821e4928ee8f80cc3dd5db5fba8a032f9bc4 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 15 Jun 2021 20:44:45 +0800 Subject: [PATCH] add backend plugin build system --- .isort.cfg | 2 +- CMakeLists.txt | 15 ++ backend_ops/CMakeLists.txt | 13 ++ backend_ops/tensorrt/CMakeLists.txt | 61 ++++++ .../tensorrt/common/common_cuda_helper.hpp | 110 ++++++++++ .../tensorrt/common/trt_cuda_helper.cuh | 30 +++ .../tensorrt/common/trt_plugin_helper.hpp | 41 ++++ backend_ops/tensorrt/common/trt_serialize.hpp | 105 +++++++++ .../tensorrt/common_impl/CMakeLists.txt | 4 + .../tensorrt/common_impl/trt_cuda_helper.cu | 66 ++++++ backend_ops/tensorrt/scatternd/CMakeLists.txt | 4 + .../tensorrt/scatternd/trt_scatternd.cpp | 206 ++++++++++++++++++ .../tensorrt/scatternd/trt_scatternd.hpp | 98 +++++++++ .../scatternd/trt_scatternd_kernel.cu | 92 ++++++++ backend_ops/tensorrt/trt_plugin.cpp | 7 + mmdeploy/utils/function_rewriter.py | 10 +- mmdeploy/utils/module_rewriter.py | 19 +- mmdeploy/utils/symbolic_register.py | 10 +- tests/test_utils/test_register.py | 23 +- 19 files changed, 883 insertions(+), 33 deletions(-) create mode 100644 CMakeLists.txt create mode 100644 backend_ops/CMakeLists.txt create mode 100644 backend_ops/tensorrt/CMakeLists.txt create mode 100644 backend_ops/tensorrt/common/common_cuda_helper.hpp create mode 100644 backend_ops/tensorrt/common/trt_cuda_helper.cuh create mode 100644 backend_ops/tensorrt/common/trt_plugin_helper.hpp create mode 100644 backend_ops/tensorrt/common/trt_serialize.hpp create mode 100644 backend_ops/tensorrt/common_impl/CMakeLists.txt create mode 100644 backend_ops/tensorrt/common_impl/trt_cuda_helper.cu create mode 100644 backend_ops/tensorrt/scatternd/CMakeLists.txt create mode 100644 backend_ops/tensorrt/scatternd/trt_scatternd.cpp create mode 100644 backend_ops/tensorrt/scatternd/trt_scatternd.hpp create mode 100644 backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu create mode 100644 backend_ops/tensorrt/trt_plugin.cpp diff --git a/.isort.cfg b/.isort.cfg index 7a34bf769..ef51184d8 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,2 +1,2 @@ [settings] -known_third_party = +known_third_party = mmcv,setuptools,torch diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..f665322dd --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required (VERSION 3.10) +project (mmdeploy_backend_ops) + +# TensorRT config + +# enable tensorrt +option(BUILD_TENSORRT_OPS "enable tensorrt ops" OFF) +# TensorRT search path +if (BUILD_TENSORRT_OPS) + if (NOT DEFINED TENSORRT_DIR) + set(TENSORRT_DIR $ENV{TENSORRT_DIR}) + endif() +endif() + +add_subdirectory (backend_ops) diff --git a/backend_ops/CMakeLists.txt b/backend_ops/CMakeLists.txt new file mode 100644 index 000000000..275711755 --- /dev/null +++ b/backend_ops/CMakeLists.txt @@ -0,0 +1,13 @@ +add_definitions(-std=c++11) +set(CMAKE_CXX_FLAGS_RELEASE "-O3") + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + + +# build TensorRT ops +if (BUILD_TENSORRT_OPS) + message("Build TensorRT custom ops.") + add_subdirectory (tensorrt) +endif() \ No newline at end of file diff --git a/backend_ops/tensorrt/CMakeLists.txt b/backend_ops/tensorrt/CMakeLists.txt new file mode 100644 index 000000000..2f6c7ae62 --- /dev/null +++ b/backend_ops/tensorrt/CMakeLists.txt @@ -0,0 +1,61 @@ +set(TARGET_NAME mmlab_tensorrt_ops) +set(SHARED_TARGET ${TARGET_NAME}) + +# cuda +FIND_PACKAGE(CUDA REQUIRED) +INCLUDE_DIRECTORIES(/usr/local/cuda/include) +enable_language(CUDA) + +# tensorrt +find_path(TENSORRT_INCLUDE_DIR NvInfer.h + HINTS ${TENSORRT_DIR} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES include) +if (TENSORRT_INCLUDE_DIR) + MESSAGE(STATUS " Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") +else() + MESSAGE(ERROR " Cannot found TensorRT headers") +endif() + +find_library(TENSORRT_LIBRARY_INFER nvinfer + HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/x64) +find_library(TENSORRT_LIBRARY_PARSERS nvparsers + HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/x64) +find_library(TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin + HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} + PATH_SUFFIXES lib lib64 lib/x64) +set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} + ${TENSORRT_LIBRARY_PARSERS} + ${TENSORRT_LIBRARY_INFER_PLUGIN} + ) +if (TENSORRT_LIBRARY_INFER AND TENSORRT_LIBRARY_PARSERS AND TENSORRT_LIBRARY_INFER_PLUGIN) +MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") +else() +MESSAGE(ERROR " Cannot found TensorRT libs") +endif() +find_package_handle_standard_args( + TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIBRARY) +if(NOT TENSORRT_FOUND) + message(ERROR " Cannot find TensorRT library.") +endif() +INCLUDE_DIRECTORIES(${TENSORRT_INCLUDE_DIR}) + +# include common +include_directories(common) +add_subdirectory(common_impl) + +# add plugin source +set(PLUGIN_LISTS scatternd) + +foreach(PLUGIN_ITER ${PLUGIN_LISTS}) + add_subdirectory(${PLUGIN_ITER}) +endforeach(PLUGIN_ITER) + +list(APPEND PLUGIN_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/trt_plugin.cpp") + +set(INFER_PLUGIN_LIB ${TENSORRT_LIBRARY}) + +cuda_add_library(${SHARED_TARGET} SHARED ${BACKEND_OPS_SRCS}) +target_link_libraries(${SHARED_TARGET} ${INFER_PLUGIN_LIB}) +target_include_directories(${SHARED_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common) diff --git a/backend_ops/tensorrt/common/common_cuda_helper.hpp b/backend_ops/tensorrt/common/common_cuda_helper.hpp new file mode 100644 index 000000000..a9ab6e82f --- /dev/null +++ b/backend_ops/tensorrt/common/common_cuda_helper.hpp @@ -0,0 +1,110 @@ +#ifndef COMMON_CUDA_HELPER +#define COMMON_CUDA_HELPER + +#include + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +#define THREADS_PER_BLOCK 512 + +inline int GET_BLOCKS(const int N) { + int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + int max_block_num = 4096; + return min(optimal_block_num, max_block_num); +} + +template +__device__ T bilinear_interpolate(const T* input, const int height, + const int width, T y, T x, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__device__ void bilinear_interpolate_gradient( + const int height, const int width, T y, T x, T& w1, T& w2, T& w3, T& w4, + int& x_low, int& x_high, int& y_low, int& y_high, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} +#endif // COMMON_CUDA_HELPER diff --git a/backend_ops/tensorrt/common/trt_cuda_helper.cuh b/backend_ops/tensorrt/common/trt_cuda_helper.cuh new file mode 100644 index 000000000..a4635dcdd --- /dev/null +++ b/backend_ops/tensorrt/common/trt_cuda_helper.cuh @@ -0,0 +1,30 @@ +#ifndef TRT_CUDA_HELPER_HPP +#define TRT_CUDA_HELPER_HPP + +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) { \ + printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(0); \ + } \ + } + +/** + * Returns a view of the original tensor with its dimensions permuted. + * + * @param[out] dst pointer to the destination tensor + * @param[in] src pointer to the source tensor + * @param[in] src_size shape of the src tensor + * @param[in] permute The desired ordering of dimensions + * @param[in] src_dim dim of src tensor + * @param[in] stream cuda stream handle + */ +template +void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, + int *permute, int src_dim, cudaStream_t stream = 0); + +#endif // TRT_CUDA_HELPER_HPP diff --git a/backend_ops/tensorrt/common/trt_plugin_helper.hpp b/backend_ops/tensorrt/common/trt_plugin_helper.hpp new file mode 100644 index 000000000..a453efc3c --- /dev/null +++ b/backend_ops/tensorrt/common/trt_plugin_helper.hpp @@ -0,0 +1,41 @@ +#ifndef TRT_PLUGIN_HELPER_HPP +#define TRT_PLUGIN_HELPER_HPP +#include + +#include "NvInferPlugin.h" + +namespace mmlab { + +const int MAXTENSORDIMS = 10; + +struct TensorDesc { + int shape[MAXTENSORDIMS]; + int stride[MAXTENSORDIMS]; + int dim; +}; + +inline unsigned int getElementSize(nvinfer1::DataType t) { + switch (t) { + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + // case nvinfer1::DataType::kBOOL: + case nvinfer1::DataType::kINT8: + return 1; + default: + throw std::runtime_error("Invalid DataType."); + } + throw std::runtime_error("Invalid DataType."); + return 0; +} + +inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) { + return size_t((origin_size + aligned_number - 1) / aligned_number) * + aligned_number; +} + +} // namespace mmlab +#endif // TRT_PLUGIN_HELPER_HPP diff --git a/backend_ops/tensorrt/common/trt_serialize.hpp b/backend_ops/tensorrt/common/trt_serialize.hpp new file mode 100644 index 000000000..1f0899fdf --- /dev/null +++ b/backend_ops/tensorrt/common/trt_serialize.hpp @@ -0,0 +1,105 @@ +// Modified from: +// https://github.com/NVIDIA/TensorRT/blob/master/plugin/common/serialize.hpp + +#ifndef TRT_SERIALIZE_HPP +#define TRT_SERIALIZE_HPP +#include +#include +#include +#include +#include +using std::cerr; +using std::cout; +using std::endl; + +template +inline void serialize_value(void** buffer, T const& value); + +template +inline void deserialize_value(void const** buffer, size_t* buffer_size, + T* value); + +namespace { + +template +struct Serializer {}; + +template +struct Serializer::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t serialized_size(T const& value) { return sizeof(T); } + static void serialize(void** buffer, T const& value) { + ::memcpy(*buffer, &value, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + } + static void deserialize(void const** buffer, size_t* buffer_size, T* value) { + assert(*buffer_size >= sizeof(T)); + ::memcpy(value, *buffer, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + *buffer_size -= sizeof(T); + } +}; + +template <> +struct Serializer { + static size_t serialized_size(const char* value) { return strlen(value) + 1; } + static void serialize(void** buffer, const char* value) { + ::strcpy(static_cast(*buffer), value); + reinterpret_cast(*buffer) += strlen(value) + 1; + } + static void deserialize(void const** buffer, size_t* buffer_size, + const char** value) { + *value = static_cast(*buffer); + size_t data_size = strnlen(*value, *buffer_size) + 1; + assert(*buffer_size >= data_size); + reinterpret_cast(*buffer) += data_size; + *buffer_size -= data_size; + } +}; + +template +struct Serializer, + typename std::enable_if::value || + std::is_enum::value || + std::is_pod::value>::type> { + static size_t serialized_size(std::vector const& value) { + return sizeof(value.size()) + value.size() * sizeof(T); + } + static void serialize(void** buffer, std::vector const& value) { + serialize_value(buffer, value.size()); + size_t nbyte = value.size() * sizeof(T); + ::memcpy(*buffer, value.data(), nbyte); + reinterpret_cast(*buffer) += nbyte; + } + static void deserialize(void const** buffer, size_t* buffer_size, + std::vector* value) { + size_t size; + deserialize_value(buffer, buffer_size, &size); + value->resize(size); + size_t nbyte = value->size() * sizeof(T); + assert(*buffer_size >= nbyte); + ::memcpy(value->data(), *buffer, nbyte); + reinterpret_cast(*buffer) += nbyte; + *buffer_size -= nbyte; + } +}; + +} // namespace + +template +inline size_t serialized_size(T const& value) { + return Serializer::serialized_size(value); +} + +template +inline void serialize_value(void** buffer, T const& value) { + return Serializer::serialize(buffer, value); +} + +template +inline void deserialize_value(void const** buffer, size_t* buffer_size, + T* value) { + return Serializer::deserialize(buffer, buffer_size, value); +} +#endif // TRT_SERIALIZE_HPP diff --git a/backend_ops/tensorrt/common_impl/CMakeLists.txt b/backend_ops/tensorrt/common_impl/CMakeLists.txt new file mode 100644 index 000000000..d967c6b6b --- /dev/null +++ b/backend_ops/tensorrt/common_impl/CMakeLists.txt @@ -0,0 +1,4 @@ +file(GLOB OPS_SRCS *.cpp *.cu) +file(GLOB OPS_HEADS *.h *.hpp *.cuh) +set(BACKEND_OPS_SRCS ${BACKEND_OPS_SRCS} ${OPS_SRCS} ${OPS_HEADS}) +set(BACKEND_OPS_SRCS ${BACKEND_OPS_SRCS} PARENT_SCOPE) \ No newline at end of file diff --git a/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu b/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu new file mode 100644 index 000000000..3aa7014ff --- /dev/null +++ b/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu @@ -0,0 +1,66 @@ +#include "common_cuda_helper.hpp" +#include "trt_cuda_helper.cuh" +#include "trt_plugin_helper.hpp" + +using mmlab::TensorDesc; + +template +__global__ void copy_permute_kernel(scalar_t *dst, const scalar_t *src, int n, + TensorDesc ts_src_stride, + TensorDesc ts_dst_stride, + TensorDesc ts_permute) { + const int src_dim = ts_src_stride.dim; + int *src_stride = &(ts_src_stride.stride[0]); + int *dst_stride = &(ts_dst_stride.stride[0]); + int *permute = &(ts_permute.shape[0]); + CUDA_1D_KERNEL_LOOP(index, n) { + size_t dst_index = index; + size_t src_index = 0; + for (int i = 0; i < src_dim; ++i) { + int dim_index = dst_index / dst_stride[i]; + dst_index = dst_index % dst_stride[i]; + src_index += dim_index * src_stride[permute[i]]; + } + dst[index] = src[src_index]; + } +} + +template +void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, + int *permute, int src_dim, cudaStream_t stream) { + size_t copy_size = 1; + TensorDesc ts_permute; + memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int)); + + TensorDesc ts_src_stride; + TensorDesc ts_dst_stride; + ts_src_stride.dim = src_dim; + ts_dst_stride.dim = src_dim; + int *src_stride = &(ts_src_stride.stride[0]); + int *dst_stride = &(ts_dst_stride.stride[0]); + int *dst_size = &(ts_dst_stride.shape[0]); + src_stride[src_dim - 1] = 1; + dst_stride[src_dim - 1] = 1; + + for (int i = src_dim - 1; i >= 0; --i) { + dst_size[i] = src_size[permute[i]]; + if (i < src_dim - 1) { + src_stride[i] = src_stride[i + 1] * src_size[i + 1]; + } + } + + for (int i = src_dim - 1; i >= 0; --i) { + copy_size *= dst_size[i]; + if (i < src_dim - 1) { + dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1]; + } + } + + copy_permute_kernel + <<>>( + dst, src, copy_size, ts_src_stride, ts_dst_stride, ts_permute); +} + +template void memcpyPermute(float *dst, const float *src, int *src_size, + int *permute, int src_dim, + cudaStream_t stream); diff --git a/backend_ops/tensorrt/scatternd/CMakeLists.txt b/backend_ops/tensorrt/scatternd/CMakeLists.txt new file mode 100644 index 000000000..d967c6b6b --- /dev/null +++ b/backend_ops/tensorrt/scatternd/CMakeLists.txt @@ -0,0 +1,4 @@ +file(GLOB OPS_SRCS *.cpp *.cu) +file(GLOB OPS_HEADS *.h *.hpp *.cuh) +set(BACKEND_OPS_SRCS ${BACKEND_OPS_SRCS} ${OPS_SRCS} ${OPS_HEADS}) +set(BACKEND_OPS_SRCS ${BACKEND_OPS_SRCS} PARENT_SCOPE) \ No newline at end of file diff --git a/backend_ops/tensorrt/scatternd/trt_scatternd.cpp b/backend_ops/tensorrt/scatternd/trt_scatternd.cpp new file mode 100644 index 000000000..8adff2477 --- /dev/null +++ b/backend_ops/tensorrt/scatternd/trt_scatternd.cpp @@ -0,0 +1,206 @@ +#include "trt_scatternd.hpp" + +#include +#include + +#include + +#include "trt_serialize.hpp" + +extern void TRTONNXScatterNDKernelLauncher_float( + const float *data, const int *indices, const float *update, const int *dims, + int nbDims, const int *indices_dims, int indice_nbDims, float *output, + cudaStream_t stream); + +extern void TRTONNXScatterNDKernelLauncher_int32( + const int *data, const int *indices, const int *update, const int *dims, + int nbDims, const int *indices_dims, int indice_nbDims, int *output, + cudaStream_t stream); + +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"ScatterND"}; +} // namespace + +nvinfer1::PluginFieldCollection ONNXScatterNDDynamicCreator::mFC{}; +std::vector + ONNXScatterNDDynamicCreator::mPluginAttributes; + +ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string &name) + : mLayerName(name) {} + +ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string name, + const void *data, size_t length) + : mLayerName(name) {} + +nvinfer1::IPluginV2DynamicExt *ONNXScatterNDDynamic::clone() const { + ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs ONNXScatterNDDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) { + return inputs[0]; +} + +bool ONNXScatterNDDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, + int nbOutputs) { + if (pos < nbInputs) { + switch (pos) { + case 0: + // data + return (inOut[pos].type == nvinfer1::DataType::kFLOAT && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR) || + (inOut[pos].type == nvinfer1::DataType::kINT32 && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); + case 1: + // indices + return inOut[pos].type == nvinfer1::DataType::kINT32 && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + case 2: + // updates + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + default: + return true; + } + } else { + switch (pos - nbInputs) { + case 0: + // output + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + default: + return true; + } + } + return true; +} + +void ONNXScatterNDDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {} + +size_t ONNXScatterNDDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const { + return 0; +} + +int ONNXScatterNDDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, + void *const *outputs, void *workSpace, + cudaStream_t stream) { + const int *dims = &(inputDesc[0].dims.d[0]); + const int *indices_dims = &(inputDesc[1].dims.d[0]); + int nbDims = inputDesc[0].dims.nbDims; + int indice_nbDims = inputDesc[1].dims.nbDims; + + const void *data = inputs[0]; + const void *indices = inputs[1]; + const void *update = inputs[2]; + void *output = outputs[0]; + + auto data_type = inputDesc[0].type; + + switch (data_type) { + case nvinfer1::DataType::kFLOAT: + TRTONNXScatterNDKernelLauncher_float( + (float *)data, (int *)indices, (float *)update, dims, nbDims, + indices_dims, indice_nbDims, (float *)output, stream); + break; + + case nvinfer1::DataType::kINT32: + TRTONNXScatterNDKernelLauncher_int32( + (int *)data, (int *)indices, (int *)update, dims, nbDims, + indices_dims, indice_nbDims, (int *)output, stream); + break; + default: + break; + } + + return 0; +} + +nvinfer1::DataType ONNXScatterNDDynamic::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *ONNXScatterNDDynamic::getPluginType() const { return PLUGIN_NAME; } + +const char *ONNXScatterNDDynamic::getPluginVersion() const { + return PLUGIN_VERSION; +} + +int ONNXScatterNDDynamic::getNbOutputs() const { return 1; } + +int ONNXScatterNDDynamic::initialize() { return 0; } + +void ONNXScatterNDDynamic::terminate() {} + +size_t ONNXScatterNDDynamic::getSerializationSize() const { return 0; } + +void ONNXScatterNDDynamic::serialize(void *buffer) const {} + +void ONNXScatterNDDynamic::destroy() { + // This gets called when the network containing plugin is destroyed + delete this; +} + +void ONNXScatterNDDynamic::setPluginNamespace(const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *ONNXScatterNDDynamic::getPluginNamespace() const { + return mNamespace.c_str(); +} + +////////////////////// creator ///////////////////////////// + +ONNXScatterNDDynamicCreator::ONNXScatterNDDynamicCreator() { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *ONNXScatterNDDynamicCreator::getPluginName() const { + return PLUGIN_NAME; +} + +const char *ONNXScatterNDDynamicCreator::getPluginVersion() const { + return PLUGIN_VERSION; +} + +const nvinfer1::PluginFieldCollection * +ONNXScatterNDDynamicCreator::getFieldNames() { + return &mFC; +} + +nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) { + ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::deserializePlugin( + const char *name, const void *serialData, size_t serialLength) { + auto plugin = new ONNXScatterNDDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +void ONNXScatterNDDynamicCreator::setPluginNamespace(const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *ONNXScatterNDDynamicCreator::getPluginNamespace() const { + return mNamespace.c_str(); +} diff --git a/backend_ops/tensorrt/scatternd/trt_scatternd.hpp b/backend_ops/tensorrt/scatternd/trt_scatternd.hpp new file mode 100644 index 000000000..6087cbefb --- /dev/null +++ b/backend_ops/tensorrt/scatternd/trt_scatternd.hpp @@ -0,0 +1,98 @@ +#ifndef TRT_SCATTERND_HPP +#define TRT_SCATTERND_HPP +#include + +#include +#include +#include + +#include "trt_plugin_helper.hpp" + +class ONNXScatterNDDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + ONNXScatterNDDynamic(const std::string &name); + + ONNXScatterNDDynamic(const std::string name, const void *data, size_t length); + + ONNXScatterNDDynamic() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *inOut, + int nbInputs, int nbOutputs) override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, + cudaStream_t stream) override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType *inputTypes, + int nbInputs) const override; + + // IPluginV2 Methods + const char *getPluginType() const override; + const char *getPluginVersion() const override; + int getNbOutputs() const override; + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void *buffer) const override; + void destroy() override; + void setPluginNamespace(const char *pluginNamespace) override; + const char *getPluginNamespace() const override; + + private: + const std::string mLayerName; + std::string mNamespace; + + protected: + // To prevent compiler warnings. + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::configurePlugin; + using nvinfer1::IPluginV2DynamicExt::enqueue; + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; +}; + +class ONNXScatterNDDynamicCreator : public nvinfer1::IPluginCreator { + public: + ONNXScatterNDDynamicCreator(); + + const char *getPluginName() const override; + + const char *getPluginVersion() const override; + + const nvinfer1::PluginFieldCollection *getFieldNames() override; + + nvinfer1::IPluginV2 *createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, + const void *serialData, + size_t serialLength) override; + + void setPluginNamespace(const char *pluginNamespace) override; + + const char *getPluginNamespace() const override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; +#endif // TRT_SCATTERND_HPP diff --git a/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu b/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu new file mode 100644 index 000000000..2c7edb99d --- /dev/null +++ b/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu @@ -0,0 +1,92 @@ +#include + +#include + +#include "common_cuda_helper.hpp" +#include "trt_cuda_helper.cuh" +#include "trt_plugin_helper.hpp" + +static int const threadsPerBlock = sizeof(unsigned long long int) * 8; + +using mmlab::TensorDesc; + +template +__global__ void onnx_scatternd_kernel(const int n, const int* indices, + const T* update, T* output, + TensorDesc tensor_desc, + TensorDesc indice_desc) { + const int indice_cols = indice_desc.shape[indice_desc.dim - 1]; + const int copy_stride = tensor_desc.stride[indice_cols - 1]; + const int* stride = &(tensor_desc.stride[0]); + CUDA_1D_KERNEL_LOOP(index, n) { + int output_offset = 0; + const int* indices_current = indices + index * indice_cols; + for (int i = 0; i < indice_cols; ++i) { + output_offset += stride[i] * indices_current[i]; + } + memcpy(output + output_offset, update + index * copy_stride, + copy_stride * sizeof(T)); + } +} + +template +void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, + const T* update, const int* dims, + int nbDims, const int* indices_dims, + int indice_nbDims, T* output, + cudaStream_t stream) { + // fill tensordesc and initial + TensorDesc tensor_desc; + memset((void*)&tensor_desc, 0, sizeof(TensorDesc)); + tensor_desc.dim = nbDims; + tensor_desc.shape[nbDims - 1] = dims[nbDims - 1]; + tensor_desc.stride[nbDims - 1] = 1; + for (int i = nbDims - 2; i >= 0; --i) { + tensor_desc.shape[i] = dims[i]; + tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1]; + } + const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0]; + + TensorDesc indice_desc; + memset((void*)&indice_desc, 0, sizeof(TensorDesc)); + indice_desc.dim = indice_nbDims; + indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1]; + indice_desc.stride[indice_nbDims - 1] = 1; + for (int i = indice_nbDims - 2; i >= 0; --i) { + indice_desc.shape[i] = indices_dims[i]; + indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1]; + } + + // output = np.copy(data) + cudaMemcpyAsync(output, data, data_size * sizeof(T), + cudaMemcpyDeviceToDevice); + + int num_update_indice = 1; + for (int i = 0; i < indice_nbDims - 1; ++i) { + num_update_indice *= indice_desc.shape[i]; + } + // scatter + const int col_block = DIVUP(num_update_indice, threadsPerBlock); + onnx_scatternd_kernel<<>>( + num_update_indice, indices, update, output, tensor_desc, indice_desc); +} + +void TRTONNXScatterNDKernelLauncher_float(const float* data, const int* indices, + const float* update, const int* dims, + int nbDims, const int* indices_dims, + int indice_nbDims, float* output, + cudaStream_t stream) { + TRTONNXScatterNDKernelLauncher(data, indices, update, dims, nbDims, + indices_dims, indice_nbDims, output, + stream); +} + +void TRTONNXScatterNDKernelLauncher_int32(const int* data, const int* indices, + const int* update, const int* dims, + int nbDims, const int* indices_dims, + int indice_nbDims, int* output, + cudaStream_t stream) { + TRTONNXScatterNDKernelLauncher(data, indices, update, dims, nbDims, + indices_dims, indice_nbDims, output, + stream); +} diff --git a/backend_ops/tensorrt/trt_plugin.cpp b/backend_ops/tensorrt/trt_plugin.cpp new file mode 100644 index 000000000..208b1be4f --- /dev/null +++ b/backend_ops/tensorrt/trt_plugin.cpp @@ -0,0 +1,7 @@ +#include "scatternd/trt_scatternd.hpp" + +REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); + +extern "C" { +bool initLibMMCVInferPlugins() { return true; } +} // extern "C" diff --git a/mmdeploy/utils/function_rewriter.py b/mmdeploy/utils/function_rewriter.py index 0e8f305ef..faa6a492d 100644 --- a/mmdeploy/utils/function_rewriter.py +++ b/mmdeploy/utils/function_rewriter.py @@ -34,7 +34,6 @@ class FuncCaller(object): # caller decorator def register_rewriter(func_name, backend='default', **kwargs): - def wrap(func): func_args = dict(func_name=func_name, backend=backend, func=func) func_args.update(kwargs) @@ -49,7 +48,6 @@ FUNCTION_REWRITERS.register_rewriter = register_rewriter def apply_rewriter(regist_func): - def wrapper(*args, **kwargs): return regist_func(*args, **kwargs) @@ -57,13 +55,14 @@ def apply_rewriter(regist_func): class RewriterHook(object): - def __init__(self, regist_name, cfg, **kwargs): func_name, backend = regist_name.split('@') self.func_name = func_name self.backend = backend - self.regist_func = FUNCTION_REWRITERS.build( - func_name, backend=self.backend, cfg=cfg, **kwargs) + self.regist_func = FUNCTION_REWRITERS.build(func_name, + backend=self.backend, + cfg=cfg, + **kwargs) try: self.origin_func = eval_with_import(self.func_name) except Exception: @@ -93,7 +92,6 @@ class RewriterHook(object): class RewriterContext(object): - def __init__(self, cfg, backend='default', **kwargs): self.cfg = cfg func_backend_dict = {} diff --git a/mmdeploy/utils/module_rewriter.py b/mmdeploy/utils/module_rewriter.py index 72c507a58..ccc585b39 100644 --- a/mmdeploy/utils/module_rewriter.py +++ b/mmdeploy/utils/module_rewriter.py @@ -1,7 +1,9 @@ -from mmcv.utils import Registry -from .register_utils import eval_with_import from copy import deepcopy +from mmcv.utils import Registry + +from .register_utils import eval_with_import + def build_rewrite_module(module, cfg, backend, registry, **kwargs): @@ -22,7 +24,6 @@ def build_rewrite_module(module, cfg, backend, registry, **kwargs): class RewriteModuleRegistry(Registry): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._module_eval_dict = dict() @@ -54,16 +55,18 @@ class RewriteModuleRegistry(Registry): # create register -MODULE_REWRITERS = RewriteModuleRegistry( - 'module_rewriters', build_func=build_rewrite_module, scope='.') +MODULE_REWRITERS = RewriteModuleRegistry('module_rewriters', + build_func=build_rewrite_module, + scope='.') def patch_model(model, cfg, backend='default', **kwargs): - def _patch_impl(model, cfg, **kwargs): for name, module in model.named_children(): model._modules[name] = _patch_impl(module, cfg, **kwargs) - return MODULE_REWRITERS.build( - module=model, cfg=cfg, backend=backend, **kwargs) + return MODULE_REWRITERS.build(module=model, + cfg=cfg, + backend=backend, + **kwargs) return _patch_impl(deepcopy(model), cfg, **kwargs) diff --git a/mmdeploy/utils/symbolic_register.py b/mmdeploy/utils/symbolic_register.py index 11d47828d..ac2a325aa 100644 --- a/mmdeploy/utils/symbolic_register.py +++ b/mmdeploy/utils/symbolic_register.py @@ -64,13 +64,11 @@ def register_symbolic(func_name, is_pytorch=False, arg_descriptors=None, **kwargs): - def wrapper(symbolic_impl): - symbolic_args = dict( - func_name=func_name, - backend=backend, - symbolic=symbolic_impl, - arg_descriptors=arg_descriptors) + symbolic_args = dict(func_name=func_name, + backend=backend, + symbolic=symbolic_impl, + arg_descriptors=arg_descriptors) symbolic_args.update(kwargs) wrapper_name = '@'.join([func_name, backend, str(is_pytorch)]) wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args) diff --git a/tests/test_utils/test_register.py b/tests/test_utils/test_register.py index dda60599e..f394b7d11 100644 --- a/tests/test_utils/test_register.py +++ b/tests/test_utils/test_register.py @@ -1,6 +1,7 @@ -import torch import os +import torch + def test_function_rewriter(): from mmdeploy.utils import FUNCTION_REWRITERS, RewriterContext @@ -8,8 +9,8 @@ def test_function_rewriter(): x = torch.tensor([1, 2, 3, 4, 5]) y = torch.tensor([2, 4, 6, 8, 10]) - @FUNCTION_REWRITERS.register_rewriter( - func_name='torch.add', backend='tensorrt') + @FUNCTION_REWRITERS.register_rewriter(func_name='torch.add', + backend='tensorrt') def sub_func(rewriter, x, y): return x - y @@ -28,8 +29,8 @@ def test_function_rewriter(): # replace should not happen with wrong backend torch.testing.assert_allclose(result, x + y) - @FUNCTION_REWRITERS.register_rewriter( - func_name='torch.Tensor.add', backend='default') + @FUNCTION_REWRITERS.register_rewriter(func_name='torch.Tensor.add', + backend='default') def mul_func_class(rewriter, x, y): return x * y @@ -55,7 +56,6 @@ def test_module_rewriter(): @MODULE_REWRITERS.register_rewrite_module( module_type='torchvision.models.resnet.Bottleneck', backend='tensorrt') class BottleneckWrapper(torch.nn.Module): - def __init__(self, module, cfg, **kwargs): super().__init__() self.module = module @@ -91,7 +91,6 @@ def test_symbolic_register(): import onnx class TestFunc(Function): - @staticmethod def symbolic(g, x, val): return g.op('mmcv::symbolic_old', x, val_i=val) @@ -109,18 +108,18 @@ def test_symbolic_register(): def symbolic_testfunc_default(symbolic_wrapper, g, x, val): return g.op('mmcv::symbolic_testfunc_default', x, val_i=val) - @SYMBOLICS_REGISTER.register_symbolic( - 'mmdeploy.TestFunc', backend='tensorrt') + @SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc', + backend='tensorrt') def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val): return g.op('mmcv::symbolic_testfunc_tensorrt', x, val_i=val) - @SYMBOLICS_REGISTER.register_symbolic( - 'cummax', is_pytorch=True, arg_descriptors=['v', 'i']) + @SYMBOLICS_REGISTER.register_symbolic('cummax', + is_pytorch=True, + arg_descriptors=['v', 'i']) def symbolic_cummax(symbolic_wrapper, g, input, dim): return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2) class TestModel(torch.nn.Module): - def __init__(self): super().__init__()