Compare commits

...

3 Commits

Author SHA1 Message Date
RunningLeon 5a9ac8765d
Fix readthedocs (#2134)
* [Fix]: limit urllib3 for readthedocs (#2070)

* fix readthedocs for zh_cn

* fix
2023-05-31 15:18:47 +08:00
AllentDan 6cd77c66b7
add deform conv v3 plugin (#1872)
* add deform conv v3 plugin

* update doc

* resolve comments

* update description
2023-05-23 10:22:47 +08:00
RunningLeon 335ef8648d
fix mmseg exportation for out_channels=1 (#1997) 2023-05-04 12:51:05 +08:00
16 changed files with 712 additions and 37 deletions

View File

@ -28,7 +28,6 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.7]
torch: [1.8.0, 1.9.0]
mmcv: [1.4.2]
include:
@ -40,22 +39,22 @@ jobs:
torchvision: 0.10.0
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
run: |
python -m pip install --upgrade pip
python -V
python -m pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install MMCV
run: |
pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cpu/${{matrix.torch_version}}/index.html
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cpu/${{matrix.torch_version}}/index.html
python -c 'import mmcv; print(mmcv.__version__)'
- name: Install unittest dependencies
run: |
pip install -r requirements.txt
pip install -U numpy
python -m pip install -U numpy
python -m pip install rapidfuzz==2.15.1
python -m pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
run: rm -rf .eggs && python -m pip install -e .
- name: Run python unittests and generate coverage report
run: |
coverage run --branch --source mmdeploy -m pytest -rsE tests
@ -139,6 +138,7 @@ jobs:
python -m pip install -U pip
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu102/${{matrix.torch_version}}/index.html
python -m pip install -r requirements.txt
python -m pip install rapidfuzz==2.15.1
- name: Build and install
run: |
rm -rf .eggs && python -m pip install -e .
@ -174,6 +174,7 @@ jobs:
python -m pip install -U pip
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu111/${{matrix.torch_version}}/index.html
python -m pip install -r requirements.txt
python -m pip install rapidfuzz==2.15.1
- name: Build and install
run: |
rm -rf .eggs && python -m pip install -e .

View File

@ -40,32 +40,27 @@ jobs:
- name: Install system dependencies
run: |
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys A4B469963BF863CC
apt-get update && apt-get install -y git
apt-get update && apt-get install -y git wget
- name: Install dependencies
run: |
python -V
python -m pip install -U pip
python -m pip install mmcv-full==${{matrix.mmcv}} -f https://download.openmmlab.com/mmcv/dist/cu111/${{matrix.torch_version}}/index.html
python -m pip install -r requirements.txt
python -m pip install rapidfuzz==2.15.1
- name: Install mmcls
run: |
cd ~
git clone https://github.com/open-mmlab/mmclassification.git
git clone -b v0.23.0 --depth 1 https://github.com/open-mmlab/mmclassification.git
cd mmclassification
git checkout v0.23.0
python3 -m pip install -e .
cd -
- name: Install ppq
run: |
cd ~
python -m pip install protobuf==3.20.0
git clone https://github.com/openppl-public/ppq
git clone -b v0.6.6 --depth 1 https://github.com/openppl-public/ppq
cd ppq
git checkout edbecf44c7b203515640e4f4119c000a1b66b33a
python3 -m pip install -r requirements.txt
python3 setup.py install
cd -
- name: Run tests
run: |
echo $(pwd)
export PYTHONPATH=${PWD}/ppq:${PYTHONPATH}
python3 .github/scripts/quantize_to_ncnn.py

View File

@ -0,0 +1,279 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_deform_conv_v3.hpp"
#include <assert.h>
#include <chrono>
#include "trt_deform_conv_v3_kernel.hpp"
#include "trt_plugin_helper.hpp"
#include "trt_serialize.hpp"
using namespace nvinfer1;
namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"TRTDCNv3"};
} // namespace
TRTDCNv3::TRTDCNv3(const std::string &name, int kernel_h, int kernel_w, int stride_h, int stride_w,
int pad_h, int pad_w, int dilation_h, int dilation_w, int group,
int group_channels, float offset_scale, int im2col_step)
: TRTPluginBase(name),
kernel_h_(kernel_h),
kernel_w_(kernel_w),
stride_h_(stride_h),
stride_w_(stride_w),
pad_h_(pad_h),
pad_w_(pad_w),
dilation_h_(dilation_h),
dilation_w_(dilation_w),
group_(group),
group_channels_(group_channels),
offset_scale_(offset_scale),
im2col_step_(im2col_step) {}
TRTDCNv3::TRTDCNv3(const std::string name, const void *data, size_t length) : TRTPluginBase(name) {
deserialize_value(&data, &length, &kernel_h_);
deserialize_value(&data, &length, &kernel_w_);
deserialize_value(&data, &length, &stride_h_);
deserialize_value(&data, &length, &stride_w_);
deserialize_value(&data, &length, &pad_h_);
deserialize_value(&data, &length, &pad_w_);
deserialize_value(&data, &length, &dilation_h_);
deserialize_value(&data, &length, &dilation_w_);
deserialize_value(&data, &length, &group_);
deserialize_value(&data, &length, &group_channels_);
deserialize_value(&data, &length, &offset_scale_);
deserialize_value(&data, &length, &im2col_step_);
}
nvinfer1::IPluginV2DynamicExt *TRTDCNv3::clone() const TRT_NOEXCEPT {
TRTDCNv3 *plugin =
new TRTDCNv3(mLayerName, kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_,
dilation_h_, dilation_w_, group_, group_channels_, offset_scale_, im2col_step_);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
const nvinfer1::IDimensionExpr *output_size(const nvinfer1::IDimensionExpr &input, int pad,
int dilation, int kernel, int stride,
nvinfer1::IExprBuilder &exprBuilder) {
// out_expand = 2×padding[0]dilation[0]×(kernel_size[0]1)+1
auto out_expand = exprBuilder.constant(2 * pad - dilation * (kernel - 1) + 1);
// out = out_expand + input
auto out_before_div = exprBuilder.operation(DimensionOperation::kSUM, input, *out_expand);
// out = out / stride
auto out_before_sub = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *out_before_div,
*(exprBuilder.constant(stride)));
// out -=1
auto out =
exprBuilder.operation(DimensionOperation::kSUB, *out_before_sub, *(exprBuilder.constant(1)));
return out;
}
nvinfer1::DimsExprs TRTDCNv3::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[0].d[0];
ret.d[3] = exprBuilder.constant(group_ * group_channels_);
ret.d[1] = output_size(*inputs[0].d[1], pad_h_, dilation_h_, kernel_h_, stride_h_, exprBuilder);
ret.d[2] = output_size(*inputs[0].d[2], pad_w_, dilation_w_, kernel_w_, stride_w_, exprBuilder);
return ret;
}
bool TRTDCNv3::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
if (pos == 0) {
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else {
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
}
}
void TRTDCNv3::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {}
size_t TRTDCNv3::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}
int TRTDCNv3::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
int batch = inputDesc[0].dims.d[0];
int height = inputDesc[0].dims.d[1];
int width = inputDesc[0].dims.d[2];
int channels = inputDesc[0].dims.d[3];
int height_out = outputDesc[0].dims.d[1];
int width_out = outputDesc[0].dims.d[2];
int channels_out = outputDesc[0].dims.d[3];
const void *input = inputs[0];
const void *offset = inputs[1];
const void *mask = inputs[2];
void *output = outputs[0];
// TODO: add fp16 support
auto data_type = inputDesc[0].type;
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
DeformConvv3ForwardCUDAKernelLauncher<float>(
(float *)input, (float *)offset, (float *)mask, (float *)output, workSpace, batch,
channels, height, width, channels_out, kernel_w_, kernel_h_, stride_w_, stride_h_, pad_w_,
pad_h_, dilation_w_, dilation_h_, group_, group_channels_, offset_scale_, im2col_step_,
stream);
break;
default:
return 1;
break;
}
return 0;
}
nvinfer1::DataType TRTDCNv3::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}
// IPluginV2 Methods
const char *TRTDCNv3::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
const char *TRTDCNv3::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
int TRTDCNv3::getNbOutputs() const TRT_NOEXCEPT { return 1; }
size_t TRTDCNv3::getSerializationSize() const TRT_NOEXCEPT {
return serialized_size(kernel_h_) + serialized_size(kernel_w_) + serialized_size(stride_h_) +
serialized_size(stride_w_) + serialized_size(pad_h_) + serialized_size(pad_w_) +
serialized_size(dilation_h_) + serialized_size(dilation_w_) + serialized_size(group_) +
serialized_size(group_channels_) + serialized_size(offset_scale_) +
serialized_size(im2col_step_);
}
void TRTDCNv3::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, kernel_h_);
serialize_value(&buffer, kernel_w_);
serialize_value(&buffer, stride_h_);
serialize_value(&buffer, stride_w_);
serialize_value(&buffer, pad_h_);
serialize_value(&buffer, pad_w_);
serialize_value(&buffer, dilation_h_);
serialize_value(&buffer, dilation_w_);
serialize_value(&buffer, group_);
serialize_value(&buffer, group_channels_);
serialize_value(&buffer, offset_scale_);
serialize_value(&buffer, im2col_step_);
}
////////////////////// creator /////////////////////////////
TRTDCNv3Creator::TRTDCNv3Creator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("kernel_h"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("kernel_w"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("pad_h"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("pad_w"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation_h"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation_w"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("group"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("group_channels"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("offset_scale"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("im2col_step"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *TRTDCNv3Creator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }
const char *TRTDCNv3Creator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
nvinfer1::IPluginV2 *TRTDCNv3Creator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
nvinfer1::Dims size{2, {1, 1}};
int kernel_h = 3;
int kernel_w = 3;
int stride_h = 1;
int stride_w = 1;
int pad_h = 1;
int pad_w = 1;
int dilation_h = 1;
int dilation_w = 1;
int group = 28;
int group_channels = 16;
float offset_scale = 1;
int im2col_step = 256;
for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);
if (field_name.compare("kernel_h") == 0) {
kernel_h = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("kernel_w") == 0) {
kernel_w = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("stride_h") == 0) {
stride_h = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("stride_w") == 0) {
stride_w = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("pad_h") == 0) {
pad_h = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("pad_w") == 0) {
pad_w = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("dilation_h") == 0) {
dilation_h = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("dilation_w") == 0) {
dilation_w = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("group") == 0) {
group = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("group_channels") == 0) {
group_channels = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("offset_scale") == 0) {
offset_scale = static_cast<const float *>(fc->fields[i].data)[0];
}
if (field_name.compare("im2col_step") == 0) {
im2col_step = static_cast<const int *>(fc->fields[i].data)[0];
}
}
TRTDCNv3 *plugin =
new TRTDCNv3(name, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, group_channels, offset_scale, im2col_step);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *TRTDCNv3Creator::deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new TRTDCNv3(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(TRTDCNv3Creator);
} // namespace mmdeploy

View File

@ -0,0 +1,78 @@
#ifndef TRT_DEFORM_CONV_V3_HPP
#define TRT_DEFORM_CONV_V3_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace mmdeploy {
class TRTDCNv3 : public TRTPluginBase {
public:
TRTDCNv3(const std::string &name, int kernel_h, int kernel_w, int stride_h, int stride_w,
int pad_h, int pad_w, int dilation_h, int dilation_w, int group, int group_channels,
float offset_scale, int im2col_step);
TRTDCNv3(const std::string name, const void *data, size_t length);
TRTDCNv3() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
private:
int kernel_h_;
int kernel_w_;
int stride_h_;
int stride_w_;
int pad_h_;
int pad_w_;
int dilation_h_;
int dilation_w_;
int group_;
int group_channels_;
float offset_scale_;
int im2col_step_;
};
class TRTDCNv3Creator : public TRTPluginCreatorBase {
public:
TRTDCNv3Creator();
const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_DEFORM_CONV_V3_HPP

View File

@ -0,0 +1,186 @@
// Modified from
// https://github.com/pytorch/pytorch/blob/6adbe044e39c8e8db158d91e151aa6dead6e9aa4/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu
#include <cuda_fp16.h>
#include <stdio.h>
#include <algorithm>
#include <cmath>
#include <vector>
#include "common_cuda_helper.hpp"
#include "trt_deform_conv_v3_kernel.hpp"
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
const static int CUDA_NUM_THREADS = 256;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
template <typename scalar_t>
__device__ scalar_t dcnv3_im2col_bilinear(const scalar_t *&bottom_data, const int &height,
const int &width, const int &group,
const int &group_channels, const scalar_t &h,
const scalar_t &w, const int &g, const int &c) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const scalar_t lh = h - h_low;
const scalar_t lw = w - w_low;
const scalar_t hh = 1 - lh, hw = 1 - lw;
const int w_stride = group * group_channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int base_ptr = g * group_channels + c;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__global__ void dcnv3_im2col_gpu_kernel(
const int num_kernels, const scalar_t *data_im, const scalar_t *data_offset,
const scalar_t *data_mask, scalar_t *data_col, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int group_channels, const int height_in,
const int width_in, const int height_out, const int width_out, const scalar_t offset_scale) {
CUDA_KERNEL_LOOP(index, num_kernels) {
int _temp = index;
const int c_col = _temp % group_channels;
_temp /= group_channels;
const int sampling_index = _temp;
const int g_col = _temp % group;
_temp /= group;
const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + (_temp % width_out) * stride_w;
_temp /= width_out;
const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + (_temp % height_out) * stride_h;
_temp /= height_out;
const int b_col = _temp;
const int input_size = height_in * width_in;
scalar_t *data_col_ptr = data_col + index;
const int kernel_size = kernel_h * kernel_w;
int data_weight_ptr = sampling_index * kernel_size;
int data_loc_w_ptr = data_weight_ptr << 1;
const int qid_stride = group * group_channels;
scalar_t col = 0;
const scalar_t *data_im_ptr = data_im + b_col * input_size * qid_stride;
// top-left
const scalar_t p0_w_ = p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale;
const scalar_t p0_h_ = p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale;
for (int i = 0; i < kernel_w; ++i) {
for (int j = 0; j < kernel_h; ++j) {
const scalar_t offset_w = data_offset[data_loc_w_ptr];
const scalar_t offset_h = data_offset[data_loc_w_ptr + 1];
const scalar_t loc_w = p0_w_ + (i * dilation_w + offset_w) * offset_scale;
const scalar_t loc_h = p0_h_ + (j * dilation_h + offset_h) * offset_scale;
const scalar_t weight = data_mask[data_weight_ptr];
if (loc_h > -1 && loc_w > -1 && loc_h < height_in && loc_w < width_in) {
col += dcnv3_im2col_bilinear(data_im_ptr, height_in, width_in, group, group_channels,
loc_h, loc_w, g_col, c_col) *
weight;
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
*data_col_ptr = col;
}
}
template <typename scalar_t>
void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im, const scalar_t *data_offset,
const scalar_t *data_mask, scalar_t *data_col, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w, const int pad_h,
const int pad_w, const int dilation_h, const int dilation_w, const int group,
const int group_channels, const int batch_n, const int height_in,
const int width_in, const int height_out, const int width_out,
const scalar_t offset_scale) {
const int num_kernels = batch_n * height_out * width_out * group * group_channels;
const int num_actual_kernels = batch_n * height_out * width_out * group * group_channels;
const int num_threads = CUDA_NUM_THREADS;
dcnv3_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels, data_im, data_offset, data_mask, data_col, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group, group_channels, height_in,
width_in, height_out, width_out, offset_scale);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in dcnv3_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void dcnv3_cuda_forward(const scalar_t *input, const scalar_t *offset, const scalar_t *mask,
scalar_t *output, int batch, int channels, int height_in, int width_in,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int group_channels,
const float offset_scale, const int im2col_step, cudaStream_t stream) {
const int height_out = (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int im2col_step_ = std::min(batch, im2col_step);
const int batch_n = im2col_step_;
auto per_input_size = height_in * width_in * group * group_channels;
auto per_output_size = height_out * width_out * group * group_channels;
auto per_offset_size = height_out * width_out * group * kernel_h * kernel_w * 2;
auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w;
for (int n = 0; n < batch / im2col_step_; ++n) {
// there is only im2col in dcnv3. The same as
// https://github.com/OpenGVLab/InternImage/blob/4fb17721a0f9ab9fb28c7ed48ac1667a247c6da4/classification/ops_dcnv3/src/cuda/dcnv3_cuda.cu#L71
dcnv3_im2col_cuda<scalar_t>(
stream, input + n * im2col_step_ * per_input_size,
offset + n * im2col_step_ * per_offset_size, mask + n * im2col_step_ * per_mask_size,
output + n * im2col_step_ * per_output_size, kernel_h, kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group, group_channels, batch_n, height_in, width_in,
height_out, width_out, offset_scale);
}
}
template <typename scalar_t>
void DeformConvv3ForwardCUDAKernelLauncher(
const scalar_t *input, const scalar_t *offset, const scalar_t *mask, scalar_t *output,
void *workspace, int batch, int channels, int height, int width, int channels_out, int kernel_w,
int kernel_h, int stride_w, int stride_h, int pad_w, int pad_h, int dilation_w, int dilation_h,
int group, int group_channel, float offset_scale, int im2col_step, cudaStream_t stream) {
dcnv3_cuda_forward(input, offset, mask, output, batch, channels, height, width, kernel_h,
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
group_channel, offset_scale, im2col_step, stream);
}
template void DeformConvv3ForwardCUDAKernelLauncher<float>(
const float *input, const float *offset, const float *mask, float *output, void *workspace,
int batch, int channels, int height, int width, int channels_out, int kernel_w, int kernel_h,
int stride_w, int stride_h, int pad_w, int pad_h, int dilation_w, int dilation_h, int group,
int group_channel, float offset_scale, int im2col_step, cudaStream_t stream);

View File

@ -0,0 +1,14 @@
#ifndef TRT_DEFORM_CONV_V3_KERNEL_HPP
#define TRT_DEFORM_CONV_V3_KERNEL_HPP
#include <cuda_runtime.h>
#include "common_cuda_helper.hpp"
template <typename scalar_t>
void DeformConvv3ForwardCUDAKernelLauncher(
const scalar_t* input, const scalar_t* offset, const scalar_t* mask, scalar_t* output,
void* workspace, int batch, int channels, int height, int width, int channels_out, int kernel_w,
int kernel_h, int stride_w, int stride_h, int pad_w, int pad_h, int dilation_w, int dilation_h,
int group, int group_channel, float offset_scale, int im2col_step, cudaStream_t stream);
#endif // TRT_DEFORM_CONV_V3_KERNEL_HPP

View File

@ -69,6 +69,12 @@
- [Inputs](#inputs-10)
- [Outputs](#outputs-10)
- [Type Constraints](#type-constraints-10)
- [TRTDCNv3](#trtdcnv3)
- [Description](#description-11)
- [Parameters](#parameters-11)
- [Inputs](#inputs-11)
- [Outputs](#outputs-11)
- [Type Constraints](#type-constraints-11)
<!-- TOC -->
@ -489,3 +495,51 @@ None
#### Type Constraints
- T:tensor(float32, Linear), tensor(int32, Linear)
### TRTDCNv3
#### Description
TensorRT deformable convolution v3 is used to support [InternImage](https://github.com/OpenGVLab/InternImage). The op
contains only im2col logic even though it is named convolution. For more detail, you may refer to [InternImage](https://github.com/OpenGVLab/InternImage/blob/4fb17721a0f9ab9fb28c7ed48ac1667a247c6da4/classification/ops_dcnv3/src/cuda/dcnv3_cuda.cu#L71)
#### Parameters
| Type | Parameter | Description |
| ------- | ---------------- | --------------------------------- |
| `int` | `kernel_h` | The kernel size of h dim. |
| `int` | `kernel_w` | The kernel size of w dim. |
| `int` | `stride_h` | The stride size of h dim. |
| `int` | `stride_w` | The stride size of w dim. |
| `int` | `pad_h` | The padding size of h dim. |
| `int` | `pad_w` | The padding size of w dim. |
| `int` | `dilation_h` | The dilation size of h dim. |
| `int` | `dilation_w` | The dilation size of w dim. |
| `int` | `group` | The group nums. |
| `int` | `group_channels` | The number of channels per group. |
| `float` | `offset_scale` | The offset cale. |
| `int` | `im2col_step` | The step for img2col. |
#### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>A 4-D Tensor, with shape of [batch, height, width, channels].</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>A 4-D Tensor, with shape of [batch, height, width, channels].</dd>
<dt><tt>inputs[2]</tt>: T</dt>
<dd>A 4-D Tensor, with shape of [batch, height, width, channels].</dd>
</dl>
#### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>A 4-D Tensor, with shape of [batch, height, width, channels].</dd>
</dl>
#### Type Constraints
- T:tensor(float32, Linear), tensor(int32, Linear), tensor(int32, Linear)

View File

@ -69,6 +69,12 @@
- [Inputs](#inputs-10)
- [Outputs](#outputs-10)
- [Type Constraints](#type-constraints-10)
- [TRTDCNv3](#trtdcnv3)
- [Description](#description-11)
- [Parameters](#parameters-11)
- [Inputs](#inputs-11)
- [Outputs](#outputs-11)
- [Type Constraints](#type-constraints-11)
<!-- TOC -->
@ -489,3 +495,51 @@ None
#### Type Constraints
- T:tensor(float32, Linear), tensor(int32, Linear)
### TRTDCNv3
#### Description
TensorRT deformable convolution v3 is used to support [InternImage](https://github.com/OpenGVLab/InternImage). The op
contains only im2col logic even though it is named convolution. For more detail, you may refer to [InternImage](https://github.com/OpenGVLab/InternImage/blob/4fb17721a0f9ab9fb28c7ed48ac1667a247c6da4/classification/ops_dcnv3/src/cuda/dcnv3_cuda.cu#L71)
#### Parameters
| Type | Parameter | Description |
| ------- | ---------------- | --------------------------------- |
| `int` | `kernel_h` | The kernel size of h dim. |
| `int` | `kernel_w` | The kernel size of w dim. |
| `int` | `stride_h` | The stride size of h dim. |
| `int` | `stride_w` | The stride size of w dim. |
| `int` | `pad_h` | The padding size of h dim. |
| `int` | `pad_w` | The padding size of w dim. |
| `int` | `dilation_h` | The dilation size of h dim. |
| `int` | `dilation_w` | The dilation size of w dim. |
| `int` | `group` | The group nums. |
| `int` | `group_channels` | The number of channels per group. |
| `float` | `offset_scale` | The offset cale. |
| `int` | `im2col_step` | The step for img2col. |
#### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>A 4-D Tensor, with shape of [batch, height, width, channels].</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>A 4-D Tensor, with shape of [batch, height, width, channels].</dd>
<dt><tt>inputs[2]</tt>: T</dt>
<dd>A 4-D Tensor, with shape of [batch, height, width, channels].</dd>
</dl>
#### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>A 4-D Tensor, with shape of [batch, height, width, channels].</dd>
</dl>
#### Type Constraints
- T:tensor(float32, Linear), tensor(int32, Linear), tensor(int32, Linear)

View File

@ -182,7 +182,9 @@ with torch.no_grad():
执行上述脚本我们导出成功了一个ONNX模型 `srcnn.onnx`。用[netron](https://netron.app/)打开这个模型可视化如下:
![](../../../resources/tutorial/srcnn.svg)
<div align="center">
<img src="https://user-images.githubusercontent.com/28671653/241883709-e21d60d0-1b1d-4665-af14-9c1240484773.png"/>
</div>
直接将该模型转换成TensorRT模型也是不可行的这是因为TensorRT还无法解析 `DynamicTRTResize` 节点。而想要解析该节点我们必须为TensorRT添加c++代码,实现该插件。
@ -274,7 +276,9 @@ class DynamicTRTResizeCreator : public TRTPluginCreatorBase {
在这样一份头文件中DynamicTRTResize类进行了如下的套娃继承
![](../../../resources/tutorial/IPluginV2DynamicExt.svg)
<div align="center">
<img src="https://user-images.githubusercontent.com/28671653/241883700-0bee87a0-6d6a-478b-8a71-983a4e47b670.png"/>
</div>
从上面的图片和代码中我们发现,插件类`DynamicTRTResize`中我们定义了私有变量`mAlignCorners`,该变量表示是否`align corners`。此外只要实现构造析构函数和TensoRT中三个基类的方法即可。其中构造函数有二分别用于创建插件和反序列化插件。而基类方法中

View File

@ -286,8 +286,13 @@ class Segmentation(BaseTask):
postprocess = self.model_cfg.model.decode_head
if isinstance(postprocess, list):
postprocess = postprocess[-1]
postprocess = postprocess.copy()
with_argmax = get_codebase_config(self.deploy_cfg).get(
'with_argmax', True)
# set with_argmax=True for this special case
if postprocess['num_classes'] == 2 and \
postprocess['out_channels'] == 1:
with_argmax = True
postprocess['with_argmax'] = with_argmax
return postprocess

View File

@ -25,10 +25,14 @@ def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs):
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
"""
seg_logit = self.encode_decode(img, img_meta)
seg_logit = F.softmax(seg_logit, dim=1)
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
return seg_logit
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
if self.out_channels == 1:
seg_logit = F.sigmoid(seg_logit)
seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit)
else:
seg_pred = F.softmax(seg_logit, dim=1)
if get_codebase_config(ctx.cfg).get('with_argmax', True):
seg_pred = seg_pred.argmax(dim=1, keepdim=True)
return seg_pred
@ -51,5 +55,10 @@ def encoder_decoder__simple_test__rknn(ctx, self, img, img_meta, **kwargs):
torch.Tensor: Output segmentation map pf shape [N, C, H, W].
"""
seg_logit = self.encode_decode(img, img_meta)
seg_logit = F.softmax(seg_logit, dim=1)
return seg_logit
if self.out_channels == 1:
seg_logit = F.sigmoid(seg_logit)
seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit)
else:
seg_pred = F.softmax(seg_logit, dim=1)
return seg_pred

View File

@ -4,5 +4,5 @@ mmdet>=2.19.0,<=2.20.0
mmedit<1.0.0
mmocr>=0.3.0,<=0.4.1
mmpose>=0.24.0,<=0.25.1
mmrazor>=0.3.0
mmrazor>=0.3.0,<=0.3.1
mmsegmentation<1.0.0

View File

@ -1,11 +1,11 @@
h5py
mmcls>=0.21.0,<=0.23.0
mmdet>=2.19.0,<=2.20.0
mmedit
mmedit<1.0.0
mmocr>=0.3.0,<=0.4.1
mmpose>=0.24.0,<=0.25.1
mmrazor>=0.3.0
mmsegmentation
mmrazor>=0.3.0,<=0.3.1
mmsegmentation<1.0.0
onnxruntime>=1.8.0
openvino-dev>=2022.3.0
tqdm

View File

@ -4,3 +4,4 @@ onnx>=1.8.0
opencv-python==4.5.4.60
sphinxcontrib-mermaid
torch
urllib3<2.0.0

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 11 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 15 KiB