[Enhancement] Add bicubic resize plugin for tensorrt (#238)
* save codes * enable export fake bicubic interpolate op to onnx * save codes * enable bicubic interpolate trt plugin * static export * enable visualize but need align acc * use torch bicubic upsample * add unit tests for bicubic interpolate * fix unit tests * change mmedit config * remove useless comments * remove useless comments * resolve comments * fix lint * clang-format Co-authored-by: grimoire <yaoqian@sensetime.com>pull/12/head
parent
3b97f64385
commit
66d5cddbdc
|
@ -0,0 +1,203 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "trt_bicubic_interpolate.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "trt_bicubic_interpolate_kernel.hpp"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
#include "trt_serialize.hpp"
|
||||
using namespace nvinfer1;
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *PLUGIN_NAME{"TRTBicubicInterpolate"};
|
||||
} // namespace
|
||||
|
||||
TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string &name,
|
||||
std::vector<float> scale_factor,
|
||||
bool align_corners)
|
||||
: TRTPluginBase(name),
|
||||
mScaleFactor(scale_factor),
|
||||
mAlignCorners(align_corners) {}
|
||||
|
||||
TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string name,
|
||||
const void *data, size_t length)
|
||||
: TRTPluginBase(name) {
|
||||
deserialize_value(&data, &length, &mScaleFactor);
|
||||
deserialize_value(&data, &length, &mAlignCorners);
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *TRTBicubicInterpolate::clone() const
|
||||
TRT_NOEXCEPT {
|
||||
TRTBicubicInterpolate *plugin =
|
||||
new TRTBicubicInterpolate(mLayerName, mScaleFactor, mAlignCorners);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs TRTBicubicInterpolate::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = 4;
|
||||
ret.d[0] = inputs[0].d[0];
|
||||
ret.d[1] = inputs[0].d[1];
|
||||
auto height = exprBuilder.constant(mScaleFactor[0]);
|
||||
auto width = exprBuilder.constant(mScaleFactor[1]);
|
||||
auto d2 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2],
|
||||
*height);
|
||||
auto d3 =
|
||||
exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[3], *width);
|
||||
ret.d[2] = d2;
|
||||
ret.d[3] = d3;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool TRTBicubicInterpolate::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos == 0) {
|
||||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
|
||||
} else {
|
||||
return ioDesc[pos].type == ioDesc[0].type &&
|
||||
ioDesc[pos].format == ioDesc[0].format;
|
||||
}
|
||||
}
|
||||
|
||||
void TRTBicubicInterpolate::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {}
|
||||
|
||||
size_t TRTBicubicInterpolate::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const TRT_NOEXCEPT {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int TRTBicubicInterpolate::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs,
|
||||
void *const *outputs, void *workSpace,
|
||||
cudaStream_t stream) TRT_NOEXCEPT {
|
||||
int batch = inputDesc[0].dims.d[0];
|
||||
int channels = inputDesc[0].dims.d[1];
|
||||
int height = inputDesc[0].dims.d[2];
|
||||
int width = inputDesc[0].dims.d[3];
|
||||
|
||||
int height_out = outputDesc[0].dims.d[2];
|
||||
int width_out = outputDesc[0].dims.d[3];
|
||||
const void *x = inputs[0];
|
||||
void *output = outputs[0];
|
||||
|
||||
// TODO: add fp16 support
|
||||
auto data_type = inputDesc[0].type;
|
||||
switch (data_type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
bicubic_interpolate<float>((float *)x, (float *)output, batch, channels,
|
||||
height, width, height_out, width_out,
|
||||
mAlignCorners, stream);
|
||||
break;
|
||||
default:
|
||||
return 1;
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType TRTBicubicInterpolate::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *TRTBicubicInterpolate::getPluginType() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTBicubicInterpolate::getPluginVersion() const TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int TRTBicubicInterpolate::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||
|
||||
size_t TRTBicubicInterpolate::getSerializationSize() const TRT_NOEXCEPT {
|
||||
return serialized_size(mScaleFactor) + serialized_size(mAlignCorners);
|
||||
}
|
||||
|
||||
void TRTBicubicInterpolate::serialize(void *buffer) const TRT_NOEXCEPT {
|
||||
serialize_value(&buffer, mScaleFactor);
|
||||
serialize_value(&buffer, mAlignCorners);
|
||||
}
|
||||
|
||||
////////////////////// creator /////////////////////////////
|
||||
|
||||
TRTBicubicInterpolateCreator::TRTBicubicInterpolateCreator() {
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("scale_factor"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners"));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *TRTBicubicInterpolateCreator::getPluginName() const TRT_NOEXCEPT {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *TRTBicubicInterpolateCreator::getPluginVersion() const
|
||||
TRT_NOEXCEPT {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||
nvinfer1::Dims size{2, {1, 1}};
|
||||
std::vector<float> scale_factor;
|
||||
bool align_corners = 1;
|
||||
|
||||
for (int i = 0; i < fc->nbFields; i++) {
|
||||
if (fc->fields[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string field_name(fc->fields[i].name);
|
||||
|
||||
if (field_name.compare("scale_factor") == 0) {
|
||||
int data_size = (fc->fields[i].length);
|
||||
if (data_size != 2) {
|
||||
data_size = data_size / sizeof(float);
|
||||
}
|
||||
ASSERT(data_size == 2)
|
||||
const float *data_start = static_cast<const float *>(fc->fields[i].data);
|
||||
scale_factor = std::vector<float>(data_start, data_start + data_size);
|
||||
}
|
||||
|
||||
if (field_name.compare("align_corners") == 0) {
|
||||
align_corners = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
}
|
||||
|
||||
TRTBicubicInterpolate *plugin =
|
||||
new TRTBicubicInterpolate(name, scale_factor, align_corners);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::deserializePlugin(
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
auto plugin = new TRTBicubicInterpolate(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
REGISTER_TENSORRT_PLUGIN(TRTBicubicInterpolateCreator);
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,76 @@
|
|||
#ifndef TRT_BICUBIC_INTERPOLATE_HPP
|
||||
#define TRT_BICUBIC_INTERPOLATE_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_base.hpp"
|
||||
namespace mmdeploy {
|
||||
class TRTBicubicInterpolate : public TRTPluginBase {
|
||||
public:
|
||||
TRTBicubicInterpolate(const std::string &name,
|
||||
std::vector<float> scale_factor, bool align_corners);
|
||||
|
||||
TRTBicubicInterpolate(const std::string name, const void *data,
|
||||
size_t length);
|
||||
|
||||
TRTBicubicInterpolate() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *ioDesc,
|
||||
int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const TRT_NOEXCEPT override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||
|
||||
private:
|
||||
std::vector<float> mScaleFactor;
|
||||
bool mAlignCorners;
|
||||
};
|
||||
|
||||
class TRTBicubicInterpolateCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
TRTBicubicInterpolateCreator();
|
||||
|
||||
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name,
|
||||
const nvinfer1::PluginFieldCollection *fc)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(
|
||||
const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmdeploy
|
||||
#endif // TRT_BICUBIC_INTERPOLATE_HPP
|
|
@ -0,0 +1,181 @@
|
|||
// Modified from
|
||||
// https://github.com/pytorch/pytorch/blob/6adbe044e39c8e8db158d91e151aa6dead6e9aa4/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "trt_bicubic_interpolate_kernel.hpp"
|
||||
|
||||
// Based on
|
||||
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ static scalar_t cubic_convolution1(scalar_t x,
|
||||
scalar_t A) {
|
||||
return ((A + 2) * x - (A + 3)) * x * x + 1;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ static scalar_t cubic_convolution2(scalar_t x,
|
||||
scalar_t A) {
|
||||
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ static void get_cubic_upsample_coefficients(
|
||||
scalar_t coeffs[4], scalar_t t) {
|
||||
scalar_t A = -0.75;
|
||||
|
||||
scalar_t x1 = t;
|
||||
coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
|
||||
coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
|
||||
|
||||
// opposite coefficients
|
||||
scalar_t x2 = 1.0 - t;
|
||||
coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
|
||||
coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ static scalar_t cubic_interp1d(
|
||||
scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) {
|
||||
scalar_t coeffs[4];
|
||||
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
|
||||
|
||||
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
|
||||
}
|
||||
|
||||
/* Used by UpSampleBicubic2d.cu */
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ static scalar_t upsample_get_value_bounded(
|
||||
const scalar_t *data, int batch, int channel, int batchsize, int channels,
|
||||
int height, int width, int y, int x) {
|
||||
int access_y = max(min(y, height - 1), 0);
|
||||
int access_x = max(min(x, width - 1), 0);
|
||||
return data[batch * channels * height * width + channel * height * width +
|
||||
access_y * width + access_x];
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ scalar_t area_pixel_compute_source_index(
|
||||
scalar_t scale, int64_t dst_index, bool align_corners, bool cubic) {
|
||||
if (align_corners) {
|
||||
return scale * dst_index;
|
||||
} else {
|
||||
scalar_t src_idx = scale * (dst_index + 0.5) - 0.5;
|
||||
// [Note] Follow Opencv resize logic:
|
||||
// We allow negative src_idx here and later will use
|
||||
// dx = src_idx - floorf(src_idx)
|
||||
// to compute the "distance"(which affects weights).
|
||||
// For linear modes, weight distribution doesn't matter
|
||||
// for negative indices as they use 2 pixels to interpolate.
|
||||
// For example, [-1, 0], they both use pixel 0 value so it
|
||||
// doesn't affect if we bound the src_idx to 0 or not.
|
||||
// TODO: Our current linear mode impls use unbound indices
|
||||
// where we should and then remove this cubic flag.
|
||||
// This matters in cubic mode, as we might need [-1, 0, 1, 2]
|
||||
// to interpolate and the weights can be affected.
|
||||
return (!cubic && src_idx < 0) ? scalar_t(0) : src_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// cubic interpolation pytorch
|
||||
template <typename scalar_t>
|
||||
__global__ void resize_cubic_kernel_torch(
|
||||
const int num_elements, const scalar_t *src, const int batchsize,
|
||||
const int channels, int srcWidth, int srcHeight, scalar_t *dst,
|
||||
int dstWidth, int dstHeight, bool align_corners, float height_scale,
|
||||
float width_scale) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (index >= num_elements) {
|
||||
return;
|
||||
}
|
||||
// Special case: input and output are the same size, just copy
|
||||
const int output_x = index % dstWidth;
|
||||
const int output_y = index / dstWidth;
|
||||
|
||||
if (srcHeight == dstHeight && srcWidth == dstWidth) {
|
||||
for (int n = 0; n < batchsize; n++) {
|
||||
for (int c = 0; c < channels; c++) {
|
||||
const scalar_t val =
|
||||
src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth +
|
||||
output_y * dstWidth + output_x];
|
||||
dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth +
|
||||
output_y * dstWidth + output_x] = val;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Interpolation kernel
|
||||
scalar_t real_x = area_pixel_compute_source_index(
|
||||
width_scale, output_x, align_corners, /*cubic=*/true);
|
||||
int in_x = floorf(real_x);
|
||||
scalar_t t_x = real_x - in_x;
|
||||
|
||||
scalar_t real_y = area_pixel_compute_source_index(
|
||||
height_scale, output_y, align_corners, /*cubic=*/true);
|
||||
int in_y = floorf(real_y);
|
||||
scalar_t t_y = real_y - in_y;
|
||||
|
||||
for (int n = 0; n < batchsize; n++) {
|
||||
for (int c = 0; c < channels; c++) {
|
||||
scalar_t coefficients[4];
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
coefficients[k] = cubic_interp1d<scalar_t>(
|
||||
upsample_get_value_bounded(src, n, c, batchsize, channels,
|
||||
srcHeight, srcWidth, in_y - 1 + k,
|
||||
in_x - 1),
|
||||
upsample_get_value_bounded(src, n, c, batchsize, channels,
|
||||
srcHeight, srcWidth, in_y - 1 + k,
|
||||
in_x + 0),
|
||||
upsample_get_value_bounded(src, n, c, batchsize, channels,
|
||||
srcHeight, srcWidth, in_y - 1 + k,
|
||||
in_x + 1),
|
||||
upsample_get_value_bounded(src, n, c, batchsize, channels,
|
||||
srcHeight, srcWidth, in_y - 1 + k,
|
||||
in_x + 2),
|
||||
t_x);
|
||||
}
|
||||
|
||||
dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth +
|
||||
output_y * dstWidth + output_x] =
|
||||
scalar_t(cubic_interp1d(coefficients[0], coefficients[1],
|
||||
coefficients[2], coefficients[3], t_y));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void resizeGPU(const scalar_t *pIn_d, scalar_t *pOut_d, int batch, int channels,
|
||||
int srcWidth, int srcHeight, int dstWidth, int dstHeight,
|
||||
bool align_corners, cudaStream_t stream) {
|
||||
float height_scale = float(srcHeight) / dstHeight;
|
||||
float width_scale = float(srcWidth) / dstWidth;
|
||||
if (align_corners && dstWidth > 1 && dstHeight > 1) {
|
||||
height_scale = (float)(srcHeight - 1) / (dstHeight - 1);
|
||||
width_scale = (float)(srcWidth - 1) / (dstWidth - 1);
|
||||
}
|
||||
int n = batch * dstWidth * dstHeight * channels;
|
||||
resize_cubic_kernel_torch<<<GET_BLOCKS(n), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
dstWidth * dstHeight, pIn_d, batch, channels, srcWidth, srcHeight, pOut_d,
|
||||
dstWidth, dstHeight, align_corners, height_scale, width_scale);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch,
|
||||
int channels, int in_height, int in_width,
|
||||
int out_height, int out_width, bool align_corners,
|
||||
cudaStream_t stream) {
|
||||
resizeGPU(input, output, batch, channels, in_width, in_height, out_width,
|
||||
out_height, align_corners, stream);
|
||||
}
|
||||
|
||||
template void bicubic_interpolate<float>(const float *input, float *output,
|
||||
int batch, int channels, int in_height,
|
||||
int in_width, int out_height,
|
||||
int out_width, bool align_corners,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,12 @@
|
|||
#ifndef TRT_BICUBIC_INTERPOLATE_KERNEL_HPP
|
||||
#define TRT_BICUBIC_INTERPOLATE_KERNEL_HPP
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
|
||||
template <typename scalar_t>
|
||||
void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch,
|
||||
int channels, int in_height, int in_width,
|
||||
int out_height, int out_width, bool align_corners,
|
||||
cudaStream_t stream);
|
||||
#endif // TRT_BICUBIC_INTERPOLATE_KERNEL_HPP
|
|
@ -1,9 +1,11 @@
|
|||
_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/tensorrt.py']
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 32, 32],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 512, 512])))
|
||||
])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 32, 32],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 512, 512])))
|
||||
])
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
_base_ = [
|
||||
'./super-resolution_dynamic.py', '../../_base_/backends/tensorrt_fp16.py'
|
||||
]
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 32, 32],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 512, 512])))
|
||||
])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 32, 32],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 512, 512])))
|
||||
])
|
||||
|
|
|
@ -2,11 +2,13 @@ _base_ = [
|
|||
'./super-resolution_static.py', '../../_base_/backends/tensorrt_fp16.py'
|
||||
]
|
||||
onnx_config = dict(input_shape=[256, 256])
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256])))
|
||||
])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256])))
|
||||
])
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
_base_ = [
|
||||
'./super-resolution_dynamic.py', '../../_base_/backends/tensorrt_int8.py'
|
||||
]
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 32, 32],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 512, 512])))
|
||||
])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 32, 32],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 512, 512])))
|
||||
])
|
||||
|
|
|
@ -2,11 +2,13 @@ _base_ = [
|
|||
'./super-resolution_static.py', '../../_base_/backends/tensorrt_int8.py'
|
||||
]
|
||||
onnx_config = dict(input_shape=[256, 256])
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256])))
|
||||
])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256])))
|
||||
])
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
_base_ = ['./super-resolution_static.py', '../../_base_/backends/tensorrt.py']
|
||||
onnx_config = dict(input_shape=[256, 256])
|
||||
backend_config = dict(model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256])))
|
||||
])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256])))
|
||||
])
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .deploy import MMEditing, SuperResolution
|
||||
from .models import * # noqa: F401,F403
|
||||
|
||||
__all__ = ['MMEditing', 'SuperResolution']
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backbones import * # noqa: F401,F403
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .srcnn import SRCNN__tensorrt
|
||||
|
||||
__all__ = ['SRCNN__tensorrt']
|
|
@ -1,50 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmdeploy.core import MODULE_REWRITER
|
||||
|
||||
|
||||
@MODULE_REWRITER.register_rewrite_module(
|
||||
'mmedit.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
|
||||
class SRCNN__tensorrt(nn.Module):
|
||||
"""Rewrite `SRCNN` for tensorrt backend.
|
||||
|
||||
SRCNN has three conv layers. For each layer, we can define the
|
||||
`in_channels`, `out_channels` and `kernel_size`.The input image will
|
||||
first be upsampled with a bicubic upsampler, and then super-resolved
|
||||
in the HR spatial size.
|
||||
Because TensorRT doesn't support bicubic operator, when deployment we use
|
||||
bilinear instead. According to the experiments, the precision may decrease
|
||||
about 4%.
|
||||
Paper: Learning a Deep Convolutional Network for Image Super-Resolution.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Source SRCNN module.
|
||||
channels (tuple[int]): A tuple of channel numbers for each layer
|
||||
including channels of input and output . Default: (3, 64, 32, 3).
|
||||
kernel_sizes (tuple[int]): A tuple of kernel sizes for each conv layer.
|
||||
Default: (9, 1, 5).
|
||||
upscale_factor (int): Upsampling factor. Default: 4.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module,
|
||||
channels=(3, 64, 32, 3),
|
||||
kernel_sizes=(9, 1, 5),
|
||||
upscale_factor=4):
|
||||
super(SRCNN__tensorrt, self).__init__()
|
||||
|
||||
self._module = module
|
||||
|
||||
module.img_upsampler = nn.Upsample(
|
||||
scale_factor=module.upscale_factor,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Run forward."""
|
||||
return self._module(*args, **kwargs)
|
||||
|
||||
def init_weights(self, *args, **kwargs):
|
||||
"""Initialize weights."""
|
||||
return self._module.init_weights(*args, **kwargs)
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .getattribute import tensor__getattribute__ncnn
|
||||
from .group_norm import group_norm__ncnn
|
||||
from .interpolate import interpolate__ncnn
|
||||
from .interpolate import interpolate__ncnn, interpolate__tensorrt
|
||||
from .linear import linear__ncnn
|
||||
from .repeat import tensor__repeat__tensorrt
|
||||
from .size import tensor__size__ncnn
|
||||
|
@ -9,6 +9,6 @@ from .topk import topk__dynamic, topk__tensorrt
|
|||
|
||||
__all__ = [
|
||||
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
|
||||
'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn',
|
||||
'topk__dynamic', 'topk__tensorrt'
|
||||
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
|
||||
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt'
|
||||
]
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils.constants import Backend
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
@ -36,3 +39,65 @@ def interpolate__ncnn(ctx,
|
|||
mode=mode,
|
||||
align_corners=align_corners,
|
||||
recompute_scale_factor=recompute_scale_factor)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'torch.nn.functional.interpolate',
|
||||
is_pytorch=True,
|
||||
backend=Backend.TENSORRT.value)
|
||||
def interpolate__tensorrt(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int,
|
||||
int]]] = None,
|
||||
scale_factor: Optional[Union[float, Tuple[float]]] = None,
|
||||
mode: str = 'bilinear',
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
):
|
||||
"""Register default symbolic function for `interpolate`."""
|
||||
|
||||
class BicubicInterpolate(Function):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, input, scale_factor, align_corners):
|
||||
"""Symbolic function for creating onnx op."""
|
||||
return g.op(
|
||||
'mmdeploy::TRTBicubicInterpolate',
|
||||
input,
|
||||
scale_factor_f=scale_factor,
|
||||
align_corners_i=align_corners)
|
||||
|
||||
@staticmethod
|
||||
def forward(g, input, scale_factor, align_corners):
|
||||
"""Run forward."""
|
||||
return ctx.origin_func(
|
||||
input,
|
||||
scale_factor=scale_factor,
|
||||
mode='bicubic',
|
||||
align_corners=align_corners)
|
||||
|
||||
if 'bicubic' == mode:
|
||||
input_size = input.shape
|
||||
if isinstance(scale_factor, float):
|
||||
scale_factor = [scale_factor, scale_factor]
|
||||
if scale_factor is None:
|
||||
logging.warning(
|
||||
'ResizeLayer in TensorRT allow dynamic input shape with shape '
|
||||
'tensor. Which is not available for custom ops. Computed scale'
|
||||
'_factor might be the right way to get final shape.')
|
||||
scale_factor = [
|
||||
s_out / s_in for s_out, s_in in zip(size, input_size[2:])
|
||||
]
|
||||
return BicubicInterpolate.apply(input, scale_factor, align_corners)
|
||||
else:
|
||||
return ctx.origin_func(
|
||||
input,
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners,
|
||||
recompute_scale_factor=recompute_scale_factor)
|
||||
|
|
|
@ -90,6 +90,61 @@ def test_grid_sample(backend,
|
|||
save_dir=save_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
||||
@pytest.mark.parametrize('dynamic_export', [True, False])
|
||||
@pytest.mark.parametrize('mode', ['bicubic', 'nearest'])
|
||||
@pytest.mark.parametrize('align_corners', [True, False])
|
||||
@pytest.mark.parametrize('scale_factor', [2, 4])
|
||||
@pytest.mark.parametrize('n, c, h, w', [(2, 3, 5, 10)])
|
||||
def test_bicubic_interpolate(backend,
|
||||
dynamic_export,
|
||||
mode,
|
||||
align_corners,
|
||||
scale_factor,
|
||||
n,
|
||||
c,
|
||||
h,
|
||||
w,
|
||||
input_list=None,
|
||||
save_dir=None):
|
||||
backend.check_env()
|
||||
|
||||
if input_list is None:
|
||||
input = torch.randn(n, c, h, w)
|
||||
if dynamic_export:
|
||||
dynamic_axes = {
|
||||
'input': {
|
||||
0: 'n',
|
||||
2: 'h',
|
||||
3: 'w',
|
||||
},
|
||||
'output': {
|
||||
0: 'n',
|
||||
2: 'h',
|
||||
3: 'w',
|
||||
},
|
||||
}
|
||||
else:
|
||||
dynamic_axes = None
|
||||
|
||||
if mode == 'nearest':
|
||||
align_corners = None
|
||||
resize = nn.Upsample(
|
||||
scale_factor=scale_factor, mode=mode, align_corners=align_corners)
|
||||
expected_result = resize(input).cuda()
|
||||
wrapped_model = WrapFunction(resize).eval()
|
||||
|
||||
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
|
||||
backend.run_and_validate(
|
||||
wrapped_model, [input],
|
||||
'bicubic_interpolate',
|
||||
input_names=['input'],
|
||||
dynamic_axes=dynamic_axes,
|
||||
output_names=['output'],
|
||||
save_dir=save_dir,
|
||||
expected_result=expected_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT])
|
||||
@pytest.mark.parametrize('in_channels,out_channels,stride,padding,'
|
||||
'dilation,groups,deform_groups,kernel_size',
|
||||
|
|
Loading…
Reference in New Issue