mirror of https://github.com/open-mmlab/mmcv.git
[Feature]: add modulated deformable conv TensorRT support (#1078)
* add modulated dcn, better dcn plugin * clangformat * update documentationpull/1105/head
parent
1b59409e94
commit
004c00675f
|
@ -51,6 +51,12 @@
|
|||
- [Inputs](#inputs-7)
|
||||
- [Outputs](#outputs-7)
|
||||
- [Type Constraints](#type-constraints-7)
|
||||
- [MMCVModulatedDeformConv2d](#mmcvmodulateddeformconv2d)
|
||||
- [Description](#description-8)
|
||||
- [Parameters](#parameters-8)
|
||||
- [Inputs](#inputs-8)
|
||||
- [Outputs](#outputs-8)
|
||||
- [Type Constraints](#type-constraints-8)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -345,3 +351,45 @@ y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance a
|
|||
### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
||||
## MMCVModulatedDeformConv2d
|
||||
|
||||
### Description
|
||||
|
||||
Perform Modulated Deformable Convolution on input feature, read [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/abs/1811.11168?from=timeline) for detail.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Type | Parameter | Description |
|
||||
| -------------- | ------------------ | ------------------------------------------------------------------------------------- |
|
||||
| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) |
|
||||
| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) |
|
||||
| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) |
|
||||
| `int` | `deformable_group` | Groups of deformable offset. |
|
||||
| `int` | `group` | Split input into groups. `input_channel` should be divisible by the number of groups. |
|
||||
|
||||
### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the number of channels, inH and inW are the height and width of the data.</dd>
|
||||
<dt><tt>inputs[1]</tt>: T</dt>
|
||||
<dd>Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
|
||||
<dt><tt>inputs[2]</tt>: T</dt>
|
||||
<dd>Input mask; 4-D tensor of shape (N, deformable_group* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
|
||||
<dt><tt>inputs[3]</tt>: T</dt>
|
||||
<dd>Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).</dd>
|
||||
<dt><tt>inputs[4]</tt>: T, optional</dt>
|
||||
<dd>Input weight; 1-D tensor of shape (output_channel).</dd>
|
||||
</dl>
|
||||
|
||||
### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>Output feature; 4-D tensor of shape (N, output_channel, outH, outW).</dd>
|
||||
</dl>
|
||||
|
||||
### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
|
|
@ -31,9 +31,10 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
|
|||
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
|
||||
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
|
||||
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
|
||||
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
|
||||
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
|
||||
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master |
|
||||
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | 1.3.5 |
|
||||
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | 1.3.5 |
|
||||
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | 1.3.5 |
|
||||
| MMCVModulatedDeformConv2d | [MMCVModulatedDeformConv2d](./tensorrt_custom_ops.md#mmcvmodulateddeformconv2d) | master |
|
||||
|
||||
Notes
|
||||
|
||||
|
|
|
@ -66,11 +66,16 @@
|
|||
#ifndef MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
|
||||
#define MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
|
||||
|
||||
#include <float.h>
|
||||
#ifdef MMCV_WITH_TRT
|
||||
#include "common_cuda_helper.hpp"
|
||||
#else // MMCV_WITH_TRT
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#else // MMCV_USE_PARROTS
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
#endif // MMCV_USE_PARROTS
|
||||
#endif // MMCV_WITH_TRT
|
||||
|
||||
template <typename T>
|
||||
__device__ T dmcn_im2col_bilinear(const T *input, const int data_width,
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
#include <cublas_v2.h>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "trt_cuda_helper.cuh"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
@ -64,3 +66,25 @@ void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
|
|||
template void memcpyPermute<float>(float *dst, const float *src, int *src_size,
|
||||
int *permute, int src_dim,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <>
|
||||
cublasStatus_t cublasGemmWrap<float>(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n,
|
||||
int k, const float *alpha, const float *A,
|
||||
int lda, const float *B, int ldb,
|
||||
const float *beta, float *C, int ldc) {
|
||||
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
|
||||
beta, C, ldc);
|
||||
}
|
||||
|
||||
template <>
|
||||
cublasStatus_t cublasGemmWrap<half>(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n,
|
||||
int k, const half *alpha, const half *A,
|
||||
int lda, const half *B, int ldb,
|
||||
const half *beta, half *C, int ldc) {
|
||||
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
|
||||
beta, C, ldc);
|
||||
}
|
||||
|
|
|
@ -32,9 +32,7 @@ DeformableConvPluginDynamic::DeformableConvPluginDynamic(
|
|||
mDilation(dilation),
|
||||
mDeformableGroup(deformableGroup),
|
||||
mGroup(group),
|
||||
mIm2colStep(im2colStep) {
|
||||
cublasCreate(&m_cublas_handle);
|
||||
}
|
||||
mIm2colStep(im2colStep) {}
|
||||
|
||||
DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name,
|
||||
const void *data,
|
||||
|
@ -46,12 +44,8 @@ DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name,
|
|||
deserialize_value(&data, &length, &mDeformableGroup);
|
||||
deserialize_value(&data, &length, &mGroup);
|
||||
deserialize_value(&data, &length, &mIm2colStep);
|
||||
cublasCreate(&m_cublas_handle);
|
||||
}
|
||||
DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {
|
||||
// destroy cublas handle
|
||||
cublasDestroy(m_cublas_handle);
|
||||
}
|
||||
DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const {
|
||||
DeformableConvPluginDynamic *plugin =
|
||||
|
@ -127,11 +121,6 @@ int DeformableConvPluginDynamic::enqueue(
|
|||
const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workSpace, cudaStream_t stream) {
|
||||
if (m_cuda_stream != stream) {
|
||||
cublasSetStream(m_cublas_handle, stream);
|
||||
m_cuda_stream = stream;
|
||||
}
|
||||
|
||||
int batch_size = inputDesc[0].dims.d[0];
|
||||
int inputChannel = inputDesc[0].dims.d[1];
|
||||
int inputHeight = inputDesc[0].dims.d[2];
|
||||
|
@ -204,6 +193,14 @@ void DeformableConvPluginDynamic::destroy() {
|
|||
delete this;
|
||||
}
|
||||
|
||||
void DeformableConvPluginDynamic::attachToContext(
|
||||
cudnnContext *cudnnContext, cublasContext *cublasContext,
|
||||
nvinfer1::IGpuAllocator *gpuAllocator) {
|
||||
m_cublas_handle = cublasContext;
|
||||
}
|
||||
|
||||
void DeformableConvPluginDynamic::detachFromContext() {}
|
||||
|
||||
void DeformableConvPluginDynamic::setPluginNamespace(const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
|
@ -32,38 +31,6 @@ void trt_deformable_im2col(const T* data_input, const T* data_offset,
|
|||
cudaCheckError();
|
||||
}
|
||||
|
||||
// used to switch gemm between fp32 and fp16
|
||||
template <typename scalar_t>
|
||||
cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const scalar_t* alpha, const scalar_t* A, int lda,
|
||||
const scalar_t* B, int ldb, const scalar_t* beta,
|
||||
scalar_t* C, int ldc) {
|
||||
return CUBLAS_STATUS_INTERNAL_ERROR;
|
||||
}
|
||||
|
||||
template <>
|
||||
cublasStatus_t cublasGemmWrap<float>(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n,
|
||||
int k, const float* alpha, const float* A,
|
||||
int lda, const float* B, int ldb,
|
||||
const float* beta, float* C, int ldc) {
|
||||
cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C,
|
||||
ldc);
|
||||
}
|
||||
|
||||
template <>
|
||||
cublasStatus_t cublasGemmWrap<half>(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n,
|
||||
int k, const half* alpha, const half* A,
|
||||
int lda, const half* B, int ldb,
|
||||
const half* beta, half* C, int ldc) {
|
||||
cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C,
|
||||
ldc);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void DeformConvForwardCUDAKernelLauncher(
|
||||
const scalar_t* input, const scalar_t* weight, const scalar_t* offset,
|
||||
|
|
|
@ -0,0 +1,307 @@
|
|||
#include "trt_modulated_deform_conv.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
void ModulatedDeformConvForwardCUDAKernelLauncher_float(
|
||||
const float *input, const float *weight, const float *bias,
|
||||
const float *offset, const float *mask, float *output, void *workspace,
|
||||
int batch, int channels, int height, int width, int channels_out,
|
||||
int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w,
|
||||
int pad_h, int dilation_w, int dilation_h, int group, int deformable_group,
|
||||
int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream);
|
||||
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *PLUGIN_NAME{"MMCVModulatedDeformConv2d"};
|
||||
} // namespace
|
||||
|
||||
nvinfer1::PluginFieldCollection
|
||||
ModulatedDeformableConvPluginDynamicCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField>
|
||||
ModulatedDeformableConvPluginDynamicCreator::mPluginAttributes;
|
||||
|
||||
ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(
|
||||
const std::string &name, const nvinfer1::Dims stride,
|
||||
const nvinfer1::Dims padding, const nvinfer1::Dims dilation,
|
||||
const int deformableGroup, const int group)
|
||||
: mLayerName(name),
|
||||
mStride(stride),
|
||||
mPadding(padding),
|
||||
mDilation(dilation),
|
||||
mDeformableGroup(deformableGroup),
|
||||
mGroup(group) {
|
||||
mWithBias = false;
|
||||
}
|
||||
|
||||
ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(
|
||||
const std::string name, const void *data, size_t length)
|
||||
: mLayerName(name) {
|
||||
deserialize_value(&data, &length, &mStride);
|
||||
deserialize_value(&data, &length, &mPadding);
|
||||
deserialize_value(&data, &length, &mDilation);
|
||||
deserialize_value(&data, &length, &mDeformableGroup);
|
||||
deserialize_value(&data, &length, &mGroup);
|
||||
mWithBias = false;
|
||||
}
|
||||
ModulatedDeformableConvPluginDynamic::~ModulatedDeformableConvPluginDynamic() {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *ModulatedDeformableConvPluginDynamic::clone()
|
||||
const {
|
||||
ModulatedDeformableConvPluginDynamic *plugin =
|
||||
new ModulatedDeformableConvPluginDynamic(
|
||||
mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = 4;
|
||||
ret.d[0] = inputs[0].d[0];
|
||||
ret.d[1] = inputs[2].d[0];
|
||||
|
||||
ret.d[2] = inputs[1].d[2];
|
||||
ret.d[3] = inputs[1].d[3];
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
if (pos == 0) {
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
|
||||
} else {
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
}
|
||||
}
|
||||
|
||||
void ModulatedDeformableConvPluginDynamic::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {
|
||||
if (nbInputs == 5) {
|
||||
mWithBias = true;
|
||||
}
|
||||
}
|
||||
|
||||
size_t ModulatedDeformableConvPluginDynamic::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
|
||||
int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
|
||||
|
||||
int batch_size = inputs[0].dims.d[0];
|
||||
int nInputPlane = inputs[0].dims.d[1];
|
||||
int inputHeight = inputs[0].dims.d[2];
|
||||
int inputWidth = inputs[0].dims.d[3];
|
||||
|
||||
int nOutputPlane = outputs[0].dims.d[1];
|
||||
int outputHeight = outputs[0].dims.d[2];
|
||||
int outputWidth = outputs[0].dims.d[3];
|
||||
|
||||
int kW = inputs[3].dims.d[2];
|
||||
int kH = inputs[3].dims.d[3];
|
||||
int im2col_step = std::min(32, batch_size);
|
||||
|
||||
size_t col_size = mmcv::getAlignedSize(nInputPlane * kW * kH * outputHeight *
|
||||
outputWidth * sizeof_dtype);
|
||||
|
||||
return col_size;
|
||||
}
|
||||
|
||||
int ModulatedDeformableConvPluginDynamic::enqueue(
|
||||
const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workSpace, cudaStream_t stream) {
|
||||
int batch = inputDesc[0].dims.d[0];
|
||||
int channels = inputDesc[0].dims.d[1];
|
||||
int height = inputDesc[0].dims.d[2];
|
||||
int width = inputDesc[0].dims.d[3];
|
||||
int channels_out = outputDesc[0].dims.d[1];
|
||||
int kernel_h = inputDesc[3].dims.d[2];
|
||||
int kernel_w = inputDesc[3].dims.d[3];
|
||||
|
||||
const void *x = inputs[0];
|
||||
const void *offset = inputs[1];
|
||||
const void *mask = inputs[2];
|
||||
const void *weight = inputs[3];
|
||||
const void *bias = mWithBias ? inputs[4] : nullptr;
|
||||
void *output = outputs[0];
|
||||
int im2col_step = std::min(batch, 32);
|
||||
|
||||
// TODO: add fp16 support
|
||||
auto data_type = inputDesc[0].type;
|
||||
switch (data_type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
ModulatedDeformConvForwardCUDAKernelLauncher_float(
|
||||
(float *)x, (float *)weight, (float *)bias, (float *)offset,
|
||||
(float *)mask, (float *)output, workSpace, batch, channels, height,
|
||||
width, channels_out, kernel_w, kernel_h, mStride.d[0], mStride.d[1],
|
||||
mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup,
|
||||
mDeformableGroup, im2col_step, m_cublas_handle, stream);
|
||||
break;
|
||||
default:
|
||||
return 1;
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType ModulatedDeformableConvPluginDynamic::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *ModulatedDeformableConvPluginDynamic::getPluginType() const {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *ModulatedDeformableConvPluginDynamic::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int ModulatedDeformableConvPluginDynamic::getNbOutputs() const { return 1; }
|
||||
|
||||
int ModulatedDeformableConvPluginDynamic::initialize() { return 0; }
|
||||
|
||||
void ModulatedDeformableConvPluginDynamic::terminate() {}
|
||||
|
||||
size_t ModulatedDeformableConvPluginDynamic::getSerializationSize() const {
|
||||
return sizeof(mStride) + sizeof(mPadding) + sizeof(mDilation) +
|
||||
sizeof(mDeformableGroup) + sizeof(mGroup);
|
||||
}
|
||||
|
||||
void ModulatedDeformableConvPluginDynamic::serialize(void *buffer) const {
|
||||
serialize_value(&buffer, mStride);
|
||||
serialize_value(&buffer, mPadding);
|
||||
serialize_value(&buffer, mDilation);
|
||||
serialize_value(&buffer, mDeformableGroup);
|
||||
serialize_value(&buffer, mGroup);
|
||||
}
|
||||
|
||||
void ModulatedDeformableConvPluginDynamic::destroy() {
|
||||
// This gets called when the network containing plugin is destroyed
|
||||
delete this;
|
||||
}
|
||||
|
||||
void ModulatedDeformableConvPluginDynamic::attachToContext(
|
||||
cudnnContext *cudnnContext, cublasContext *cublasContext,
|
||||
nvinfer1::IGpuAllocator *gpuAllocator) {
|
||||
m_cublas_handle = cublasContext;
|
||||
}
|
||||
|
||||
void ModulatedDeformableConvPluginDynamic::detachFromContext() {}
|
||||
|
||||
void ModulatedDeformableConvPluginDynamic::setPluginNamespace(
|
||||
const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *ModulatedDeformableConvPluginDynamic::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
////////////////////// creator /////////////////////////////
|
||||
|
||||
ModulatedDeformableConvPluginDynamicCreator::
|
||||
ModulatedDeformableConvPluginDynamicCreator() {
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("padding"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("groups"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups"));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *ModulatedDeformableConvPluginDynamicCreator::getPluginName() const {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *ModulatedDeformableConvPluginDynamicCreator::getPluginVersion()
|
||||
const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
const nvinfer1::PluginFieldCollection *
|
||||
ModulatedDeformableConvPluginDynamicCreator::getFieldNames() {
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
nvinfer1::Dims stride{2, {1, 1}};
|
||||
nvinfer1::Dims padding{2, {0, 0}};
|
||||
nvinfer1::Dims dilation{2, {1, 1}};
|
||||
int deformableGroup = 1;
|
||||
int group = 1;
|
||||
|
||||
for (int i = 0; i < fc->nbFields; i++) {
|
||||
if (fc->fields[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string field_name(fc->fields[i].name);
|
||||
|
||||
if (field_name.compare("deformable_group") == 0) {
|
||||
deformableGroup = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
|
||||
if (field_name.compare("group") == 0) {
|
||||
group = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
|
||||
if (field_name.compare("stride") == 0) {
|
||||
stride.nbDims = 2;
|
||||
stride.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
stride.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
|
||||
}
|
||||
|
||||
if (field_name.compare("padding") == 0) {
|
||||
padding.nbDims = 2;
|
||||
padding.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
padding.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
|
||||
}
|
||||
|
||||
if (field_name.compare("dilation") == 0) {
|
||||
dilation.nbDims = 2;
|
||||
dilation.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
dilation.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
|
||||
}
|
||||
}
|
||||
|
||||
ModulatedDeformableConvPluginDynamic *plugin =
|
||||
new ModulatedDeformableConvPluginDynamic(name, stride, padding, dilation,
|
||||
deformableGroup, group);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *
|
||||
ModulatedDeformableConvPluginDynamicCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) {
|
||||
auto plugin =
|
||||
new ModulatedDeformableConvPluginDynamic(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void ModulatedDeformableConvPluginDynamicCreator::setPluginNamespace(
|
||||
const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *ModulatedDeformableConvPluginDynamicCreator::getPluginNamespace()
|
||||
const {
|
||||
return mNamespace.c_str();
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "modulated_deform_conv_cuda_kernel.cuh"
|
||||
#include "trt_cuda_helper.cuh"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
template <typename T>
|
||||
void trt_modulated_deformable_im2col(
|
||||
const T* data_im_, const T* data_offset_, const T* data_mask_,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, T* data_col_,
|
||||
cudaStream_t stream) {
|
||||
// num_axes should be smaller than block size
|
||||
const int channel_per_deformable_group = channels / deformable_group;
|
||||
const int num_kernels = channels * batch_size * height_col * width_col;
|
||||
|
||||
modulated_deformable_im2col_gpu_kernel<T>
|
||||
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im,
|
||||
kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
|
||||
dilation_w, channel_per_deformable_group, batch_size, channels,
|
||||
deformable_group, height_col, width_col, data_col_);
|
||||
|
||||
cudaCheckError();
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void output_add_bias_kernel(scalar_t* output, const scalar_t* bias,
|
||||
size_t step_batch, size_t step_channel,
|
||||
size_t n) {
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
output[index] += bias[(index % step_batch) / step_channel];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void output_add_bias(scalar_t* output, const scalar_t* bias,
|
||||
size_t batch, size_t channel, size_t height,
|
||||
size_t width, cudaStream_t stream) {
|
||||
size_t step_channel = height * width;
|
||||
size_t step_batch = step_channel * channel;
|
||||
size_t n = step_batch * batch;
|
||||
output_add_bias_kernel<<<GET_BLOCKS(n), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output, bias, step_batch, step_channel, n);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void ModulatedDeformConvForwardCUDAKernelLauncher(
|
||||
const scalar_t* input, const scalar_t* weight, const scalar_t* bias,
|
||||
const scalar_t* offset, const scalar_t* mask, scalar_t* output,
|
||||
void* workspace, int batch, int channels, int height, int width,
|
||||
int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h,
|
||||
int pad_w, int pad_h, int dilation_w, int dilation_h, int group,
|
||||
int deformable_group, int im2col_step, cublasHandle_t cublas_handle,
|
||||
cudaStream_t stream) {
|
||||
size_t sizeof_dtype = sizeof(scalar_t);
|
||||
bool with_bias = (bias != nullptr);
|
||||
|
||||
im2col_step = std::min(int(batch), im2col_step);
|
||||
assert(batch % im2col_step == 0);
|
||||
const int channels_kernel = channels / group;
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
scalar_t* columns = (scalar_t*)workspace;
|
||||
|
||||
const size_t input_step = channels * height * width;
|
||||
const size_t offset_step =
|
||||
deformable_group * kernel_h * kernel_w * 2 * height * width;
|
||||
const size_t mask_step =
|
||||
deformable_group * kernel_h * kernel_w * height * width;
|
||||
const size_t out_step = channels_out * height_out * width_out;
|
||||
const size_t out_group_step = out_step / group;
|
||||
const size_t col_g_step =
|
||||
channels * kernel_w * kernel_h / group * height_out * width_out;
|
||||
const size_t weight_g_step =
|
||||
channels_out / group * channels / group * kernel_h * kernel_w;
|
||||
|
||||
const int m = channels_out / group;
|
||||
const int n = height_out * width_out;
|
||||
const int k = channels / group * kernel_h * kernel_w;
|
||||
scalar_t alpha = 1.;
|
||||
scalar_t beta = 0.;
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
const scalar_t* input_start = input + b * input_step;
|
||||
const scalar_t* offset_start = offset + b * offset_step;
|
||||
const scalar_t* mask_start = mask + b * mask_step;
|
||||
trt_modulated_deformable_im2col<scalar_t>(
|
||||
input_start, offset_start, mask_start, 1, channels, height, width,
|
||||
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||
stride_w, dilation_h, dilation_w, deformable_group, columns, stream);
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
const scalar_t* weight_start = weight + g * weight_g_step;
|
||||
scalar_t* col_start = columns + g * col_g_step;
|
||||
scalar_t* out_buffer_start = output + b * out_step + g * out_group_step;
|
||||
|
||||
// cudaMemsetAsync(out_buffer_start, 0, 1, stream);
|
||||
cublasGemmWrap<scalar_t>(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
|
||||
&alpha, col_start, n, weight_start, k, &beta,
|
||||
out_buffer_start, n);
|
||||
cudaCheckError();
|
||||
}
|
||||
}
|
||||
|
||||
if (with_bias) {
|
||||
output_add_bias<scalar_t>(output, bias, batch, channels_out, height_out,
|
||||
width_out, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void ModulatedDeformConvForwardCUDAKernelLauncher_float(
|
||||
const float* input, const float* weight, const float* bias,
|
||||
const float* offset, const float* mask, float* output, void* workspace,
|
||||
int batch, int channels, int height, int width, int channels_out,
|
||||
int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w,
|
||||
int pad_h, int dilation_w, int dilation_h, int group, int deformable_group,
|
||||
int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) {
|
||||
ModulatedDeformConvForwardCUDAKernelLauncher<float>(
|
||||
input, weight, bias, offset, mask, output, workspace, batch, channels,
|
||||
height, width, channels_out, kernel_w, kernel_h, stride_w, stride_h,
|
||||
pad_w, pad_h, dilation_w, dilation_h, group, deformable_group,
|
||||
im2col_step, cublas_handle, stream);
|
||||
}
|
|
@ -4,6 +4,7 @@
|
|||
#include "trt_deform_conv.hpp"
|
||||
#include "trt_grid_sampler.hpp"
|
||||
#include "trt_instance_norm.hpp"
|
||||
#include "trt_modulated_deform_conv.hpp"
|
||||
#include "trt_nms.hpp"
|
||||
#include "trt_roi_align.hpp"
|
||||
#include "trt_scatternd.hpp"
|
||||
|
@ -12,6 +13,7 @@ REGISTER_TENSORRT_PLUGIN(CumMaxPluginDynamicCreator);
|
|||
REGISTER_TENSORRT_PLUGIN(CumMinPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(GridSamplerDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(ModulatedDeformableConvPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#ifndef TRT_CUDA_HELPER_HPP
|
||||
#define TRT_CUDA_HELPER_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
|
@ -24,7 +25,16 @@
|
|||
* @param[in] stream cuda stream handle
|
||||
*/
|
||||
template <class scalar_t>
|
||||
void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
|
||||
int *permute, int src_dim, cudaStream_t stream = 0);
|
||||
void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size,
|
||||
int* permute, int src_dim, cudaStream_t stream = 0);
|
||||
|
||||
template <typename scalar_t>
|
||||
cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
|
||||
cublasOperation_t transb, int m, int n, int k,
|
||||
const scalar_t* alpha, const scalar_t* A, int lda,
|
||||
const scalar_t* B, int ldb, const scalar_t* beta,
|
||||
scalar_t* C, int ldc) {
|
||||
return CUBLAS_STATUS_INTERNAL_ERROR;
|
||||
}
|
||||
|
||||
#endif // TRT_CUDA_HELPER_HPP
|
||||
|
|
|
@ -44,6 +44,9 @@ class DeformableConvPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
|
|||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
|
||||
nvinfer1::IGpuAllocator *gpuAllocator) override;
|
||||
void detachFromContext() override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
|
@ -74,7 +77,6 @@ class DeformableConvPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
|
|||
int mIm2colStep;
|
||||
|
||||
cublasHandle_t m_cublas_handle;
|
||||
cudaStream_t m_cuda_stream;
|
||||
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
#ifndef TRT_MODULATED_DEFORM_CONV_HPP
|
||||
#define TRT_MODULATED_DEFORM_CONV_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
class ModulatedDeformableConvPluginDynamic
|
||||
: public nvinfer1::IPluginV2DynamicExt {
|
||||
public:
|
||||
ModulatedDeformableConvPluginDynamic(const std::string &name,
|
||||
const nvinfer1::Dims stride,
|
||||
const nvinfer1::Dims padding,
|
||||
const nvinfer1::Dims dilation,
|
||||
const int deformableGroup,
|
||||
const int group);
|
||||
|
||||
ModulatedDeformableConvPluginDynamic(const std::string name, const void *data,
|
||||
size_t length);
|
||||
|
||||
ModulatedDeformableConvPluginDynamic() = delete;
|
||||
|
||||
~ModulatedDeformableConvPluginDynamic();
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
|
||||
nvinfer1::IGpuAllocator *gpuAllocator) override;
|
||||
void detachFromContext() override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
int initialize() override;
|
||||
void terminate() override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
void destroy() override;
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
const std::string mLayerName;
|
||||
std::string mNamespace;
|
||||
|
||||
nvinfer1::Dims mStride;
|
||||
nvinfer1::Dims mPadding;
|
||||
nvinfer1::Dims mDilation;
|
||||
int mDeformableGroup;
|
||||
int mGroup;
|
||||
bool mWithBias;
|
||||
|
||||
cublasHandle_t m_cublas_handle;
|
||||
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
|
||||
using nvinfer1::IPluginV2DynamicExt::enqueue;
|
||||
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
|
||||
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
|
||||
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
|
||||
};
|
||||
|
||||
class ModulatedDeformableConvPluginDynamicCreator
|
||||
: public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
ModulatedDeformableConvPluginDynamicCreator();
|
||||
|
||||
const char *getPluginName() const override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
|
||||
const nvinfer1::PluginFieldCollection *getFieldNames() override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
std::string mNamespace;
|
||||
};
|
||||
#endif // TRT_MODULATED_DEFORM_CONV_HPP
|
|
@ -20,13 +20,12 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
@staticmethod
|
||||
def symbolic(g, input, offset, mask, weight, bias, stride, padding,
|
||||
dilation, groups, deform_groups):
|
||||
input_tensors = [input, offset, mask, weight]
|
||||
if bias is not None:
|
||||
input_tensors.append(bias)
|
||||
return g.op(
|
||||
'MMCVModulatedDeformConv2d',
|
||||
input,
|
||||
offset,
|
||||
mask,
|
||||
weight,
|
||||
bias,
|
||||
'mmcv::MMCVModulatedDeformConv2d',
|
||||
*input_tensors,
|
||||
stride_i=stride,
|
||||
padding_i=padding,
|
||||
dilation_i=dilation,
|
||||
|
|
|
@ -406,6 +406,77 @@ def test_deform_conv():
|
|||
assert torch.allclose(pytorch_results, trt_results)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('with_bias', [True, False])
|
||||
def test_modulated_deform_conv(with_bias):
|
||||
try:
|
||||
from mmcv.ops import ModulatedDeformConv2dPack
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('test requires compilation')
|
||||
|
||||
input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
|
||||
|
||||
x = torch.Tensor(input).cuda()
|
||||
model = ModulatedDeformConv2dPack(
|
||||
1,
|
||||
1,
|
||||
kernel_size=(2, 2),
|
||||
stride=1,
|
||||
padding=1,
|
||||
deform_groups=1,
|
||||
bias=with_bias)
|
||||
model.weight.data.fill_(1.)
|
||||
model.type(torch.float32)
|
||||
model = model.cuda().eval()
|
||||
|
||||
input_names = ['input']
|
||||
output_names = ['output']
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
model, (x.clone(), ),
|
||||
onnx_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=11)
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
# create trt engine and wraper
|
||||
opt_shape_dict = {
|
||||
'input': [list(x.shape), list(x.shape),
|
||||
list(x.shape)],
|
||||
}
|
||||
# trt config
|
||||
fp16_mode = False
|
||||
max_workspace_size = 1 << 30
|
||||
|
||||
trt_engine = onnx2trt(
|
||||
onnx_model,
|
||||
opt_shape_dict,
|
||||
fp16_mode=fp16_mode,
|
||||
max_workspace_size=max_workspace_size)
|
||||
|
||||
save_trt_engine(trt_engine, trt_file)
|
||||
trt_model = TRTWrapper(trt_file, input_names, output_names)
|
||||
|
||||
with torch.no_grad():
|
||||
trt_outputs = trt_model({'input': x.clone()})
|
||||
trt_results = trt_outputs['output']
|
||||
|
||||
# compute pytorch_output
|
||||
with torch.no_grad():
|
||||
pytorch_results = model(x.clone())
|
||||
|
||||
# allclose
|
||||
if os.path.exists(onnx_file):
|
||||
os.remove(onnx_file)
|
||||
if os.path.exists(trt_file):
|
||||
os.remove(trt_file)
|
||||
torch.testing.assert_allclose(pytorch_results, trt_results)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', ['bilinear', 'nearest'])
|
||||
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
|
||||
@pytest.mark.parametrize('align_corners', [True, False])
|
||||
|
|
Loading…
Reference in New Issue