mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] fix trt_multi_level_roi_align plugin of tensorrt in T4 platform (#170)
* fix trt_multi_level_roi_align for tensorrt in T4 platform * fix typo * resolve comments
This commit is contained in:
parent
4e168ee1c7
commit
12eec39340
@ -52,8 +52,7 @@ nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const
|
|||||||
nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
|
nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions(
|
||||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||||
ASSERT(nbInputs == mFeatmapStrides.size() + 1);
|
// warning, nbInputs should equal to mFeatmapStrides.size() + 1
|
||||||
|
|
||||||
nvinfer1::DimsExprs ret;
|
nvinfer1::DimsExprs ret;
|
||||||
ret.nbDims = 4;
|
ret.nbDims = 4;
|
||||||
ret.d[0] = inputs[0].d[0];
|
ret.d[0] = inputs[0].d[0];
|
||||||
@ -77,7 +76,9 @@ void TRTMultiLevelRoiAlign::configurePlugin(
|
|||||||
int nbOutputs) TRT_NOEXCEPT {
|
int nbOutputs) TRT_NOEXCEPT {
|
||||||
// Validate input arguments
|
// Validate input arguments
|
||||||
ASSERT(nbOutputs == 1);
|
ASSERT(nbOutputs == 1);
|
||||||
ASSERT(nbInputs == mFeatmapStrides.size() + 1);
|
ASSERT(nbInputs >= 1);
|
||||||
|
mFeatmapStrides = std::vector<float>(mFeatmapStrides.begin(),
|
||||||
|
mFeatmapStrides.begin() + nbInputs - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t TRTMultiLevelRoiAlign::getWorkspaceSize(
|
size_t TRTMultiLevelRoiAlign::getWorkspaceSize(
|
||||||
@ -203,11 +204,7 @@ nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin(
|
|||||||
} else if (field_name.compare("finest_scale") == 0) {
|
} else if (field_name.compare("finest_scale") == 0) {
|
||||||
finestScale = static_cast<const int *>(fc->fields[i].data)[0];
|
finestScale = static_cast<const int *>(fc->fields[i].data)[0];
|
||||||
} else if (field_name.compare("featmap_strides") == 0) {
|
} else if (field_name.compare("featmap_strides") == 0) {
|
||||||
#if NV_TENSORRT_MAJOR > 7
|
|
||||||
int data_size = (fc->fields[i].length);
|
int data_size = (fc->fields[i].length);
|
||||||
#else
|
|
||||||
int data_size = (fc->fields[i].length) / sizeof(float);
|
|
||||||
#endif
|
|
||||||
const float *data_start = static_cast<const float *>(fc->fields[i].data);
|
const float *data_start = static_cast<const float *>(fc->fields[i].data);
|
||||||
featmapStrides = std::vector<float>(data_start, data_start + data_size);
|
featmapStrides = std::vector<float>(data_start, data_start + data_size);
|
||||||
} else if (field_name.compare("aligned") == 0) {
|
} else if (field_name.compare("aligned") == 0) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user