From 66d5cddbdc8c85bd46a8fefa5733007d8573f4c4 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 1 Dec 2021 16:31:10 +0800 Subject: [PATCH] [Enhancement] Add bicubic resize plugin for tensorrt (#238) * save codes * enable export fake bicubic interpolate op to onnx * save codes * enable bicubic interpolate trt plugin * static export * enable visualize but need align acc * use torch bicubic upsample * add unit tests for bicubic interpolate * fix unit tests * change mmedit config * remove useless comments * remove useless comments * resolve comments * fix lint * clang-format Co-authored-by: grimoire --- .../trt_bicubic_interpolate.cpp | 203 ++++++++++++++++++ .../trt_bicubic_interpolate.hpp | 76 +++++++ .../trt_bicubic_interpolate_kernel.cu | 181 ++++++++++++++++ .../trt_bicubic_interpolate_kernel.hpp | 12 ++ ...solution_tensorrt_dynamic-32x32-512x512.py | 18 +- ...ion_tensorrt_fp16_dynamic-32x32-512x512.py | 18 +- ...resolution_tensorrt_fp16_static-256x256.py | 18 +- ...ion_tensorrt_int8_dynamic-32x32-512x512.py | 18 +- ...resolution_tensorrt_int8_static-256x256.py | 18 +- ...uper-resolution_tensorrt_static-256x256.py | 18 +- mmdeploy/codebase/mmedit/__init__.py | 1 - mmdeploy/codebase/mmedit/models/__init__.py | 2 - .../mmedit/models/backbones/__init__.py | 4 - .../codebase/mmedit/models/backbones/srcnn.py | 50 ----- mmdeploy/pytorch/functions/__init__.py | 6 +- mmdeploy/pytorch/functions/interpolate.py | 65 ++++++ tests/test_ops/test_ops.py | 55 +++++ 17 files changed, 655 insertions(+), 108 deletions(-) create mode 100644 backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp create mode 100644 backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp create mode 100644 backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu create mode 100644 backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp delete mode 100644 mmdeploy/codebase/mmedit/models/__init__.py delete mode 100644 mmdeploy/codebase/mmedit/models/backbones/__init__.py delete mode 100644 mmdeploy/codebase/mmedit/models/backbones/srcnn.py diff --git a/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp b/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp new file mode 100644 index 000000000..08264720f --- /dev/null +++ b/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp @@ -0,0 +1,203 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "trt_bicubic_interpolate.hpp" + +#include + +#include + +#include "trt_bicubic_interpolate_kernel.hpp" +#include "trt_plugin_helper.hpp" +#include "trt_serialize.hpp" +using namespace nvinfer1; + +namespace mmdeploy { +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"TRTBicubicInterpolate"}; +} // namespace + +TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string &name, + std::vector scale_factor, + bool align_corners) + : TRTPluginBase(name), + mScaleFactor(scale_factor), + mAlignCorners(align_corners) {} + +TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string name, + const void *data, size_t length) + : TRTPluginBase(name) { + deserialize_value(&data, &length, &mScaleFactor); + deserialize_value(&data, &length, &mAlignCorners); +} + +nvinfer1::IPluginV2DynamicExt *TRTBicubicInterpolate::clone() const + TRT_NOEXCEPT { + TRTBicubicInterpolate *plugin = + new TRTBicubicInterpolate(mLayerName, mScaleFactor, mAlignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs TRTBicubicInterpolate::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + auto height = exprBuilder.constant(mScaleFactor[0]); + auto width = exprBuilder.constant(mScaleFactor[1]); + auto d2 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], + *height); + auto d3 = + exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[3], *width); + ret.d[2] = d2; + ret.d[3] = d3; + + return ret; +} + +bool TRTBicubicInterpolate::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + if (pos == 0) { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + + } else { + return ioDesc[pos].type == ioDesc[0].type && + ioDesc[pos].format == ioDesc[0].format; + } +} + +void TRTBicubicInterpolate::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, + int nbOutputs) TRT_NOEXCEPT {} + +size_t TRTBicubicInterpolate::getWorkspaceSize( + const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT { + return 0; +} + +int TRTBicubicInterpolate::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, + void *const *outputs, void *workSpace, + cudaStream_t stream) TRT_NOEXCEPT { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + + int height_out = outputDesc[0].dims.d[2]; + int width_out = outputDesc[0].dims.d[3]; + const void *x = inputs[0]; + void *output = outputs[0]; + + // TODO: add fp16 support + auto data_type = inputDesc[0].type; + switch (data_type) { + case nvinfer1::DataType::kFLOAT: + bicubic_interpolate((float *)x, (float *)output, batch, channels, + height, width, height_out, width_out, + mAlignCorners, stream); + break; + default: + return 1; + break; + } + + return 0; +} + +nvinfer1::DataType TRTBicubicInterpolate::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *TRTBicubicInterpolate::getPluginType() const TRT_NOEXCEPT { + return PLUGIN_NAME; +} + +const char *TRTBicubicInterpolate::getPluginVersion() const TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +int TRTBicubicInterpolate::getNbOutputs() const TRT_NOEXCEPT { return 1; } + +size_t TRTBicubicInterpolate::getSerializationSize() const TRT_NOEXCEPT { + return serialized_size(mScaleFactor) + serialized_size(mAlignCorners); +} + +void TRTBicubicInterpolate::serialize(void *buffer) const TRT_NOEXCEPT { + serialize_value(&buffer, mScaleFactor); + serialize_value(&buffer, mAlignCorners); +} + +////////////////////// creator ///////////////////////////// + +TRTBicubicInterpolateCreator::TRTBicubicInterpolateCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("scale_factor")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *TRTBicubicInterpolateCreator::getPluginName() const TRT_NOEXCEPT { + return PLUGIN_NAME; +} + +const char *TRTBicubicInterpolateCreator::getPluginVersion() const + TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { + nvinfer1::Dims size{2, {1, 1}}; + std::vector scale_factor; + bool align_corners = 1; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("scale_factor") == 0) { + int data_size = (fc->fields[i].length); + if (data_size != 2) { + data_size = data_size / sizeof(float); + } + ASSERT(data_size == 2) + const float *data_start = static_cast(fc->fields[i].data); + scale_factor = std::vector(data_start, data_start + data_size); + } + + if (field_name.compare("align_corners") == 0) { + align_corners = static_cast(fc->fields[i].data)[0]; + } + } + + TRTBicubicInterpolate *plugin = + new TRTBicubicInterpolate(name, scale_factor, align_corners); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::deserializePlugin( + const char *name, const void *serialData, + size_t serialLength) TRT_NOEXCEPT { + auto plugin = new TRTBicubicInterpolate(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} +REGISTER_TENSORRT_PLUGIN(TRTBicubicInterpolateCreator); +} // namespace mmdeploy diff --git a/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp b/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp new file mode 100644 index 000000000..f78560485 --- /dev/null +++ b/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp @@ -0,0 +1,76 @@ +#ifndef TRT_BICUBIC_INTERPOLATE_HPP +#define TRT_BICUBIC_INTERPOLATE_HPP +#include + +#include +#include +#include + +#include "trt_plugin_base.hpp" +namespace mmdeploy { +class TRTBicubicInterpolate : public TRTPluginBase { + public: + TRTBicubicInterpolate(const std::string &name, + std::vector scale_factor, bool align_corners); + + TRTBicubicInterpolate(const std::string name, const void *data, + size_t length); + + TRTBicubicInterpolate() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char *getPluginType() const TRT_NOEXCEPT override; + const char *getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void *buffer) const TRT_NOEXCEPT override; + + private: + std::vector mScaleFactor; + bool mAlignCorners; +}; + +class TRTBicubicInterpolateCreator : public TRTPluginCreatorBase { + public: + TRTBicubicInterpolateCreator(); + + const char *getPluginName() const TRT_NOEXCEPT override; + + const char *getPluginVersion() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2 *createPlugin(const char *name, + const nvinfer1::PluginFieldCollection *fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *deserializePlugin( + const char *name, const void *serialData, + size_t serialLength) TRT_NOEXCEPT override; +}; +} // namespace mmdeploy +#endif // TRT_BICUBIC_INTERPOLATE_HPP diff --git a/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu b/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu new file mode 100644 index 000000000..6fe6f3b2c --- /dev/null +++ b/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu @@ -0,0 +1,181 @@ +// Modified from +// https://github.com/pytorch/pytorch/blob/6adbe044e39c8e8db158d91e151aa6dead6e9aa4/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu +#include +#include + +#include +#include +#include + +#include "common_cuda_helper.hpp" +#include "trt_bicubic_interpolate_kernel.hpp" + +// Based on +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +template +__device__ __forceinline__ static scalar_t cubic_convolution1(scalar_t x, + scalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +__device__ __forceinline__ static scalar_t cubic_convolution2(scalar_t x, + scalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +__device__ __forceinline__ static void get_cubic_upsample_coefficients( + scalar_t coeffs[4], scalar_t t) { + scalar_t A = -0.75; + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + scalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +__device__ __forceinline__ static scalar_t cubic_interp1d( + scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) { + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +/* Used by UpSampleBicubic2d.cu */ +template +__device__ __forceinline__ static scalar_t upsample_get_value_bounded( + const scalar_t *data, int batch, int channel, int batchsize, int channels, + int height, int width, int y, int x) { + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + return data[batch * channels * height * width + channel * height * width + + access_y * width + access_x]; +} + +template +__device__ __forceinline__ scalar_t area_pixel_compute_source_index( + scalar_t scale, int64_t dst_index, bool align_corners, bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + scalar_t src_idx = scale * (dst_index + 0.5) - 0.5; + // [Note] Follow Opencv resize logic: + // We allow negative src_idx here and later will use + // dx = src_idx - floorf(src_idx) + // to compute the "distance"(which affects weights). + // For linear modes, weight distribution doesn't matter + // for negative indices as they use 2 pixels to interpolate. + // For example, [-1, 0], they both use pixel 0 value so it + // doesn't affect if we bound the src_idx to 0 or not. + // TODO: Our current linear mode impls use unbound indices + // where we should and then remove this cubic flag. + // This matters in cubic mode, as we might need [-1, 0, 1, 2] + // to interpolate and the weights can be affected. + return (!cubic && src_idx < 0) ? scalar_t(0) : src_idx; + } +} + +// cubic interpolation pytorch +template +__global__ void resize_cubic_kernel_torch( + const int num_elements, const scalar_t *src, const int batchsize, + const int channels, int srcWidth, int srcHeight, scalar_t *dst, + int dstWidth, int dstHeight, bool align_corners, float height_scale, + float width_scale) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index >= num_elements) { + return; + } + // Special case: input and output are the same size, just copy + const int output_x = index % dstWidth; + const int output_y = index / dstWidth; + + if (srcHeight == dstHeight && srcWidth == dstWidth) { + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; c++) { + const scalar_t val = + src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + + output_y * dstWidth + output_x]; + dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + + output_y * dstWidth + output_x] = val; + } + } + return; + } + // Interpolation kernel + scalar_t real_x = area_pixel_compute_source_index( + width_scale, output_x, align_corners, /*cubic=*/true); + int in_x = floorf(real_x); + scalar_t t_x = real_x - in_x; + + scalar_t real_y = area_pixel_compute_source_index( + height_scale, output_y, align_corners, /*cubic=*/true); + int in_y = floorf(real_y); + scalar_t t_y = real_y - in_y; + + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; c++) { + scalar_t coefficients[4]; + + for (int k = 0; k < 4; k++) { + coefficients[k] = cubic_interp1d( + upsample_get_value_bounded(src, n, c, batchsize, channels, + srcHeight, srcWidth, in_y - 1 + k, + in_x - 1), + upsample_get_value_bounded(src, n, c, batchsize, channels, + srcHeight, srcWidth, in_y - 1 + k, + in_x + 0), + upsample_get_value_bounded(src, n, c, batchsize, channels, + srcHeight, srcWidth, in_y - 1 + k, + in_x + 1), + upsample_get_value_bounded(src, n, c, batchsize, channels, + srcHeight, srcWidth, in_y - 1 + k, + in_x + 2), + t_x); + } + + dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + + output_y * dstWidth + output_x] = + scalar_t(cubic_interp1d(coefficients[0], coefficients[1], + coefficients[2], coefficients[3], t_y)); + } + } +} + +template +void resizeGPU(const scalar_t *pIn_d, scalar_t *pOut_d, int batch, int channels, + int srcWidth, int srcHeight, int dstWidth, int dstHeight, + bool align_corners, cudaStream_t stream) { + float height_scale = float(srcHeight) / dstHeight; + float width_scale = float(srcWidth) / dstWidth; + if (align_corners && dstWidth > 1 && dstHeight > 1) { + height_scale = (float)(srcHeight - 1) / (dstHeight - 1); + width_scale = (float)(srcWidth - 1) / (dstWidth - 1); + } + int n = batch * dstWidth * dstHeight * channels; + resize_cubic_kernel_torch<<>>( + dstWidth * dstHeight, pIn_d, batch, channels, srcWidth, srcHeight, pOut_d, + dstWidth, dstHeight, align_corners, height_scale, width_scale); +} + +template +void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch, + int channels, int in_height, int in_width, + int out_height, int out_width, bool align_corners, + cudaStream_t stream) { + resizeGPU(input, output, batch, channels, in_width, in_height, out_width, + out_height, align_corners, stream); +} + +template void bicubic_interpolate(const float *input, float *output, + int batch, int channels, int in_height, + int in_width, int out_height, + int out_width, bool align_corners, + cudaStream_t stream); diff --git a/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp b/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp new file mode 100644 index 000000000..a2e54d042 --- /dev/null +++ b/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp @@ -0,0 +1,12 @@ +#ifndef TRT_BICUBIC_INTERPOLATE_KERNEL_HPP +#define TRT_BICUBIC_INTERPOLATE_KERNEL_HPP +#include + +#include "common_cuda_helper.hpp" + +template +void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch, + int channels, int in_height, int in_width, + int out_height, int out_width, bool align_corners, + cudaStream_t stream); +#endif // TRT_BICUBIC_INTERPOLATE_KERNEL_HPP diff --git a/configs/mmedit/super-resolution/super-resolution_tensorrt_dynamic-32x32-512x512.py b/configs/mmedit/super-resolution/super-resolution_tensorrt_dynamic-32x32-512x512.py index 932cabf9e..66ae67ce8 100644 --- a/configs/mmedit/super-resolution/super-resolution_tensorrt_dynamic-32x32-512x512.py +++ b/configs/mmedit/super-resolution/super-resolution_tensorrt_dynamic-32x32-512x512.py @@ -1,9 +1,11 @@ _base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/tensorrt.py'] -backend_config = dict(model_inputs=[ - dict( - input_shapes=dict( - input=dict( - min_shape=[1, 3, 32, 32], - opt_shape=[1, 3, 256, 256], - max_shape=[1, 3, 512, 512]))) -]) +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 32, 32], + opt_shape=[1, 3, 256, 256], + max_shape=[1, 3, 512, 512]))) + ]) diff --git a/configs/mmedit/super-resolution/super-resolution_tensorrt_fp16_dynamic-32x32-512x512.py b/configs/mmedit/super-resolution/super-resolution_tensorrt_fp16_dynamic-32x32-512x512.py index faa9013b5..7a8bcee9a 100644 --- a/configs/mmedit/super-resolution/super-resolution_tensorrt_fp16_dynamic-32x32-512x512.py +++ b/configs/mmedit/super-resolution/super-resolution_tensorrt_fp16_dynamic-32x32-512x512.py @@ -1,11 +1,13 @@ _base_ = [ './super-resolution_dynamic.py', '../../_base_/backends/tensorrt_fp16.py' ] -backend_config = dict(model_inputs=[ - dict( - input_shapes=dict( - input=dict( - min_shape=[1, 3, 32, 32], - opt_shape=[1, 3, 256, 256], - max_shape=[1, 3, 512, 512]))) -]) +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 32, 32], + opt_shape=[1, 3, 256, 256], + max_shape=[1, 3, 512, 512]))) + ]) diff --git a/configs/mmedit/super-resolution/super-resolution_tensorrt_fp16_static-256x256.py b/configs/mmedit/super-resolution/super-resolution_tensorrt_fp16_static-256x256.py index d828a38e7..141d9b1b1 100644 --- a/configs/mmedit/super-resolution/super-resolution_tensorrt_fp16_static-256x256.py +++ b/configs/mmedit/super-resolution/super-resolution_tensorrt_fp16_static-256x256.py @@ -2,11 +2,13 @@ _base_ = [ './super-resolution_static.py', '../../_base_/backends/tensorrt_fp16.py' ] onnx_config = dict(input_shape=[256, 256]) -backend_config = dict(model_inputs=[ - dict( - input_shapes=dict( - input=dict( - min_shape=[1, 3, 256, 256], - opt_shape=[1, 3, 256, 256], - max_shape=[1, 3, 256, 256]))) -]) +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 256, 256], + opt_shape=[1, 3, 256, 256], + max_shape=[1, 3, 256, 256]))) + ]) diff --git a/configs/mmedit/super-resolution/super-resolution_tensorrt_int8_dynamic-32x32-512x512.py b/configs/mmedit/super-resolution/super-resolution_tensorrt_int8_dynamic-32x32-512x512.py index ddd786c5b..1e3cb433f 100644 --- a/configs/mmedit/super-resolution/super-resolution_tensorrt_int8_dynamic-32x32-512x512.py +++ b/configs/mmedit/super-resolution/super-resolution_tensorrt_int8_dynamic-32x32-512x512.py @@ -1,11 +1,13 @@ _base_ = [ './super-resolution_dynamic.py', '../../_base_/backends/tensorrt_int8.py' ] -backend_config = dict(model_inputs=[ - dict( - input_shapes=dict( - input=dict( - min_shape=[1, 3, 32, 32], - opt_shape=[1, 3, 256, 256], - max_shape=[1, 3, 512, 512]))) -]) +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 32, 32], + opt_shape=[1, 3, 256, 256], + max_shape=[1, 3, 512, 512]))) + ]) diff --git a/configs/mmedit/super-resolution/super-resolution_tensorrt_int8_static-256x256.py b/configs/mmedit/super-resolution/super-resolution_tensorrt_int8_static-256x256.py index d8220c10b..759820398 100644 --- a/configs/mmedit/super-resolution/super-resolution_tensorrt_int8_static-256x256.py +++ b/configs/mmedit/super-resolution/super-resolution_tensorrt_int8_static-256x256.py @@ -2,11 +2,13 @@ _base_ = [ './super-resolution_static.py', '../../_base_/backends/tensorrt_int8.py' ] onnx_config = dict(input_shape=[256, 256]) -backend_config = dict(model_inputs=[ - dict( - input_shapes=dict( - input=dict( - min_shape=[1, 3, 256, 256], - opt_shape=[1, 3, 256, 256], - max_shape=[1, 3, 256, 256]))) -]) +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 256, 256], + opt_shape=[1, 3, 256, 256], + max_shape=[1, 3, 256, 256]))) + ]) diff --git a/configs/mmedit/super-resolution/super-resolution_tensorrt_static-256x256.py b/configs/mmedit/super-resolution/super-resolution_tensorrt_static-256x256.py index 2b1ac65be..79e141a42 100644 --- a/configs/mmedit/super-resolution/super-resolution_tensorrt_static-256x256.py +++ b/configs/mmedit/super-resolution/super-resolution_tensorrt_static-256x256.py @@ -1,10 +1,12 @@ _base_ = ['./super-resolution_static.py', '../../_base_/backends/tensorrt.py'] onnx_config = dict(input_shape=[256, 256]) -backend_config = dict(model_inputs=[ - dict( - input_shapes=dict( - input=dict( - min_shape=[1, 3, 256, 256], - opt_shape=[1, 3, 256, 256], - max_shape=[1, 3, 256, 256]))) -]) +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 256, 256], + opt_shape=[1, 3, 256, 256], + max_shape=[1, 3, 256, 256]))) + ]) diff --git a/mmdeploy/codebase/mmedit/__init__.py b/mmdeploy/codebase/mmedit/__init__.py index bd3cceb0c..55855b48d 100644 --- a/mmdeploy/codebase/mmedit/__init__.py +++ b/mmdeploy/codebase/mmedit/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deploy import MMEditing, SuperResolution -from .models import * # noqa: F401,F403 __all__ = ['MMEditing', 'SuperResolution'] diff --git a/mmdeploy/codebase/mmedit/models/__init__.py b/mmdeploy/codebase/mmedit/models/__init__.py deleted file mode 100644 index 9bda05072..000000000 --- a/mmdeploy/codebase/mmedit/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .backbones import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmedit/models/backbones/__init__.py b/mmdeploy/codebase/mmedit/models/backbones/__init__.py deleted file mode 100644 index 51955a89f..000000000 --- a/mmdeploy/codebase/mmedit/models/backbones/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .srcnn import SRCNN__tensorrt - -__all__ = ['SRCNN__tensorrt'] diff --git a/mmdeploy/codebase/mmedit/models/backbones/srcnn.py b/mmdeploy/codebase/mmedit/models/backbones/srcnn.py deleted file mode 100644 index 4497fccfc..000000000 --- a/mmdeploy/codebase/mmedit/models/backbones/srcnn.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn - -from mmdeploy.core import MODULE_REWRITER - - -@MODULE_REWRITER.register_rewrite_module( - 'mmedit.models.backbones.sr_backbones.SRCNN', backend='tensorrt') -class SRCNN__tensorrt(nn.Module): - """Rewrite `SRCNN` for tensorrt backend. - - SRCNN has three conv layers. For each layer, we can define the - `in_channels`, `out_channels` and `kernel_size`.The input image will - first be upsampled with a bicubic upsampler, and then super-resolved - in the HR spatial size. - Because TensorRT doesn't support bicubic operator, when deployment we use - bilinear instead. According to the experiments, the precision may decrease - about 4%. - Paper: Learning a Deep Convolutional Network for Image Super-Resolution. - - Args: - module (nn.Module): Source SRCNN module. - channels (tuple[int]): A tuple of channel numbers for each layer - including channels of input and output . Default: (3, 64, 32, 3). - kernel_sizes (tuple[int]): A tuple of kernel sizes for each conv layer. - Default: (9, 1, 5). - upscale_factor (int): Upsampling factor. Default: 4. - """ - - def __init__(self, - module, - channels=(3, 64, 32, 3), - kernel_sizes=(9, 1, 5), - upscale_factor=4): - super(SRCNN__tensorrt, self).__init__() - - self._module = module - - module.img_upsampler = nn.Upsample( - scale_factor=module.upscale_factor, - mode='bilinear', - align_corners=False) - - def forward(self, *args, **kwargs): - """Run forward.""" - return self._module(*args, **kwargs) - - def init_weights(self, *args, **kwargs): - """Initialize weights.""" - return self._module.init_weights(*args, **kwargs) diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index d13c94ac4..5baf25b24 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .getattribute import tensor__getattribute__ncnn from .group_norm import group_norm__ncnn -from .interpolate import interpolate__ncnn +from .interpolate import interpolate__ncnn, interpolate__tensorrt from .linear import linear__ncnn from .repeat import tensor__repeat__tensorrt from .size import tensor__size__ncnn @@ -9,6 +9,6 @@ from .topk import topk__dynamic, topk__tensorrt __all__ = [ 'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn', - 'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn', - 'topk__dynamic', 'topk__tensorrt' + 'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt', + 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt' ] diff --git a/mmdeploy/pytorch/functions/interpolate.py b/mmdeploy/pytorch/functions/interpolate.py index 77663b4fb..452fb481e 100644 --- a/mmdeploy/pytorch/functions/interpolate.py +++ b/mmdeploy/pytorch/functions/interpolate.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging from typing import Optional, Tuple, Union import torch +from torch.autograd import Function from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils.constants import Backend @FUNCTION_REWRITER.register_rewriter( @@ -36,3 +39,65 @@ def interpolate__ncnn(ctx, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) + + +@FUNCTION_REWRITER.register_rewriter( + 'torch.nn.functional.interpolate', + is_pytorch=True, + backend=Backend.TENSORRT.value) +def interpolate__tensorrt( + ctx, + input: torch.Tensor, + size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, + int]]] = None, + scale_factor: Optional[Union[float, Tuple[float]]] = None, + mode: str = 'bilinear', + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, +): + """Register default symbolic function for `interpolate`.""" + + class BicubicInterpolate(Function): + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def symbolic(g, input, scale_factor, align_corners): + """Symbolic function for creating onnx op.""" + return g.op( + 'mmdeploy::TRTBicubicInterpolate', + input, + scale_factor_f=scale_factor, + align_corners_i=align_corners) + + @staticmethod + def forward(g, input, scale_factor, align_corners): + """Run forward.""" + return ctx.origin_func( + input, + scale_factor=scale_factor, + mode='bicubic', + align_corners=align_corners) + + if 'bicubic' == mode: + input_size = input.shape + if isinstance(scale_factor, float): + scale_factor = [scale_factor, scale_factor] + if scale_factor is None: + logging.warning( + 'ResizeLayer in TensorRT allow dynamic input shape with shape ' + 'tensor. Which is not available for custom ops. Computed scale' + '_factor might be the right way to get final shape.') + scale_factor = [ + s_out / s_in for s_out, s_in in zip(size, input_size[2:]) + ] + return BicubicInterpolate.apply(input, scale_factor, align_corners) + else: + return ctx.origin_func( + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor) diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index b51d04e0a..ed40c8e04 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -90,6 +90,61 @@ def test_grid_sample(backend, save_dir=save_dir) +@pytest.mark.parametrize('backend', [TEST_TENSORRT]) +@pytest.mark.parametrize('dynamic_export', [True, False]) +@pytest.mark.parametrize('mode', ['bicubic', 'nearest']) +@pytest.mark.parametrize('align_corners', [True, False]) +@pytest.mark.parametrize('scale_factor', [2, 4]) +@pytest.mark.parametrize('n, c, h, w', [(2, 3, 5, 10)]) +def test_bicubic_interpolate(backend, + dynamic_export, + mode, + align_corners, + scale_factor, + n, + c, + h, + w, + input_list=None, + save_dir=None): + backend.check_env() + + if input_list is None: + input = torch.randn(n, c, h, w) + if dynamic_export: + dynamic_axes = { + 'input': { + 0: 'n', + 2: 'h', + 3: 'w', + }, + 'output': { + 0: 'n', + 2: 'h', + 3: 'w', + }, + } + else: + dynamic_axes = None + + if mode == 'nearest': + align_corners = None + resize = nn.Upsample( + scale_factor=scale_factor, mode=mode, align_corners=align_corners) + expected_result = resize(input).cuda() + wrapped_model = WrapFunction(resize).eval() + + with RewriterContext(cfg={}, backend=backend.backend_name, opset=11): + backend.run_and_validate( + wrapped_model, [input], + 'bicubic_interpolate', + input_names=['input'], + dynamic_axes=dynamic_axes, + output_names=['output'], + save_dir=save_dir, + expected_result=expected_result) + + @pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT]) @pytest.mark.parametrize('in_channels,out_channels,stride,padding,' 'dilation,groups,deform_groups,kernel_size',