[Fix] fix trt_multi_level_roi_align plugin of tensorrt in T4 platform (#170)

* fix trt_multi_level_roi_align for tensorrt in T4 platform

* fix typo

* resolve comments
This commit is contained in:
AllentDan 2021-11-04 15:53:48 +08:00 committed by GitHub
parent 4e168ee1c7
commit 12eec39340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,8 +52,7 @@ nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const
nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
ASSERT(nbInputs == mFeatmapStrides.size() + 1);
// warning, nbInputs should equal to mFeatmapStrides.size() + 1
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[0].d[0];
@ -77,7 +76,9 @@ void TRTMultiLevelRoiAlign::configurePlugin(
int nbOutputs) TRT_NOEXCEPT {
// Validate input arguments
ASSERT(nbOutputs == 1);
ASSERT(nbInputs == mFeatmapStrides.size() + 1);
ASSERT(nbInputs >= 1);
mFeatmapStrides = std::vector<float>(mFeatmapStrides.begin(),
mFeatmapStrides.begin() + nbInputs - 1);
}
size_t TRTMultiLevelRoiAlign::getWorkspaceSize(
@ -203,11 +204,7 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
} else if (field_name.compare("finest_scale") == 0) {
finestScale = static_cast<const int *>(fc->fields[i].data)[0];
} else if (field_name.compare("featmap_strides") == 0) {
#if NV_TENSORRT_MAJOR > 7
int data_size = (fc->fields[i].length);
#else
int data_size = (fc->fields[i].length) / sizeof(float);
#endif
const float *data_start = static_cast<const float *>(fc->fields[i].data);
featmapStrides = std::vector<float>(data_start, data_start + data_size);
} else if (field_name.compare("aligned") == 0) {