mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] fix trt_multi_level_roi_align plugin of tensorrt in T4 platform (#170)
* fix trt_multi_level_roi_align for tensorrt in T4 platform * fix typo * resolve comments
This commit is contained in:
parent
4e168ee1c7
commit
12eec39340
@ -52,8 +52,7 @@ nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const
|
||||
nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
ASSERT(nbInputs == mFeatmapStrides.size() + 1);
|
||||
|
||||
// warning, nbInputs should equal to mFeatmapStrides.size() + 1
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = 4;
|
||||
ret.d[0] = inputs[0].d[0];
|
||||
@ -77,7 +76,9 @@ void TRTMultiLevelRoiAlign::configurePlugin(
|
||||
int nbOutputs) TRT_NOEXCEPT {
|
||||
// Validate input arguments
|
||||
ASSERT(nbOutputs == 1);
|
||||
ASSERT(nbInputs == mFeatmapStrides.size() + 1);
|
||||
ASSERT(nbInputs >= 1);
|
||||
mFeatmapStrides = std::vector<float>(mFeatmapStrides.begin(),
|
||||
mFeatmapStrides.begin() + nbInputs - 1);
|
||||
}
|
||||
|
||||
size_t TRTMultiLevelRoiAlign::getWorkspaceSize(
|
||||
@ -203,11 +204,7 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
|
||||
} else if (field_name.compare("finest_scale") == 0) {
|
||||
finestScale = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
} else if (field_name.compare("featmap_strides") == 0) {
|
||||
#if NV_TENSORRT_MAJOR > 7
|
||||
int data_size = (fc->fields[i].length);
|
||||
#else
|
||||
int data_size = (fc->fields[i].length) / sizeof(float);
|
||||
#endif
|
||||
const float *data_start = static_cast<const float *>(fc->fields[i].data);
|
||||
featmapStrides = std::vector<float>(data_start, data_start + data_size);
|
||||
} else if (field_name.compare("aligned") == 0) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user