[Enhancement] Add bicubic resize plugin for tensorrt (#238)

* save codes

* enable export fake bicubic interpolate op to onnx

* save codes

* enable bicubic interpolate trt plugin

* static export

* enable visualize but need align acc

* use torch bicubic upsample

* add unit tests for bicubic interpolate

* fix unit tests

* change mmedit config

* remove useless comments

* remove useless comments

* resolve comments

* fix lint

* clang-format

Co-authored-by: grimoire <yaoqian@sensetime.com>
pull/12/head
AllentDan 2021-12-01 16:31:10 +08:00 committed by GitHub
parent 3b97f64385
commit 66d5cddbdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 655 additions and 108 deletions

View File

@ -0,0 +1,203 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_bicubic_interpolate.hpp"
#include <assert.h>
#include <chrono>
#include "trt_bicubic_interpolate_kernel.hpp"
#include "trt_plugin_helper.hpp"
#include "trt_serialize.hpp"
using namespace nvinfer1;
namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"TRTBicubicInterpolate"};
} // namespace
TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string &name,
std::vector<float> scale_factor,
bool align_corners)
: TRTPluginBase(name),
mScaleFactor(scale_factor),
mAlignCorners(align_corners) {}
TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string name,
const void *data, size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &mScaleFactor);
deserialize_value(&data, &length, &mAlignCorners);
}
nvinfer1::IPluginV2DynamicExt *TRTBicubicInterpolate::clone() const
TRT_NOEXCEPT {
TRTBicubicInterpolate *plugin =
new TRTBicubicInterpolate(mLayerName, mScaleFactor, mAlignCorners);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs TRTBicubicInterpolate::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
auto height = exprBuilder.constant(mScaleFactor[0]);
auto width = exprBuilder.constant(mScaleFactor[1]);
auto d2 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2],
*height);
auto d3 =
exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[3], *width);
ret.d[2] = d2;
ret.d[3] = d3;
return ret;
}
bool TRTBicubicInterpolate::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
if (pos == 0) {
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else {
return ioDesc[pos].type == ioDesc[0].type &&
ioDesc[pos].format == ioDesc[0].format;
}
}
void TRTBicubicInterpolate::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {}
size_t TRTBicubicInterpolate::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}
int TRTBicubicInterpolate::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs,
void *const *outputs, void *workSpace,
cudaStream_t stream) TRT_NOEXCEPT {
int batch = inputDesc[0].dims.d[0];
int channels = inputDesc[0].dims.d[1];
int height = inputDesc[0].dims.d[2];
int width = inputDesc[0].dims.d[3];
int height_out = outputDesc[0].dims.d[2];
int width_out = outputDesc[0].dims.d[3];
const void *x = inputs[0];
void *output = outputs[0];
// TODO: add fp16 support
auto data_type = inputDesc[0].type;
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
bicubic_interpolate<float>((float *)x, (float *)output, batch, channels,
height, width, height_out, width_out,
mAlignCorners, stream);
break;
default:
return 1;
break;
}
return 0;
}
nvinfer1::DataType TRTBicubicInterpolate::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}
// IPluginV2 Methods
const char *TRTBicubicInterpolate::getPluginType() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *TRTBicubicInterpolate::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
int TRTBicubicInterpolate::getNbOutputs() const TRT_NOEXCEPT { return 1; }
size_t TRTBicubicInterpolate::getSerializationSize() const TRT_NOEXCEPT {
return serialized_size(mScaleFactor) + serialized_size(mAlignCorners);
}
void TRTBicubicInterpolate::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mScaleFactor);
serialize_value(&buffer, mAlignCorners);
}
////////////////////// creator /////////////////////////////
TRTBicubicInterpolateCreator::TRTBicubicInterpolateCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("scale_factor"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *TRTBicubicInterpolateCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *TRTBicubicInterpolateCreator::getPluginVersion() const
TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
nvinfer1::Dims size{2, {1, 1}};
std::vector<float> scale_factor;
bool align_corners = 1;
for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);
if (field_name.compare("scale_factor") == 0) {
int data_size = (fc->fields[i].length);
if (data_size != 2) {
data_size = data_size / sizeof(float);
}
ASSERT(data_size == 2)
const float *data_start = static_cast<const float *>(fc->fields[i].data);
scale_factor = std::vector<float>(data_start, data_start + data_size);
}
if (field_name.compare("align_corners") == 0) {
align_corners = static_cast<const int *>(fc->fields[i].data)[0];
}
}
TRTBicubicInterpolate *plugin =
new TRTBicubicInterpolate(name, scale_factor, align_corners);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::deserializePlugin(
const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new TRTBicubicInterpolate(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(TRTBicubicInterpolateCreator);
} // namespace mmdeploy

View File

@ -0,0 +1,76 @@
#ifndef TRT_BICUBIC_INTERPOLATE_HPP
#define TRT_BICUBIC_INTERPOLATE_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace mmdeploy {
class TRTBicubicInterpolate : public TRTPluginBase {
public:
TRTBicubicInterpolate(const std::string &name,
std::vector<float> scale_factor, bool align_corners);
TRTBicubicInterpolate(const std::string name, const void *data,
size_t length);
TRTBicubicInterpolate() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *ioDesc,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
private:
std::vector<float> mScaleFactor;
bool mAlignCorners;
};
class TRTBicubicInterpolateCreator : public TRTPluginCreatorBase {
public:
TRTBicubicInterpolateCreator();
const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin(const char *name,
const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(
const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_BICUBIC_INTERPOLATE_HPP

View File

@ -0,0 +1,181 @@
// Modified from
// https://github.com/pytorch/pytorch/blob/6adbe044e39c8e8db158d91e151aa6dead6e9aa4/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu
#include <cuda_fp16.h>
#include <stdio.h>
#include <algorithm>
#include <cmath>
#include <vector>
#include "common_cuda_helper.hpp"
#include "trt_bicubic_interpolate_kernel.hpp"
// Based on
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
template <typename scalar_t>
__device__ __forceinline__ static scalar_t cubic_convolution1(scalar_t x,
scalar_t A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename scalar_t>
__device__ __forceinline__ static scalar_t cubic_convolution2(scalar_t x,
scalar_t A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename scalar_t>
__device__ __forceinline__ static void get_cubic_upsample_coefficients(
scalar_t coeffs[4], scalar_t t) {
scalar_t A = -0.75;
scalar_t x1 = t;
coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
// opposite coefficients
scalar_t x2 = 1.0 - t;
coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
}
template <typename scalar_t>
__device__ __forceinline__ static scalar_t cubic_interp1d(
scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) {
scalar_t coeffs[4];
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
/* Used by UpSampleBicubic2d.cu */
template <typename scalar_t>
__device__ __forceinline__ static scalar_t upsample_get_value_bounded(
const scalar_t *data, int batch, int channel, int batchsize, int channels,
int height, int width, int y, int x) {
int access_y = max(min(y, height - 1), 0);
int access_x = max(min(x, width - 1), 0);
return data[batch * channels * height * width + channel * height * width +
access_y * width + access_x];
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t area_pixel_compute_source_index(
scalar_t scale, int64_t dst_index, bool align_corners, bool cubic) {
if (align_corners) {
return scale * dst_index;
} else {
scalar_t src_idx = scale * (dst_index + 0.5) - 0.5;
// [Note] Follow Opencv resize logic:
// We allow negative src_idx here and later will use
// dx = src_idx - floorf(src_idx)
// to compute the "distance"(which affects weights).
// For linear modes, weight distribution doesn't matter
// for negative indices as they use 2 pixels to interpolate.
// For example, [-1, 0], they both use pixel 0 value so it
// doesn't affect if we bound the src_idx to 0 or not.
// TODO: Our current linear mode impls use unbound indices
// where we should and then remove this cubic flag.
// This matters in cubic mode, as we might need [-1, 0, 1, 2]
// to interpolate and the weights can be affected.
return (!cubic && src_idx < 0) ? scalar_t(0) : src_idx;
}
}
// cubic interpolation pytorch
template <typename scalar_t>
__global__ void resize_cubic_kernel_torch(
const int num_elements, const scalar_t *src, const int batchsize,
const int channels, int srcWidth, int srcHeight, scalar_t *dst,
int dstWidth, int dstHeight, bool align_corners, float height_scale,
float width_scale) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index >= num_elements) {
return;
}
// Special case: input and output are the same size, just copy
const int output_x = index % dstWidth;
const int output_y = index / dstWidth;
if (srcHeight == dstHeight && srcWidth == dstWidth) {
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; c++) {
const scalar_t val =
src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth +
output_y * dstWidth + output_x];
dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth +
output_y * dstWidth + output_x] = val;
}
}
return;
}
// Interpolation kernel
scalar_t real_x = area_pixel_compute_source_index(
width_scale, output_x, align_corners, /*cubic=*/true);
int in_x = floorf(real_x);
scalar_t t_x = real_x - in_x;
scalar_t real_y = area_pixel_compute_source_index(
height_scale, output_y, align_corners, /*cubic=*/true);
int in_y = floorf(real_y);
scalar_t t_y = real_y - in_y;
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; c++) {
scalar_t coefficients[4];
for (int k = 0; k < 4; k++) {
coefficients[k] = cubic_interp1d<scalar_t>(
upsample_get_value_bounded(src, n, c, batchsize, channels,
srcHeight, srcWidth, in_y - 1 + k,
in_x - 1),
upsample_get_value_bounded(src, n, c, batchsize, channels,
srcHeight, srcWidth, in_y - 1 + k,
in_x + 0),
upsample_get_value_bounded(src, n, c, batchsize, channels,
srcHeight, srcWidth, in_y - 1 + k,
in_x + 1),
upsample_get_value_bounded(src, n, c, batchsize, channels,
srcHeight, srcWidth, in_y - 1 + k,
in_x + 2),
t_x);
}
dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth +
output_y * dstWidth + output_x] =
scalar_t(cubic_interp1d(coefficients[0], coefficients[1],
coefficients[2], coefficients[3], t_y));
}
}
}
template <typename scalar_t>
void resizeGPU(const scalar_t *pIn_d, scalar_t *pOut_d, int batch, int channels,
int srcWidth, int srcHeight, int dstWidth, int dstHeight,
bool align_corners, cudaStream_t stream) {
float height_scale = float(srcHeight) / dstHeight;
float width_scale = float(srcWidth) / dstWidth;
if (align_corners && dstWidth > 1 && dstHeight > 1) {
height_scale = (float)(srcHeight - 1) / (dstHeight - 1);
width_scale = (float)(srcWidth - 1) / (dstWidth - 1);
}
int n = batch * dstWidth * dstHeight * channels;
resize_cubic_kernel_torch<<<GET_BLOCKS(n), THREADS_PER_BLOCK, 0, stream>>>(
dstWidth * dstHeight, pIn_d, batch, channels, srcWidth, srcHeight, pOut_d,
dstWidth, dstHeight, align_corners, height_scale, width_scale);
}
template <typename scalar_t>
void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch,
int channels, int in_height, int in_width,
int out_height, int out_width, bool align_corners,
cudaStream_t stream) {
resizeGPU(input, output, batch, channels, in_width, in_height, out_width,
out_height, align_corners, stream);
}
template void bicubic_interpolate<float>(const float *input, float *output,
int batch, int channels, int in_height,
int in_width, int out_height,
int out_width, bool align_corners,
cudaStream_t stream);

View File

@ -0,0 +1,12 @@
#ifndef TRT_BICUBIC_INTERPOLATE_KERNEL_HPP
#define TRT_BICUBIC_INTERPOLATE_KERNEL_HPP
#include <cuda_runtime.h>
#include "common_cuda_helper.hpp"
template <typename scalar_t>
void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch,
int channels, int in_height, int in_width,
int out_height, int out_width, bool align_corners,
cudaStream_t stream);
#endif // TRT_BICUBIC_INTERPOLATE_KERNEL_HPP

View File

@ -1,9 +1,11 @@
_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/tensorrt.py']
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 512, 512])))
])
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 512, 512])))
])

View File

@ -1,11 +1,13 @@
_base_ = [
'./super-resolution_dynamic.py', '../../_base_/backends/tensorrt_fp16.py'
]
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 512, 512])))
])
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 512, 512])))
])

View File

@ -2,11 +2,13 @@ _base_ = [
'./super-resolution_static.py', '../../_base_/backends/tensorrt_fp16.py'
]
onnx_config = dict(input_shape=[256, 256])
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 256, 256],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 256, 256])))
])
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 256, 256],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 256, 256])))
])

View File

@ -1,11 +1,13 @@
_base_ = [
'./super-resolution_dynamic.py', '../../_base_/backends/tensorrt_int8.py'
]
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 512, 512])))
])
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 512, 512])))
])

View File

@ -2,11 +2,13 @@ _base_ = [
'./super-resolution_static.py', '../../_base_/backends/tensorrt_int8.py'
]
onnx_config = dict(input_shape=[256, 256])
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 256, 256],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 256, 256])))
])
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 256, 256],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 256, 256])))
])

View File

@ -1,10 +1,12 @@
_base_ = ['./super-resolution_static.py', '../../_base_/backends/tensorrt.py']
onnx_config = dict(input_shape=[256, 256])
backend_config = dict(model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 256, 256],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 256, 256])))
])
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 256, 256],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 256, 256])))
])

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .deploy import MMEditing, SuperResolution
from .models import * # noqa: F401,F403
__all__ = ['MMEditing', 'SuperResolution']

View File

@ -1,2 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .srcnn import SRCNN__tensorrt
__all__ = ['SRCNN__tensorrt']

View File

@ -1,50 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmdeploy.core import MODULE_REWRITER
@MODULE_REWRITER.register_rewrite_module(
'mmedit.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
class SRCNN__tensorrt(nn.Module):
"""Rewrite `SRCNN` for tensorrt backend.
SRCNN has three conv layers. For each layer, we can define the
`in_channels`, `out_channels` and `kernel_size`.The input image will
first be upsampled with a bicubic upsampler, and then super-resolved
in the HR spatial size.
Because TensorRT doesn't support bicubic operator, when deployment we use
bilinear instead. According to the experiments, the precision may decrease
about 4%.
Paper: Learning a Deep Convolutional Network for Image Super-Resolution.
Args:
module (nn.Module): Source SRCNN module.
channels (tuple[int]): A tuple of channel numbers for each layer
including channels of input and output . Default: (3, 64, 32, 3).
kernel_sizes (tuple[int]): A tuple of kernel sizes for each conv layer.
Default: (9, 1, 5).
upscale_factor (int): Upsampling factor. Default: 4.
"""
def __init__(self,
module,
channels=(3, 64, 32, 3),
kernel_sizes=(9, 1, 5),
upscale_factor=4):
super(SRCNN__tensorrt, self).__init__()
self._module = module
module.img_upsampler = nn.Upsample(
scale_factor=module.upscale_factor,
mode='bilinear',
align_corners=False)
def forward(self, *args, **kwargs):
"""Run forward."""
return self._module(*args, **kwargs)
def init_weights(self, *args, **kwargs):
"""Initialize weights."""
return self._module.init_weights(*args, **kwargs)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .getattribute import tensor__getattribute__ncnn
from .group_norm import group_norm__ncnn
from .interpolate import interpolate__ncnn
from .interpolate import interpolate__ncnn, interpolate__tensorrt
from .linear import linear__ncnn
from .repeat import tensor__repeat__tensorrt
from .size import tensor__size__ncnn
@ -9,6 +9,6 @@ from .topk import topk__dynamic, topk__tensorrt
__all__ = [
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn',
'topk__dynamic', 'topk__tensorrt'
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt'
]

View File

@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional, Tuple, Union
import torch
from torch.autograd import Function
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils.constants import Backend
@FUNCTION_REWRITER.register_rewriter(
@ -36,3 +39,65 @@ def interpolate__ncnn(ctx,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor)
@FUNCTION_REWRITER.register_rewriter(
'torch.nn.functional.interpolate',
is_pytorch=True,
backend=Backend.TENSORRT.value)
def interpolate__tensorrt(
ctx,
input: torch.Tensor,
size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int,
int]]] = None,
scale_factor: Optional[Union[float, Tuple[float]]] = None,
mode: str = 'bilinear',
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
):
"""Register default symbolic function for `interpolate`."""
class BicubicInterpolate(Function):
def __init__(self) -> None:
super().__init__()
@staticmethod
def symbolic(g, input, scale_factor, align_corners):
"""Symbolic function for creating onnx op."""
return g.op(
'mmdeploy::TRTBicubicInterpolate',
input,
scale_factor_f=scale_factor,
align_corners_i=align_corners)
@staticmethod
def forward(g, input, scale_factor, align_corners):
"""Run forward."""
return ctx.origin_func(
input,
scale_factor=scale_factor,
mode='bicubic',
align_corners=align_corners)
if 'bicubic' == mode:
input_size = input.shape
if isinstance(scale_factor, float):
scale_factor = [scale_factor, scale_factor]
if scale_factor is None:
logging.warning(
'ResizeLayer in TensorRT allow dynamic input shape with shape '
'tensor. Which is not available for custom ops. Computed scale'
'_factor might be the right way to get final shape.')
scale_factor = [
s_out / s_in for s_out, s_in in zip(size, input_size[2:])
]
return BicubicInterpolate.apply(input, scale_factor, align_corners)
else:
return ctx.origin_func(
input,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor)

View File

@ -90,6 +90,61 @@ def test_grid_sample(backend,
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
@pytest.mark.parametrize('dynamic_export', [True, False])
@pytest.mark.parametrize('mode', ['bicubic', 'nearest'])
@pytest.mark.parametrize('align_corners', [True, False])
@pytest.mark.parametrize('scale_factor', [2, 4])
@pytest.mark.parametrize('n, c, h, w', [(2, 3, 5, 10)])
def test_bicubic_interpolate(backend,
dynamic_export,
mode,
align_corners,
scale_factor,
n,
c,
h,
w,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.randn(n, c, h, w)
if dynamic_export:
dynamic_axes = {
'input': {
0: 'n',
2: 'h',
3: 'w',
},
'output': {
0: 'n',
2: 'h',
3: 'w',
},
}
else:
dynamic_axes = None
if mode == 'nearest':
align_corners = None
resize = nn.Upsample(
scale_factor=scale_factor, mode=mode, align_corners=align_corners)
expected_result = resize(input).cuda()
wrapped_model = WrapFunction(resize).eval()
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
backend.run_and_validate(
wrapped_model, [input],
'bicubic_interpolate',
input_names=['input'],
dynamic_axes=dynamic_axes,
output_names=['output'],
save_dir=save_dir,
expected_result=expected_result)
@pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT])
@pytest.mark.parametrize('in_channels,out_channels,stride,padding,'
'dilation,groups,deform_groups,kernel_size',