From 12eec39340c7409f3f87b1e8865c3bb1e031eee5 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Thu, 4 Nov 2021 15:53:48 +0800 Subject: [PATCH] [Fix] fix trt_multi_level_roi_align plugin of tensorrt in T4 platform (#170) * fix trt_multi_level_roi_align for tensorrt in T4 platform * fix typo * resolve comments --- .../trt_multi_level_roi_align.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp b/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp index 84c07ab83..7a4b881c4 100644 --- a/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp +++ b/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp @@ -52,8 +52,7 @@ nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions( int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - ASSERT(nbInputs == mFeatmapStrides.size() + 1); - + // warning, nbInputs should equal to mFeatmapStrides.size() + 1 nvinfer1::DimsExprs ret; ret.nbDims = 4; ret.d[0] = inputs[0].d[0]; @@ -77,7 +76,9 @@ void TRTMultiLevelRoiAlign::configurePlugin( int nbOutputs) TRT_NOEXCEPT { // Validate input arguments ASSERT(nbOutputs == 1); - ASSERT(nbInputs == mFeatmapStrides.size() + 1); + ASSERT(nbInputs >= 1); + mFeatmapStrides = std::vector(mFeatmapStrides.begin(), + mFeatmapStrides.begin() + nbInputs - 1); } size_t TRTMultiLevelRoiAlign::getWorkspaceSize( @@ -203,11 +204,7 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin( } else if (field_name.compare("finest_scale") == 0) { finestScale = static_cast(fc->fields[i].data)[0]; } else if (field_name.compare("featmap_strides") == 0) { -#if NV_TENSORRT_MAJOR > 7 int data_size = (fc->fields[i].length); -#else - int data_size = (fc->fields[i].length) / sizeof(float); -#endif const float *data_start = static_cast(fc->fields[i].data); featmapStrides = std::vector(data_start, data_start + data_size); } else if (field_name.compare("aligned") == 0) {