mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add backend plugin build system
This commit is contained in:
parent
f90ebf8c2c
commit
1e47821e49
@ -1,2 +1,2 @@
|
||||
[settings]
|
||||
known_third_party =
|
||||
known_third_party = mmcv,setuptools,torch
|
||||
|
15
CMakeLists.txt
Normal file
15
CMakeLists.txt
Normal file
@ -0,0 +1,15 @@
|
||||
cmake_minimum_required (VERSION 3.10)
|
||||
project (mmdeploy_backend_ops)
|
||||
|
||||
# TensorRT config
|
||||
|
||||
# enable tensorrt
|
||||
option(BUILD_TENSORRT_OPS "enable tensorrt ops" OFF)
|
||||
# TensorRT search path
|
||||
if (BUILD_TENSORRT_OPS)
|
||||
if (NOT DEFINED TENSORRT_DIR)
|
||||
set(TENSORRT_DIR $ENV{TENSORRT_DIR})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory (backend_ops)
|
13
backend_ops/CMakeLists.txt
Normal file
13
backend_ops/CMakeLists.txt
Normal file
@ -0,0 +1,13 @@
|
||||
add_definitions(-std=c++11)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
|
||||
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
|
||||
# build TensorRT ops
|
||||
if (BUILD_TENSORRT_OPS)
|
||||
message("Build TensorRT custom ops.")
|
||||
add_subdirectory (tensorrt)
|
||||
endif()
|
61
backend_ops/tensorrt/CMakeLists.txt
Normal file
61
backend_ops/tensorrt/CMakeLists.txt
Normal file
@ -0,0 +1,61 @@
|
||||
set(TARGET_NAME mmlab_tensorrt_ops)
|
||||
set(SHARED_TARGET ${TARGET_NAME})
|
||||
|
||||
# cuda
|
||||
FIND_PACKAGE(CUDA REQUIRED)
|
||||
INCLUDE_DIRECTORIES(/usr/local/cuda/include)
|
||||
enable_language(CUDA)
|
||||
|
||||
# tensorrt
|
||||
find_path(TENSORRT_INCLUDE_DIR NvInfer.h
|
||||
HINTS ${TENSORRT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES include)
|
||||
if (TENSORRT_INCLUDE_DIR)
|
||||
MESSAGE(STATUS " Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}")
|
||||
else()
|
||||
MESSAGE(ERROR " Cannot found TensorRT headers")
|
||||
endif()
|
||||
|
||||
find_library(TENSORRT_LIBRARY_INFER nvinfer
|
||||
HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES lib lib64 lib/x64)
|
||||
find_library(TENSORRT_LIBRARY_PARSERS nvparsers
|
||||
HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES lib lib64 lib/x64)
|
||||
find_library(TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin
|
||||
HINTS ${TENSORRT_DIR} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES lib lib64 lib/x64)
|
||||
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER}
|
||||
${TENSORRT_LIBRARY_PARSERS}
|
||||
${TENSORRT_LIBRARY_INFER_PLUGIN}
|
||||
)
|
||||
if (TENSORRT_LIBRARY_INFER AND TENSORRT_LIBRARY_PARSERS AND TENSORRT_LIBRARY_INFER_PLUGIN)
|
||||
MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}")
|
||||
else()
|
||||
MESSAGE(ERROR " Cannot found TensorRT libs")
|
||||
endif()
|
||||
find_package_handle_standard_args(
|
||||
TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIBRARY)
|
||||
if(NOT TENSORRT_FOUND)
|
||||
message(ERROR " Cannot find TensorRT library.")
|
||||
endif()
|
||||
INCLUDE_DIRECTORIES(${TENSORRT_INCLUDE_DIR})
|
||||
|
||||
# include common
|
||||
include_directories(common)
|
||||
add_subdirectory(common_impl)
|
||||
|
||||
# add plugin source
|
||||
set(PLUGIN_LISTS scatternd)
|
||||
|
||||
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
|
||||
add_subdirectory(${PLUGIN_ITER})
|
||||
endforeach(PLUGIN_ITER)
|
||||
|
||||
list(APPEND PLUGIN_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/trt_plugin.cpp")
|
||||
|
||||
set(INFER_PLUGIN_LIB ${TENSORRT_LIBRARY})
|
||||
|
||||
cuda_add_library(${SHARED_TARGET} SHARED ${BACKEND_OPS_SRCS})
|
||||
target_link_libraries(${SHARED_TARGET} ${INFER_PLUGIN_LIB})
|
||||
target_include_directories(${SHARED_TARGET} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/common)
|
110
backend_ops/tensorrt/common/common_cuda_helper.hpp
Normal file
110
backend_ops/tensorrt/common/common_cuda_helper.hpp
Normal file
@ -0,0 +1,110 @@
|
||||
#ifndef COMMON_CUDA_HELPER
|
||||
#define COMMON_CUDA_HELPER
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
#define THREADS_PER_BLOCK 512
|
||||
|
||||
inline int GET_BLOCKS(const int N) {
|
||||
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
||||
int max_block_num = 4096;
|
||||
return min(optimal_block_num, max_block_num);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T bilinear_interpolate(const T* input, const int height,
|
||||
const int width, T y, T x,
|
||||
const int index /* index for debug only*/) {
|
||||
// deal with cases that inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
|
||||
|
||||
if (y <= 0) y = 0;
|
||||
if (x <= 0) x = 0;
|
||||
|
||||
int y_low = (int)y;
|
||||
int x_low = (int)x;
|
||||
int y_high;
|
||||
int x_high;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T)y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T)x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
// do bilinear interpolation
|
||||
T v1 = input[y_low * width + x_low];
|
||||
T v2 = input[y_low * width + x_high];
|
||||
T v3 = input[y_high * width + x_low];
|
||||
T v4 = input[y_high * width + x_high];
|
||||
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void bilinear_interpolate_gradient(
|
||||
const int height, const int width, T y, T x, T& w1, T& w2, T& w3, T& w4,
|
||||
int& x_low, int& x_high, int& y_low, int& y_high,
|
||||
const int index /* index for debug only*/) {
|
||||
// deal with cases that inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
||||
// empty
|
||||
w1 = w2 = w3 = w4 = 0.;
|
||||
x_low = x_high = y_low = y_high = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
if (y <= 0) y = 0;
|
||||
if (x <= 0) x = 0;
|
||||
|
||||
y_low = (int)y;
|
||||
x_low = (int)x;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T)y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T)x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
|
||||
// reference in forward
|
||||
// T v1 = input[y_low * width + x_low];
|
||||
// T v2 = input[y_low * width + x_high];
|
||||
// T v3 = input[y_high * width + x_low];
|
||||
// T v4 = input[y_high * width + x_high];
|
||||
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
|
||||
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
return;
|
||||
}
|
||||
#endif // COMMON_CUDA_HELPER
|
30
backend_ops/tensorrt/common/trt_cuda_helper.cuh
Normal file
30
backend_ops/tensorrt/common/trt_cuda_helper.cuh
Normal file
@ -0,0 +1,30 @@
|
||||
#ifndef TRT_CUDA_HELPER_HPP
|
||||
#define TRT_CUDA_HELPER_HPP
|
||||
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
#define cudaCheckError() \
|
||||
{ \
|
||||
cudaError_t e = cudaGetLastError(); \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(0); \
|
||||
} \
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a view of the original tensor with its dimensions permuted.
|
||||
*
|
||||
* @param[out] dst pointer to the destination tensor
|
||||
* @param[in] src pointer to the source tensor
|
||||
* @param[in] src_size shape of the src tensor
|
||||
* @param[in] permute The desired ordering of dimensions
|
||||
* @param[in] src_dim dim of src tensor
|
||||
* @param[in] stream cuda stream handle
|
||||
*/
|
||||
template <class scalar_t>
|
||||
void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
|
||||
int *permute, int src_dim, cudaStream_t stream = 0);
|
||||
|
||||
#endif // TRT_CUDA_HELPER_HPP
|
41
backend_ops/tensorrt/common/trt_plugin_helper.hpp
Normal file
41
backend_ops/tensorrt/common/trt_plugin_helper.hpp
Normal file
@ -0,0 +1,41 @@
|
||||
#ifndef TRT_PLUGIN_HELPER_HPP
|
||||
#define TRT_PLUGIN_HELPER_HPP
|
||||
#include <stdexcept>
|
||||
|
||||
#include "NvInferPlugin.h"
|
||||
|
||||
namespace mmlab {
|
||||
|
||||
const int MAXTENSORDIMS = 10;
|
||||
|
||||
struct TensorDesc {
|
||||
int shape[MAXTENSORDIMS];
|
||||
int stride[MAXTENSORDIMS];
|
||||
int dim;
|
||||
};
|
||||
|
||||
inline unsigned int getElementSize(nvinfer1::DataType t) {
|
||||
switch (t) {
|
||||
case nvinfer1::DataType::kINT32:
|
||||
return 4;
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return 4;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return 2;
|
||||
// case nvinfer1::DataType::kBOOL:
|
||||
case nvinfer1::DataType::kINT8:
|
||||
return 1;
|
||||
default:
|
||||
throw std::runtime_error("Invalid DataType.");
|
||||
}
|
||||
throw std::runtime_error("Invalid DataType.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) {
|
||||
return size_t((origin_size + aligned_number - 1) / aligned_number) *
|
||||
aligned_number;
|
||||
}
|
||||
|
||||
} // namespace mmlab
|
||||
#endif // TRT_PLUGIN_HELPER_HPP
|
105
backend_ops/tensorrt/common/trt_serialize.hpp
Normal file
105
backend_ops/tensorrt/common/trt_serialize.hpp
Normal file
@ -0,0 +1,105 @@
|
||||
// Modified from:
|
||||
// https://github.com/NVIDIA/TensorRT/blob/master/plugin/common/serialize.hpp
|
||||
|
||||
#ifndef TRT_SERIALIZE_HPP
|
||||
#define TRT_SERIALIZE_HPP
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
using std::cerr;
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
template <typename T>
|
||||
inline void serialize_value(void** buffer, T const& value);
|
||||
|
||||
template <typename T>
|
||||
inline void deserialize_value(void const** buffer, size_t* buffer_size,
|
||||
T* value);
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct Serializer {};
|
||||
|
||||
template <typename T>
|
||||
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
|
||||
std::is_enum<T>::value ||
|
||||
std::is_pod<T>::value>::type> {
|
||||
static size_t serialized_size(T const& value) { return sizeof(T); }
|
||||
static void serialize(void** buffer, T const& value) {
|
||||
::memcpy(*buffer, &value, sizeof(T));
|
||||
reinterpret_cast<char*&>(*buffer) += sizeof(T);
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size, T* value) {
|
||||
assert(*buffer_size >= sizeof(T));
|
||||
::memcpy(value, *buffer, sizeof(T));
|
||||
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
|
||||
*buffer_size -= sizeof(T);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Serializer<const char*> {
|
||||
static size_t serialized_size(const char* value) { return strlen(value) + 1; }
|
||||
static void serialize(void** buffer, const char* value) {
|
||||
::strcpy(static_cast<char*>(*buffer), value);
|
||||
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size,
|
||||
const char** value) {
|
||||
*value = static_cast<char const*>(*buffer);
|
||||
size_t data_size = strnlen(*value, *buffer_size) + 1;
|
||||
assert(*buffer_size >= data_size);
|
||||
reinterpret_cast<char const*&>(*buffer) += data_size;
|
||||
*buffer_size -= data_size;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Serializer<std::vector<T>,
|
||||
typename std::enable_if<std::is_arithmetic<T>::value ||
|
||||
std::is_enum<T>::value ||
|
||||
std::is_pod<T>::value>::type> {
|
||||
static size_t serialized_size(std::vector<T> const& value) {
|
||||
return sizeof(value.size()) + value.size() * sizeof(T);
|
||||
}
|
||||
static void serialize(void** buffer, std::vector<T> const& value) {
|
||||
serialize_value(buffer, value.size());
|
||||
size_t nbyte = value.size() * sizeof(T);
|
||||
::memcpy(*buffer, value.data(), nbyte);
|
||||
reinterpret_cast<char*&>(*buffer) += nbyte;
|
||||
}
|
||||
static void deserialize(void const** buffer, size_t* buffer_size,
|
||||
std::vector<T>* value) {
|
||||
size_t size;
|
||||
deserialize_value(buffer, buffer_size, &size);
|
||||
value->resize(size);
|
||||
size_t nbyte = value->size() * sizeof(T);
|
||||
assert(*buffer_size >= nbyte);
|
||||
::memcpy(value->data(), *buffer, nbyte);
|
||||
reinterpret_cast<char const*&>(*buffer) += nbyte;
|
||||
*buffer_size -= nbyte;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
inline size_t serialized_size(T const& value) {
|
||||
return Serializer<T>::serialized_size(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void serialize_value(void** buffer, T const& value) {
|
||||
return Serializer<T>::serialize(buffer, value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void deserialize_value(void const** buffer, size_t* buffer_size,
|
||||
T* value) {
|
||||
return Serializer<T>::deserialize(buffer, buffer_size, value);
|
||||
}
|
||||
#endif // TRT_SERIALIZE_HPP
|
4
backend_ops/tensorrt/common_impl/CMakeLists.txt
Normal file
4
backend_ops/tensorrt/common_impl/CMakeLists.txt
Normal file
@ -0,0 +1,4 @@
|
||||
file(GLOB OPS_SRCS *.cpp *.cu)
|
||||
file(GLOB OPS_HEADS *.h *.hpp *.cuh)
|
||||
set(BACKEND_OPS_SRCS ${BACKEND_OPS_SRCS} ${OPS_SRCS} ${OPS_HEADS})
|
||||
set(BACKEND_OPS_SRCS ${BACKEND_OPS_SRCS} PARENT_SCOPE)
|
66
backend_ops/tensorrt/common_impl/trt_cuda_helper.cu
Normal file
66
backend_ops/tensorrt/common_impl/trt_cuda_helper.cu
Normal file
@ -0,0 +1,66 @@
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "trt_cuda_helper.cuh"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
using mmlab::TensorDesc;
|
||||
|
||||
template <class scalar_t>
|
||||
__global__ void copy_permute_kernel(scalar_t *dst, const scalar_t *src, int n,
|
||||
TensorDesc ts_src_stride,
|
||||
TensorDesc ts_dst_stride,
|
||||
TensorDesc ts_permute) {
|
||||
const int src_dim = ts_src_stride.dim;
|
||||
int *src_stride = &(ts_src_stride.stride[0]);
|
||||
int *dst_stride = &(ts_dst_stride.stride[0]);
|
||||
int *permute = &(ts_permute.shape[0]);
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
size_t dst_index = index;
|
||||
size_t src_index = 0;
|
||||
for (int i = 0; i < src_dim; ++i) {
|
||||
int dim_index = dst_index / dst_stride[i];
|
||||
dst_index = dst_index % dst_stride[i];
|
||||
src_index += dim_index * src_stride[permute[i]];
|
||||
}
|
||||
dst[index] = src[src_index];
|
||||
}
|
||||
}
|
||||
|
||||
template <class scalar_t>
|
||||
void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
|
||||
int *permute, int src_dim, cudaStream_t stream) {
|
||||
size_t copy_size = 1;
|
||||
TensorDesc ts_permute;
|
||||
memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int));
|
||||
|
||||
TensorDesc ts_src_stride;
|
||||
TensorDesc ts_dst_stride;
|
||||
ts_src_stride.dim = src_dim;
|
||||
ts_dst_stride.dim = src_dim;
|
||||
int *src_stride = &(ts_src_stride.stride[0]);
|
||||
int *dst_stride = &(ts_dst_stride.stride[0]);
|
||||
int *dst_size = &(ts_dst_stride.shape[0]);
|
||||
src_stride[src_dim - 1] = 1;
|
||||
dst_stride[src_dim - 1] = 1;
|
||||
|
||||
for (int i = src_dim - 1; i >= 0; --i) {
|
||||
dst_size[i] = src_size[permute[i]];
|
||||
if (i < src_dim - 1) {
|
||||
src_stride[i] = src_stride[i + 1] * src_size[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = src_dim - 1; i >= 0; --i) {
|
||||
copy_size *= dst_size[i];
|
||||
if (i < src_dim - 1) {
|
||||
dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
copy_permute_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(copy_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
dst, src, copy_size, ts_src_stride, ts_dst_stride, ts_permute);
|
||||
}
|
||||
|
||||
template void memcpyPermute<float>(float *dst, const float *src, int *src_size,
|
||||
int *permute, int src_dim,
|
||||
cudaStream_t stream);
|
4
backend_ops/tensorrt/scatternd/CMakeLists.txt
Normal file
4
backend_ops/tensorrt/scatternd/CMakeLists.txt
Normal file
@ -0,0 +1,4 @@
|
||||
file(GLOB OPS_SRCS *.cpp *.cu)
|
||||
file(GLOB OPS_HEADS *.h *.hpp *.cuh)
|
||||
set(BACKEND_OPS_SRCS ${BACKEND_OPS_SRCS} ${OPS_SRCS} ${OPS_HEADS})
|
||||
set(BACKEND_OPS_SRCS ${BACKEND_OPS_SRCS} PARENT_SCOPE)
|
206
backend_ops/tensorrt/scatternd/trt_scatternd.cpp
Normal file
206
backend_ops/tensorrt/scatternd/trt_scatternd.cpp
Normal file
@ -0,0 +1,206 @@
|
||||
#include "trt_scatternd.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
extern void TRTONNXScatterNDKernelLauncher_float(
|
||||
const float *data, const int *indices, const float *update, const int *dims,
|
||||
int nbDims, const int *indices_dims, int indice_nbDims, float *output,
|
||||
cudaStream_t stream);
|
||||
|
||||
extern void TRTONNXScatterNDKernelLauncher_int32(
|
||||
const int *data, const int *indices, const int *update, const int *dims,
|
||||
int nbDims, const int *indices_dims, int indice_nbDims, int *output,
|
||||
cudaStream_t stream);
|
||||
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *PLUGIN_NAME{"ScatterND"};
|
||||
} // namespace
|
||||
|
||||
nvinfer1::PluginFieldCollection ONNXScatterNDDynamicCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField>
|
||||
ONNXScatterNDDynamicCreator::mPluginAttributes;
|
||||
|
||||
ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string &name)
|
||||
: mLayerName(name) {}
|
||||
|
||||
ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string name,
|
||||
const void *data, size_t length)
|
||||
: mLayerName(name) {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *ONNXScatterNDDynamic::clone() const {
|
||||
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(mLayerName);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs ONNXScatterNDDynamic::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
return inputs[0];
|
||||
}
|
||||
|
||||
bool ONNXScatterNDDynamic::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
if (pos < nbInputs) {
|
||||
switch (pos) {
|
||||
case 0:
|
||||
// data
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
|
||||
(inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
case 1:
|
||||
// indices
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
case 2:
|
||||
// updates
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
switch (pos - nbInputs) {
|
||||
case 0:
|
||||
// output
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ONNXScatterNDDynamic::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
|
||||
|
||||
size_t ONNXScatterNDDynamic::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ONNXScatterNDDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs,
|
||||
void *const *outputs, void *workSpace,
|
||||
cudaStream_t stream) {
|
||||
const int *dims = &(inputDesc[0].dims.d[0]);
|
||||
const int *indices_dims = &(inputDesc[1].dims.d[0]);
|
||||
int nbDims = inputDesc[0].dims.nbDims;
|
||||
int indice_nbDims = inputDesc[1].dims.nbDims;
|
||||
|
||||
const void *data = inputs[0];
|
||||
const void *indices = inputs[1];
|
||||
const void *update = inputs[2];
|
||||
void *output = outputs[0];
|
||||
|
||||
auto data_type = inputDesc[0].type;
|
||||
|
||||
switch (data_type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
TRTONNXScatterNDKernelLauncher_float(
|
||||
(float *)data, (int *)indices, (float *)update, dims, nbDims,
|
||||
indices_dims, indice_nbDims, (float *)output, stream);
|
||||
break;
|
||||
|
||||
case nvinfer1::DataType::kINT32:
|
||||
TRTONNXScatterNDKernelLauncher_int32(
|
||||
(int *)data, (int *)indices, (int *)update, dims, nbDims,
|
||||
indices_dims, indice_nbDims, (int *)output, stream);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType ONNXScatterNDDynamic::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *ONNXScatterNDDynamic::getPluginType() const { return PLUGIN_NAME; }
|
||||
|
||||
const char *ONNXScatterNDDynamic::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int ONNXScatterNDDynamic::getNbOutputs() const { return 1; }
|
||||
|
||||
int ONNXScatterNDDynamic::initialize() { return 0; }
|
||||
|
||||
void ONNXScatterNDDynamic::terminate() {}
|
||||
|
||||
size_t ONNXScatterNDDynamic::getSerializationSize() const { return 0; }
|
||||
|
||||
void ONNXScatterNDDynamic::serialize(void *buffer) const {}
|
||||
|
||||
void ONNXScatterNDDynamic::destroy() {
|
||||
// This gets called when the network containing plugin is destroyed
|
||||
delete this;
|
||||
}
|
||||
|
||||
void ONNXScatterNDDynamic::setPluginNamespace(const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *ONNXScatterNDDynamic::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
////////////////////// creator /////////////////////////////
|
||||
|
||||
ONNXScatterNDDynamicCreator::ONNXScatterNDDynamicCreator() {
|
||||
mPluginAttributes.clear();
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *ONNXScatterNDDynamicCreator::getPluginName() const {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *ONNXScatterNDDynamicCreator::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
const nvinfer1::PluginFieldCollection *
|
||||
ONNXScatterNDDynamicCreator::getFieldNames() {
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(name);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) {
|
||||
auto plugin = new ONNXScatterNDDynamic(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void ONNXScatterNDDynamicCreator::setPluginNamespace(const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *ONNXScatterNDDynamicCreator::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
98
backend_ops/tensorrt/scatternd/trt_scatternd.hpp
Normal file
98
backend_ops/tensorrt/scatternd/trt_scatternd.hpp
Normal file
@ -0,0 +1,98 @@
|
||||
#ifndef TRT_SCATTERND_HPP
|
||||
#define TRT_SCATTERND_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
class ONNXScatterNDDynamic : public nvinfer1::IPluginV2DynamicExt {
|
||||
public:
|
||||
ONNXScatterNDDynamic(const std::string &name);
|
||||
|
||||
ONNXScatterNDDynamic(const std::string name, const void *data, size_t length);
|
||||
|
||||
ONNXScatterNDDynamic() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
int initialize() override;
|
||||
void terminate() override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
void destroy() override;
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
const std::string mLayerName;
|
||||
std::string mNamespace;
|
||||
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
|
||||
using nvinfer1::IPluginV2DynamicExt::enqueue;
|
||||
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
|
||||
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
|
||||
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
|
||||
};
|
||||
|
||||
class ONNXScatterNDDynamicCreator : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
ONNXScatterNDDynamicCreator();
|
||||
|
||||
const char *getPluginName() const override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
|
||||
const nvinfer1::PluginFieldCollection *getFieldNames() override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
std::string mNamespace;
|
||||
};
|
||||
#endif // TRT_SCATTERND_HPP
|
92
backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu
Normal file
92
backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu
Normal file
@ -0,0 +1,92 @@
|
||||
#include <stdio.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "trt_cuda_helper.cuh"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
static int const threadsPerBlock = sizeof(unsigned long long int) * 8;
|
||||
|
||||
using mmlab::TensorDesc;
|
||||
|
||||
template <typename T>
|
||||
__global__ void onnx_scatternd_kernel(const int n, const int* indices,
|
||||
const T* update, T* output,
|
||||
TensorDesc tensor_desc,
|
||||
TensorDesc indice_desc) {
|
||||
const int indice_cols = indice_desc.shape[indice_desc.dim - 1];
|
||||
const int copy_stride = tensor_desc.stride[indice_cols - 1];
|
||||
const int* stride = &(tensor_desc.stride[0]);
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
int output_offset = 0;
|
||||
const int* indices_current = indices + index * indice_cols;
|
||||
for (int i = 0; i < indice_cols; ++i) {
|
||||
output_offset += stride[i] * indices_current[i];
|
||||
}
|
||||
memcpy(output + output_offset, update + index * copy_stride,
|
||||
copy_stride * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices,
|
||||
const T* update, const int* dims,
|
||||
int nbDims, const int* indices_dims,
|
||||
int indice_nbDims, T* output,
|
||||
cudaStream_t stream) {
|
||||
// fill tensordesc and initial
|
||||
TensorDesc tensor_desc;
|
||||
memset((void*)&tensor_desc, 0, sizeof(TensorDesc));
|
||||
tensor_desc.dim = nbDims;
|
||||
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
|
||||
tensor_desc.stride[nbDims - 1] = 1;
|
||||
for (int i = nbDims - 2; i >= 0; --i) {
|
||||
tensor_desc.shape[i] = dims[i];
|
||||
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
|
||||
}
|
||||
const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0];
|
||||
|
||||
TensorDesc indice_desc;
|
||||
memset((void*)&indice_desc, 0, sizeof(TensorDesc));
|
||||
indice_desc.dim = indice_nbDims;
|
||||
indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1];
|
||||
indice_desc.stride[indice_nbDims - 1] = 1;
|
||||
for (int i = indice_nbDims - 2; i >= 0; --i) {
|
||||
indice_desc.shape[i] = indices_dims[i];
|
||||
indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1];
|
||||
}
|
||||
|
||||
// output = np.copy(data)
|
||||
cudaMemcpyAsync(output, data, data_size * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice);
|
||||
|
||||
int num_update_indice = 1;
|
||||
for (int i = 0; i < indice_nbDims - 1; ++i) {
|
||||
num_update_indice *= indice_desc.shape[i];
|
||||
}
|
||||
// scatter
|
||||
const int col_block = DIVUP(num_update_indice, threadsPerBlock);
|
||||
onnx_scatternd_kernel<<<col_block, threadsPerBlock, 0, stream>>>(
|
||||
num_update_indice, indices, update, output, tensor_desc, indice_desc);
|
||||
}
|
||||
|
||||
void TRTONNXScatterNDKernelLauncher_float(const float* data, const int* indices,
|
||||
const float* update, const int* dims,
|
||||
int nbDims, const int* indices_dims,
|
||||
int indice_nbDims, float* output,
|
||||
cudaStream_t stream) {
|
||||
TRTONNXScatterNDKernelLauncher<float>(data, indices, update, dims, nbDims,
|
||||
indices_dims, indice_nbDims, output,
|
||||
stream);
|
||||
}
|
||||
|
||||
void TRTONNXScatterNDKernelLauncher_int32(const int* data, const int* indices,
|
||||
const int* update, const int* dims,
|
||||
int nbDims, const int* indices_dims,
|
||||
int indice_nbDims, int* output,
|
||||
cudaStream_t stream) {
|
||||
TRTONNXScatterNDKernelLauncher<int>(data, indices, update, dims, nbDims,
|
||||
indices_dims, indice_nbDims, output,
|
||||
stream);
|
||||
}
|
7
backend_ops/tensorrt/trt_plugin.cpp
Normal file
7
backend_ops/tensorrt/trt_plugin.cpp
Normal file
@ -0,0 +1,7 @@
|
||||
#include "scatternd/trt_scatternd.hpp"
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
|
||||
|
||||
extern "C" {
|
||||
bool initLibMMCVInferPlugins() { return true; }
|
||||
} // extern "C"
|
@ -34,7 +34,6 @@ class FuncCaller(object):
|
||||
|
||||
# caller decorator
|
||||
def register_rewriter(func_name, backend='default', **kwargs):
|
||||
|
||||
def wrap(func):
|
||||
func_args = dict(func_name=func_name, backend=backend, func=func)
|
||||
func_args.update(kwargs)
|
||||
@ -49,7 +48,6 @@ FUNCTION_REWRITERS.register_rewriter = register_rewriter
|
||||
|
||||
|
||||
def apply_rewriter(regist_func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return regist_func(*args, **kwargs)
|
||||
|
||||
@ -57,13 +55,14 @@ def apply_rewriter(regist_func):
|
||||
|
||||
|
||||
class RewriterHook(object):
|
||||
|
||||
def __init__(self, regist_name, cfg, **kwargs):
|
||||
func_name, backend = regist_name.split('@')
|
||||
self.func_name = func_name
|
||||
self.backend = backend
|
||||
self.regist_func = FUNCTION_REWRITERS.build(
|
||||
func_name, backend=self.backend, cfg=cfg, **kwargs)
|
||||
self.regist_func = FUNCTION_REWRITERS.build(func_name,
|
||||
backend=self.backend,
|
||||
cfg=cfg,
|
||||
**kwargs)
|
||||
try:
|
||||
self.origin_func = eval_with_import(self.func_name)
|
||||
except Exception:
|
||||
@ -93,7 +92,6 @@ class RewriterHook(object):
|
||||
|
||||
|
||||
class RewriterContext(object):
|
||||
|
||||
def __init__(self, cfg, backend='default', **kwargs):
|
||||
self.cfg = cfg
|
||||
func_backend_dict = {}
|
||||
|
@ -1,7 +1,9 @@
|
||||
from mmcv.utils import Registry
|
||||
from .register_utils import eval_with_import
|
||||
from copy import deepcopy
|
||||
|
||||
from mmcv.utils import Registry
|
||||
|
||||
from .register_utils import eval_with_import
|
||||
|
||||
|
||||
def build_rewrite_module(module, cfg, backend, registry, **kwargs):
|
||||
|
||||
@ -22,7 +24,6 @@ def build_rewrite_module(module, cfg, backend, registry, **kwargs):
|
||||
|
||||
|
||||
class RewriteModuleRegistry(Registry):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._module_eval_dict = dict()
|
||||
@ -54,16 +55,18 @@ class RewriteModuleRegistry(Registry):
|
||||
|
||||
|
||||
# create register
|
||||
MODULE_REWRITERS = RewriteModuleRegistry(
|
||||
'module_rewriters', build_func=build_rewrite_module, scope='.')
|
||||
MODULE_REWRITERS = RewriteModuleRegistry('module_rewriters',
|
||||
build_func=build_rewrite_module,
|
||||
scope='.')
|
||||
|
||||
|
||||
def patch_model(model, cfg, backend='default', **kwargs):
|
||||
|
||||
def _patch_impl(model, cfg, **kwargs):
|
||||
for name, module in model.named_children():
|
||||
model._modules[name] = _patch_impl(module, cfg, **kwargs)
|
||||
return MODULE_REWRITERS.build(
|
||||
module=model, cfg=cfg, backend=backend, **kwargs)
|
||||
return MODULE_REWRITERS.build(module=model,
|
||||
cfg=cfg,
|
||||
backend=backend,
|
||||
**kwargs)
|
||||
|
||||
return _patch_impl(deepcopy(model), cfg, **kwargs)
|
||||
|
@ -64,13 +64,11 @@ def register_symbolic(func_name,
|
||||
is_pytorch=False,
|
||||
arg_descriptors=None,
|
||||
**kwargs):
|
||||
|
||||
def wrapper(symbolic_impl):
|
||||
symbolic_args = dict(
|
||||
func_name=func_name,
|
||||
backend=backend,
|
||||
symbolic=symbolic_impl,
|
||||
arg_descriptors=arg_descriptors)
|
||||
symbolic_args = dict(func_name=func_name,
|
||||
backend=backend,
|
||||
symbolic=symbolic_impl,
|
||||
arg_descriptors=arg_descriptors)
|
||||
symbolic_args.update(kwargs)
|
||||
wrapper_name = '@'.join([func_name, backend, str(is_pytorch)])
|
||||
wrapper = type(wrapper_name, (SymbolicWrapper, ), symbolic_args)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def test_function_rewriter():
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS, RewriterContext
|
||||
@ -8,8 +9,8 @@ def test_function_rewriter():
|
||||
x = torch.tensor([1, 2, 3, 4, 5])
|
||||
y = torch.tensor([2, 4, 6, 8, 10])
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='torch.add', backend='tensorrt')
|
||||
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.add',
|
||||
backend='tensorrt')
|
||||
def sub_func(rewriter, x, y):
|
||||
return x - y
|
||||
|
||||
@ -28,8 +29,8 @@ def test_function_rewriter():
|
||||
# replace should not happen with wrong backend
|
||||
torch.testing.assert_allclose(result, x + y)
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='torch.Tensor.add', backend='default')
|
||||
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.Tensor.add',
|
||||
backend='default')
|
||||
def mul_func_class(rewriter, x, y):
|
||||
return x * y
|
||||
|
||||
@ -55,7 +56,6 @@ def test_module_rewriter():
|
||||
@MODULE_REWRITERS.register_rewrite_module(
|
||||
module_type='torchvision.models.resnet.Bottleneck', backend='tensorrt')
|
||||
class BottleneckWrapper(torch.nn.Module):
|
||||
|
||||
def __init__(self, module, cfg, **kwargs):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
@ -91,7 +91,6 @@ def test_symbolic_register():
|
||||
import onnx
|
||||
|
||||
class TestFunc(Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, x, val):
|
||||
return g.op('mmcv::symbolic_old', x, val_i=val)
|
||||
@ -109,18 +108,18 @@ def test_symbolic_register():
|
||||
def symbolic_testfunc_default(symbolic_wrapper, g, x, val):
|
||||
return g.op('mmcv::symbolic_testfunc_default', x, val_i=val)
|
||||
|
||||
@SYMBOLICS_REGISTER.register_symbolic(
|
||||
'mmdeploy.TestFunc', backend='tensorrt')
|
||||
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.TestFunc',
|
||||
backend='tensorrt')
|
||||
def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val):
|
||||
return g.op('mmcv::symbolic_testfunc_tensorrt', x, val_i=val)
|
||||
|
||||
@SYMBOLICS_REGISTER.register_symbolic(
|
||||
'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
|
||||
@SYMBOLICS_REGISTER.register_symbolic('cummax',
|
||||
is_pytorch=True,
|
||||
arg_descriptors=['v', 'i'])
|
||||
def symbolic_cummax(symbolic_wrapper, g, input, dim):
|
||||
return g.op('mmcv::cummax_default', input, dim_i=dim, outputs=2)
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user