[Feature] Add trt ort mdcn plugins (#43)

* add trt mdcn plugin

* add onnxruntime mdcn op

* add mdcn trt ort plugins

* fix lint

* remove comment

* remove plugin condition lines

* apply new form

* use serialized_size
pull/12/head
AllentDan 2021-08-25 10:17:21 +08:00 committed by GitHub
parent e5dc959276
commit 8fe8056080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1310 additions and 22 deletions

View File

@ -8,7 +8,8 @@ link_directories(${ONNXRUNTIME_DIR}/lib)
# add plugin source
set(PLUGIN_LISTS grid_sample
roi_align)
roi_align
modulated_deform_conv)
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
file(GLOB PLUGIN_OPS_SRCS ${PLUGIN_ITER}/*.cpp ${PLUGIN_ITER}/*.cu)

View File

@ -0,0 +1,296 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "modulated_deform_conv.h"
#include <cmath>
#include <vector>
#include "ort_utils.h"
namespace mmlab {
float bilinear_interpolate_2d(const float *src, const int64_t src_h,
const int64_t src_w, const float h,
const float w) {
if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) {
return 0;
}
int64_t h_low = floor(h);
int64_t w_low = floor(w);
int64_t h_high = h_low + 1;
int64_t w_high = w_low + 1;
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh;
float hw = 1 - lw;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low];
float v2 = 0;
if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high];
float v3 = 0;
if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low];
float v4 = 0;
if (h_high <= src_h - 1 && w_high <= src_w - 1)
v4 = src[h_high * src_w + w_high];
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
// output: (channels * kernel_h * kernel_w, dst_h * dst_w)
void deformable_im2col_2d(const float *input, const float *offset,
const float *mask, const int64_t src_h,
const int64_t src_w, const int64_t kernel_h,
const int64_t kernel_w, const int64_t pad_h,
const int64_t pad_w, const int64_t stride_h,
const int64_t stride_w, const int64_t dilation_h,
const int64_t dilation_w, const int64_t channels,
const int64_t offset_groups, const int64_t dst_h,
const int64_t dst_w, const bool use_mask,
float *columns) {
const int64_t workload = channels * dst_h * dst_w;
for (int64_t index = 0; index != workload; ++index) {
const int64_t ow = index % dst_w;
const int64_t oh = (index / dst_w) % dst_h;
const int64_t ic = index / (dst_w * dst_h);
const int64_t oc = ic * kernel_h * kernel_w;
int64_t c_per_offset_grp = channels / offset_groups;
const int64_t grp_idx = ic / c_per_offset_grp;
auto columns_ptr = columns + (oc * (dst_h * dst_w) + oh * dst_w + ow);
auto input_ptr = input + ic * (src_h * src_w);
auto offset_ptr =
offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w;
auto mask_ptr = mask;
if (use_mask) {
mask_ptr += grp_idx * kernel_h * kernel_w * dst_h * dst_w;
}
for (int64_t kh = 0; kh < kernel_h; ++kh) {
for (int64_t kw = 0; kw < kernel_w; ++kw) {
const int64_t mask_idx = kh * kernel_w + kw;
const int64_t offset_idx = 2 * mask_idx;
float mask_value = 1;
if (use_mask) {
mask_value = mask_ptr[mask_idx * (dst_h * dst_w) + oh * dst_w + ow];
}
const float offset_h =
offset_ptr[offset_idx * (dst_h * dst_w) + oh * dst_w + ow];
const float offset_w =
offset_ptr[(offset_idx + 1) * (dst_h * dst_w) + oh * dst_w + ow];
const float ih = (oh * stride_h - pad_h) + kh * dilation_h + offset_h;
const float iw = (ow * stride_w - pad_w) + kw * dilation_w + offset_w;
*columns_ptr = mask_value *
bilinear_interpolate_2d(input_ptr, src_h, src_w, ih, iw);
columns_ptr += dst_h * dst_w;
}
}
}
}
void gemm_ref_fp32(const float *A, const float *B, const float *V,
const float *H, const int32_t trans_A, const int32_t trans_B,
const int32_t M, const int32_t N, const int32_t K,
const float alpha, const float beta, float *Y) {
if (!trans_A && !trans_B) { // MK, KN; NN
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[m * K + k] * B[k * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (trans_A && !trans_B) { // KM, KN; TN
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[k * M + m] * B[k * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (trans_A && trans_B) { // KM, NK; TT
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[k * M + m] * B[n * K + k];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (!trans_A && trans_B) { // MK, NK; NT
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[m * K + k] * B[n * K + k];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
}
void deformable_conv2d_ref_fp32(
const float *src, const float *offset, const float *mask,
const float *filter, const float *bias, const int64_t batch,
const int64_t src_c, const int64_t src_h, const int64_t src_w,
const int64_t dst_c, const int64_t dst_h, const int64_t dst_w,
const int64_t group, const int64_t offset_group, const int64_t channels,
const int64_t num_output, const int64_t kernel_h, const int64_t kernel_w,
const int64_t stride_h, const int64_t stride_w, const int64_t pad_h,
const int64_t pad_w, const int64_t dilation_h, const int64_t dilation_w,
float *columns, float *dst) {
const int64_t ic_per_gp = channels / group;
const int64_t oc_per_gp = num_output / group;
for (int64_t b = 0; b < batch; ++b) {
for (int64_t g = 0; g < group; ++g) {
deformable_im2col_2d(
src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w,
offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w,
mask + b * offset_group * kernel_h * kernel_w * dst_h * dst_w, src_h,
src_w, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, ic_per_gp, offset_group, dst_h, dst_w,
mask != nullptr, columns);
float *dst_ptr =
dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w;
if (bias != nullptr) {
const float *bias_ptr = bias + g * oc_per_gp;
for (int64_t oc = 0; oc < oc_per_gp; ++oc) {
for (int64_t hw = 0; hw < dst_h * dst_w; ++hw) {
dst_ptr[oc * dst_h * dst_w + hw] = bias_ptr[oc];
}
}
} else {
memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w);
}
gemm_ref_fp32(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w,
columns, nullptr, dst_ptr, 0, 0, oc_per_gp, dst_h * dst_w,
ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr);
}
}
}
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(
OrtApi api, const OrtKernelInfo *info)
: api_(api), ort_(api_), info_(info) {
std::vector<int64_t> stride =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
stride_height_ = stride[0];
stride_width_ = stride[1];
std::vector<int64_t> padding =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "padding");
padding_height_ = padding[0];
padding_width_ = padding[1];
std::vector<int64_t> dilation =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "dilation");
dilation_height_ = dilation[0];
dilation_width_ = dilation[1];
deformable_group_ =
ort_.KernelInfoGetAttribute<int64_t>(info, "deform_groups");
group_ = ort_.KernelInfoGetAttribute<int64_t>(info, "groups");
// create allocator
allocator_ = Ort::AllocatorWithDefaultOptions();
}
void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext *context) {
const int64_t stride_height = stride_height_;
const int64_t stride_width = stride_width_;
const int64_t padding_height = padding_height_;
const int64_t padding_width = padding_width_;
const int64_t dilation_height = dilation_height_;
const int64_t dilation_width = dilation_width_;
const int64_t deformable_group = deformable_group_;
const int64_t group = group_;
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
const float *input_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(input));
const OrtValue *offset = ort_.KernelContext_GetInput(context, 1);
const float *offset_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(offset));
const OrtValue *mask = ort_.KernelContext_GetInput(context, 2);
const float *mask_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(mask));
const OrtValue *filter = ort_.KernelContext_GetInput(context, 3);
const float *filter_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(filter));
const OrtValue *bias = ort_.KernelContext_GetInput(context, 4);
const float *bias_data =
(bias != nullptr)
? reinterpret_cast<const float *>(ort_.GetTensorData<float>(bias))
: nullptr;
// const float *bias_data = nullptr;
OrtTensorDimensions input_dims(ort_, input);
OrtTensorDimensions filter_dims(ort_, filter);
int64_t batch = input_dims[0];
int64_t channels = input_dims[1];
int64_t in_height = input_dims[2];
int64_t in_width = input_dims[3];
int64_t num_output = filter_dims[0];
int64_t kernel_height = filter_dims[2];
int64_t kernel_width = filter_dims[3];
// get output memory
int64_t out_height = floor((in_height + 2 * padding_height -
dilation_height * (kernel_height - 1) - 1) /
stride_height +
1);
int64_t out_width = floor(
(in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) /
stride_width +
1);
std::vector<int64_t> output_dims = {batch, num_output, out_height, out_width};
OrtValue *output = ort_.KernelContext_GetOutput(
context, 0, output_dims.data(), output_dims.size());
float *out_ptr = ort_.GetTensorMutableData<float>(output);
// allocate tmp memory
int64_t column_len = (channels / group) * kernel_height * kernel_width *
out_height * out_width;
float *columns = (float *)allocator_.Alloc(sizeof(float) * column_len);
deformable_conv2d_ref_fp32(
input_data, offset_data, mask_data, filter_data, bias_data, batch,
channels, in_height, in_width, num_output, out_height, out_width, group,
deformable_group, channels, num_output, kernel_height, kernel_width,
stride_height, stride_width, padding_height, padding_width,
dilation_height, dilation_width, columns, out_ptr);
}
REGISTER_ONNXRUNTIME_OPS(MMCVModulatedDeformConvOp);
} // namespace mmlab

View File

@ -0,0 +1,64 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ONNXRUNTIME_MODULATED_DEFORM_CONV_H
#define ONNXRUNTIME_MODULATED_DEFORM_CONV_H
#include <onnxruntime_cxx_api.h>
namespace mmlab {
struct MMCVModulatedDeformConvKernel {
MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
protected:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_;
int64_t stride_height_;
int64_t stride_width_;
int64_t padding_height_;
int64_t padding_width_;
int64_t dilation_height_;
int64_t dilation_width_;
int64_t deformable_group_;
int64_t group_;
};
struct MMCVModulatedDeformConvOp
: Ort::CustomOpBase<MMCVModulatedDeformConvOp,
MMCVModulatedDeformConvKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
return new MMCVModulatedDeformConvKernel(api, info);
}
const char *GetName() const { return "MMCVModulatedDeformConv2d"; };
size_t GetInputTypeCount() const { return 5; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(
size_t index) const {
// The last input (index == 4) is optional, which is bias
if (index == 4)
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
// force cpu
const char *GetExecutionProviderType() const {
return "CPUExecutionProvider";
};
};
} // namespace mmlab
#endif

View File

@ -53,6 +53,7 @@ set(PLUGIN_LISTS scatternd
roi_align
batched_nms
instance_norm
modulated_deform_conv
multi_level_roi_align
grid_sampler)

View File

@ -1,6 +1,7 @@
#ifndef COMMON_CUDA_HELPER
#define COMMON_CUDA_HELPER
#include <cublas_v2.h>
#include <cuda.h>
#include <algorithm>
@ -39,11 +40,20 @@ inline int GET_BLOCKS(const int N) {
* @param[in] stream cuda stream handle
*/
template <class scalar_t>
void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
int *permute, int src_dim, cudaStream_t stream = 0);
void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size,
int* permute, int src_dim, cudaStream_t stream = 0);
template <typename scalar_t>
__device__ scalar_t bilinear_interpolate(const scalar_t *input,
cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const scalar_t* alpha, const scalar_t* A, int lda,
const scalar_t* B, int ldb, const scalar_t* beta,
scalar_t* C, int ldc) {
return CUBLAS_STATUS_INTERNAL_ERROR;
}
template <typename scalar_t>
__device__ scalar_t bilinear_interpolate(const scalar_t* input,
const int height, const int width,
scalar_t y, scalar_t x) {
// deal with cases that inverse elements are out of feature map boundary

View File

@ -78,3 +78,25 @@ cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
}
return CUDNN_STATUS_SUCCESS;
}
template <>
cublasStatus_t cublasGemmWrap<float>(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
int k, const float *alpha, const float *A,
int lda, const float *B, int ldb,
const float *beta, float *C, int ldc) {
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
cublasStatus_t cublasGemmWrap<half>(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
int k, const half *alpha, const half *A,
int lda, const half *B, int ldb,
const half *beta, half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb,
beta, C, ldc);
}

View File

@ -0,0 +1,288 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_modulated_deform_conv.hpp"
#include <assert.h>
#include <chrono>
#include "trt_serialize.hpp"
using namespace nvinfer1;
void ModulatedDeformConvForwardCUDAKernelLauncher_float(
const float *input, const float *weight, const float *bias,
const float *offset, const float *mask, float *output, void *workspace,
int batch, int channels, int height, int width, int channels_out,
int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w,
int pad_h, int dilation_w, int dilation_h, int group, int deformable_group,
int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream);
namespace mmlab {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"MMCVModulatedDeformConv2d"};
} // namespace
ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(
const std::string &name, const nvinfer1::Dims stride,
const nvinfer1::Dims padding, const nvinfer1::Dims dilation,
const int deformableGroup, const int group)
: TRTPluginBase(name),
mStride(stride),
mPadding(padding),
mDilation(dilation),
mDeformableGroup(deformableGroup),
mGroup(group) {
mWithBias = false;
}
ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(
const std::string name, const void *data, size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &mStride);
deserialize_value(&data, &length, &mPadding);
deserialize_value(&data, &length, &mDilation);
deserialize_value(&data, &length, &mDeformableGroup);
deserialize_value(&data, &length, &mGroup);
mWithBias = false;
}
ModulatedDeformableConvPluginDynamic::~ModulatedDeformableConvPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt *ModulatedDeformableConvPluginDynamic::clone()
const TRT_NOEXCEPT {
ModulatedDeformableConvPluginDynamic *plugin =
new ModulatedDeformableConvPluginDynamic(
mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[3].d[0];
ret.d[2] = inputs[1].d[2];
ret.d[3] = inputs[1].d[3];
return ret;
}
bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
if (pos == 0) {
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else {
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
}
}
void ModulatedDeformableConvPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {
if (nbInputs == 5) {
mWithBias = true;
}
}
size_t ModulatedDeformableConvPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT {
int sizeof_dtype = mmlab::getElementSize(outputs[0].type);
int batch_size = inputs[0].dims.d[0];
int nInputPlane = inputs[0].dims.d[1];
int inputHeight = inputs[0].dims.d[2];
int inputWidth = inputs[0].dims.d[3];
int nOutputPlane = outputs[0].dims.d[1];
int outputHeight = outputs[0].dims.d[2];
int outputWidth = outputs[0].dims.d[3];
int kW = inputs[3].dims.d[2];
int kH = inputs[3].dims.d[3];
int im2col_step = std::min(32, batch_size);
size_t col_size = mmlab::getAlignedSize(nInputPlane * kW * kH * outputHeight *
outputWidth * sizeof_dtype);
return col_size;
}
int ModulatedDeformableConvPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
int batch = inputDesc[0].dims.d[0];
int channels = inputDesc[0].dims.d[1];
int height = inputDesc[0].dims.d[2];
int width = inputDesc[0].dims.d[3];
int channels_out = outputDesc[0].dims.d[1];
int kernel_h = inputDesc[3].dims.d[2];
int kernel_w = inputDesc[3].dims.d[3];
const void *x = inputs[0];
const void *offset = inputs[1];
const void *mask = inputs[2];
const void *weight = inputs[3];
const void *bias = mWithBias ? inputs[4] : nullptr;
void *output = outputs[0];
int im2col_step = std::min(batch, 32);
// TODO: add fp16 support
auto data_type = inputDesc[0].type;
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
ModulatedDeformConvForwardCUDAKernelLauncher_float(
(float *)x, (float *)weight, (float *)bias, (float *)offset,
(float *)mask, (float *)output, workSpace, batch, channels, height,
width, channels_out, kernel_w, kernel_h, mStride.d[0], mStride.d[1],
mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup,
mDeformableGroup, im2col_step, m_cublas_handle, stream);
break;
default:
return 1;
break;
}
return 0;
}
nvinfer1::DataType ModulatedDeformableConvPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}
// IPluginV2 Methods
const char *ModulatedDeformableConvPluginDynamic::getPluginType() const
TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *ModulatedDeformableConvPluginDynamic::getPluginVersion() const
TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
int ModulatedDeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT {
return 1;
}
size_t ModulatedDeformableConvPluginDynamic::getSerializationSize() const
TRT_NOEXCEPT {
return serialized_size(mStride) + serialized_size(mPadding) +
serialized_size(mDilation) + serialized_size(mDeformableGroup) +
serialized_size(mGroup);
}
void ModulatedDeformableConvPluginDynamic::serialize(void *buffer) const
TRT_NOEXCEPT {
serialize_value(&buffer, mStride);
serialize_value(&buffer, mPadding);
serialize_value(&buffer, mDilation);
serialize_value(&buffer, mDeformableGroup);
serialize_value(&buffer, mGroup);
}
void ModulatedDeformableConvPluginDynamic::attachToContext(
cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {
m_cublas_handle = cublasContext;
}
void ModulatedDeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {}
////////////////////// creator /////////////////////////////
ModulatedDeformableConvPluginDynamicCreator::
ModulatedDeformableConvPluginDynamicCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("padding"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("groups"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *ModulatedDeformableConvPluginDynamicCreator::getPluginName() const
TRT_NOEXCEPT {
return PLUGIN_NAME;
}
const char *ModulatedDeformableConvPluginDynamicCreator::getPluginVersion()
const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}
nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
nvinfer1::Dims stride{2, {1, 1}};
nvinfer1::Dims padding{2, {0, 0}};
nvinfer1::Dims dilation{2, {1, 1}};
int deformableGroup = 1;
int group = 1;
for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);
if (field_name.compare("deformable_group") == 0) {
deformableGroup = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("group") == 0) {
group = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("stride") == 0) {
stride.nbDims = 2;
stride.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
stride.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}
if (field_name.compare("padding") == 0) {
padding.nbDims = 2;
padding.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
padding.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}
if (field_name.compare("dilation") == 0) {
dilation.nbDims = 2;
dilation.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
dilation.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}
}
ModulatedDeformableConvPluginDynamic *plugin =
new ModulatedDeformableConvPluginDynamic(name, stride, padding, dilation,
deformableGroup, group);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *
ModulatedDeformableConvPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin =
new ModulatedDeformableConvPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(ModulatedDeformableConvPluginDynamicCreator);
} // namespace mmlab

View File

@ -0,0 +1,95 @@
#ifndef TRT_MODULATED_DEFORM_CONV_HPP
#define TRT_MODULATED_DEFORM_CONV_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace mmlab {
class ModulatedDeformableConvPluginDynamic : public TRTPluginBase {
public:
ModulatedDeformableConvPluginDynamic(const std::string &name,
const nvinfer1::Dims stride,
const nvinfer1::Dims padding,
const nvinfer1::Dims dilation,
const int deformableGroup,
const int group);
ModulatedDeformableConvPluginDynamic(const std::string name, const void *data,
size_t length);
ModulatedDeformableConvPluginDynamic() = delete;
~ModulatedDeformableConvPluginDynamic() TRT_NOEXCEPT override;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator)
TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
private:
nvinfer1::Dims mStride;
nvinfer1::Dims mPadding;
nvinfer1::Dims mDilation;
int mDeformableGroup;
int mGroup;
bool mWithBias;
cublasHandle_t m_cublas_handle;
};
class ModulatedDeformableConvPluginDynamicCreator
: public TRTPluginCreatorBase {
public:
ModulatedDeformableConvPluginDynamicCreator();
const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin(const char *name,
const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(
const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmlab
#endif // TRT_MODULATED_DEFORM_CONV_HPP

View File

@ -0,0 +1,132 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <assert.h>
#include <cuda_fp16.h>
#include "common_cuda_helper.hpp"
#include "trt_modulated_deform_conv_kernel.hpp"
#include "trt_plugin_helper.hpp"
template <typename T>
void trt_modulated_deformable_im2col(
const T* data_im_, const T* data_offset_, const T* data_mask_,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, T* data_col_,
cudaStream_t stream) {
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
modulated_deformable_im2col_gpu_kernel<T>
<<<GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0, stream>>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im,
kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
dilation_w, channel_per_deformable_group, batch_size, channels,
deformable_group, height_col, width_col, data_col_);
cudaCheckError();
}
template <typename scalar_t>
__global__ void output_add_bias_kernel(scalar_t* output, const scalar_t* bias,
size_t step_batch, size_t step_channel,
size_t n) {
CUDA_1D_KERNEL_LOOP(index, n) {
output[index] += bias[(index % step_batch) / step_channel];
}
}
template <typename scalar_t>
static void output_add_bias(scalar_t* output, const scalar_t* bias,
size_t batch, size_t channel, size_t height,
size_t width, cudaStream_t stream) {
size_t step_channel = height * width;
size_t step_batch = step_channel * channel;
size_t n = step_batch * batch;
output_add_bias_kernel<<<GET_BLOCKS(n), THREADS_PER_BLOCK, 0, stream>>>(
output, bias, step_batch, step_channel, n);
}
template <typename scalar_t>
void ModulatedDeformConvForwardCUDAKernelLauncher(
const scalar_t* input, const scalar_t* weight, const scalar_t* bias,
const scalar_t* offset, const scalar_t* mask, scalar_t* output,
void* workspace, int batch, int channels, int height, int width,
int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h,
int pad_w, int pad_h, int dilation_w, int dilation_h, int group,
int deformable_group, int im2col_step, cublasHandle_t cublas_handle,
cudaStream_t stream) {
size_t sizeof_dtype = sizeof(scalar_t);
bool with_bias = (bias != nullptr);
im2col_step = std::min(int(batch), im2col_step);
assert(batch % im2col_step == 0);
const int channels_kernel = channels / group;
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
scalar_t* columns = (scalar_t*)workspace;
const size_t input_step = channels * height * width;
const size_t offset_step =
deformable_group * kernel_h * kernel_w * 2 * height * width;
const size_t mask_step =
deformable_group * kernel_h * kernel_w * height * width;
const size_t out_step = channels_out * height_out * width_out;
const size_t out_group_step = out_step / group;
const size_t col_g_step =
channels * kernel_w * kernel_h / group * height_out * width_out;
const size_t weight_g_step =
channels_out / group * channels / group * kernel_h * kernel_w;
const int m = channels_out / group;
const int n = height_out * width_out;
const int k = channels / group * kernel_h * kernel_w;
scalar_t alpha = 1.;
scalar_t beta = 0.;
for (int b = 0; b < batch; b++) {
const scalar_t* input_start = input + b * input_step;
const scalar_t* offset_start = offset + b * offset_step;
const scalar_t* mask_start = mask + b * mask_step;
trt_modulated_deformable_im2col<scalar_t>(
input_start, offset_start, mask_start, 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, columns, stream);
for (int g = 0; g < group; g++) {
const scalar_t* weight_start = weight + g * weight_g_step;
scalar_t* col_start = columns + g * col_g_step;
scalar_t* out_buffer_start = output + b * out_step + g * out_group_step;
cublasGemmWrap<scalar_t>(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
&alpha, col_start, n, weight_start, k, &beta,
out_buffer_start, n);
cudaCheckError();
}
}
if (with_bias) {
output_add_bias<scalar_t>(output, bias, batch, channels_out, height_out,
width_out, stream);
}
}
void ModulatedDeformConvForwardCUDAKernelLauncher_float(
const float* input, const float* weight, const float* bias,
const float* offset, const float* mask, float* output, void* workspace,
int batch, int channels, int height, int width, int channels_out,
int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w,
int pad_h, int dilation_w, int dilation_h, int group, int deformable_group,
int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) {
ModulatedDeformConvForwardCUDAKernelLauncher<float>(
input, weight, bias, offset, mask, output, workspace, batch, channels,
height, width, channels_out, kernel_w, kernel_h, stride_w, stride_h,
pad_w, pad_h, dilation_w, dilation_h, group, deformable_group,
im2col_step, cublas_handle, stream);
}

View File

@ -0,0 +1,392 @@
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer
*****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer
*********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#ifndef TRT_MODULATED_DEFORM_CONV_KERNEL_CUH
#define TRT_MODULATED_DEFORM_CONV_KERNEL_CUH
#include <float.h>
#include "common_cuda_helper.hpp"
template <typename T>
__device__ T dmcn_im2col_bilinear(const T *input, const int data_width,
const int height, const int width, T h, T w) {
int h_low = floorf(h);
int w_low = floorf(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh, hw = 1 - lw;
T v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
T v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = input[h_low * data_width + w_high];
T v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = input[h_high * data_width + w_low];
T v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = input[h_high * data_width + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__device__ T dmcn_get_gradient_weight(T argmax_h, T argmax_w, const int h,
const int w, const int height,
const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename T>
__device__ T dmcn_get_coordinate_weight(T argmax_h, T argmax_w,
const int height, const int width,
const T *im_data, const int data_width,
const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (bp_dir == 0) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
} else if (bp_dir == 1) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename T>
__global__ void modulated_deformable_im2col_gpu_kernel(
const int n, const T *data_im, const T *data_offset, const T *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int num_channels, const int deformable_group, const int height_col,
const int width_col, T *data_col) {
CUDA_1D_KERNEL_LOOP(index, n) {
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T *data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T *data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T *data_offset_ptr =
data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const T *data_mask_ptr =
data_mask + (b_col * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im,
w_im);
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T>
__global__ void modulated_deformable_col2im_gpu_kernel(
const int n, const T *data_col, const T *data_offset, const T *data_mask,
const int channels, const int height, const int width, const int kernel_h,
const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int deformable_group, const int height_col, const int width_col,
T *grad_im) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i =
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const T *data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const T *data_mask_ptr =
data_mask + (b * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
const T cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
T weight =
dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data,
cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template <typename T>
__global__ void modulated_deformable_col2im_coord_gpu_kernel(
const int n, const T *data_col, const T *data_im, const T *data_offset,
const T *data_mask, const int channels, const int height, const int width,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col, T *grad_offset, T *grad_mask) {
CUDA_1D_KERNEL_LOOP(index, n) {
T val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const T *data_col_ptr = data_col + deformable_group_index *
channel_per_deformable_group *
batch_size * width_col * height_col;
const T *data_im_ptr =
data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w *
height * width;
const T *data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const T *data_mask_ptr =
data_mask + (b * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const int data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
T inv_h = h_in + i * dilation_h + offset_h;
T inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
inv_h = inv_w = -2;
else
mval += data_col_ptr[col_pos] *
dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width,
height, width, inv_h, inv_w);
const T weight = dmcn_get_coordinate_weight(
inv_h, inv_w, height, width, data_im_ptr + cnt * height * width,
width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group +
// deformable_group_index) * kernel_h * kernel_w + offset_c / 2) *
// height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
kernel_w +
offset_c / 2) *
height_col +
h) *
width_col +
w] = mval;
}
}
#endif // TRT_MODULATED_DEFORM_CONV_KERNEL_CUH

View File

@ -2,6 +2,6 @@ _base_ = ['./base.py', '../_base_/backends/tensorrt.py']
tensorrt_params = dict(model_params=[
dict(
opt_shape_dict=dict(
input=[[1, 3, 320, 320], [1, 3, 800, 1344], [1, 3, 1344, 1344]]),
input=[[1, 3, 320, 320], [1, 3, 1024, 1824], [1, 3, 1024, 1824]]),
max_workspace_size=1 << 30)
])

View File

@ -1,4 +1,3 @@
import warnings
from typing import Iterable, Union
import mmcv
@ -177,14 +176,8 @@ class TensorRTDetector(DeployBaseTextDetector):
device_id: int,
show_score: bool = False):
super(TensorRTDetector, self).__init__(cfg, device_id, show_score)
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWrapper(
trt_file, input_names=['input'], output_names=['output'])
from mmdeploy.apis.tensorrt import TRTWrapper
model = TRTWrapper(trt_file)
self.model = model
def forward_of_backend(self,
@ -205,14 +198,8 @@ class TensorRTRecognizer(DeployBaseRecognizer):
device_id: int,
show_score: bool = False):
super(TensorRTRecognizer, self).__init__(cfg, device_id, show_score)
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWrapper(
trt_file, input_names=['input'], output_names=['output'])
from mmdeploy.apis.tensorrt import TRTWrapper
model = TRTWrapper(trt_file)
self.model = model
def forward_of_backend(self,