From 89253699dae6d7d6750631c6657e9f73596c495a Mon Sep 17 00:00:00 2001 From: Derry Lin <37007486+DerryHub@users.noreply.github.com> Date: Mon, 14 Nov 2022 19:09:10 +0800 Subject: [PATCH] [Fix] Fix DCN TensorRT plugin (#2408) --- mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp | 2 +- mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp | 2 +- .../csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp index b9b2439ba..62ec3a127 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_deform_conv.cpp @@ -286,7 +286,7 @@ nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::createPlugin( deformableGroup = static_cast(fc->fields[i].data)[0]; } - if (field_name.compare("group") == 0) { + if (field_name.compare("groups") == 0) { group = static_cast(fc->fields[i].data)[0]; } diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp index 30ca758b8..0ed490d6b 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv.cpp @@ -258,7 +258,7 @@ nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::createPlugin( deformableGroup = static_cast(fc->fields[i].data)[0]; } - if (field_name.compare("group") == 0) { + if (field_name.compare("groups") == 0) { group = static_cast(fc->fields[i].data)[0]; } diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu index f29a7a79d..3c5b723a0 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_modulated_deform_conv_kernel.cu @@ -75,9 +75,9 @@ void ModulatedDeformConvForwardCUDAKernelLauncher( const size_t input_step = channels * height * width; const size_t offset_step = - deformable_group * kernel_h * kernel_w * 2 * height * width; + deformable_group * kernel_h * kernel_w * 2 * height_out * width_out; const size_t mask_step = - deformable_group * kernel_h * kernel_w * height * width; + deformable_group * kernel_h * kernel_w * height_out * width_out; const size_t out_step = channels_out * height_out * width_out; const size_t out_group_step = out_step / group; const size_t col_g_step =