mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] Dbnet performance of trt8 (#278)
* compatible trt version for dbnet * judge inside rewrite
This commit is contained in:
parent
bd2867178d
commit
105acc9de9
@ -1,2 +1,2 @@
|
||||
[settings]
|
||||
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision
|
||||
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision,version
|
||||
|
@ -673,7 +673,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](doc
|
||||
<td class="tg-0lax">model config file</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-nrix" rowspan="3">DBNet</td>
|
||||
<td class="tg-nrix" rowspan="3">DBNet*</td>
|
||||
<td class="tg-nrix" rowspan="3">TextDetection</td>
|
||||
<td class="tg-baqh">recall</td>
|
||||
<td class="tg-baqh">0.7310</td>
|
||||
@ -823,6 +823,8 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](doc
|
||||
|
||||
|
||||
### Notes
|
||||
As some datasets contains images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones.
|
||||
- As some datasets contains images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones.
|
||||
|
||||
Some int8 performance benchmarks of tensorrt require nvidia cards with tensor core, or the performance would drop heavily.
|
||||
- Some int8 performance benchmarks of TensorRT require nvidia cards with tensor core, or the performance would drop heavily.
|
||||
|
||||
- DBNet uses the interpolate mode `nearest` in the neck of the model, which TensorRT-7 applies quite different strategy from pytorch. To make the repository compatible with TensorRT-7, we rewrite the neck to use the interpolate mode `bilinear` which improves final detection performance. To get the matched performance with Pytorch, TensorRT-8+ is recommended, which the interpolate methods are all the same as Pytorch.
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import version
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@ -11,7 +12,7 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
||||
"""Rewrite `forward` of FPNC for tensorrt backend.
|
||||
|
||||
Rewrite this function to replace nearest upsampling with bilinear
|
||||
upsampling. Tensorrt backend applies different nearest sampling strategy
|
||||
upsampling. TensorRT-7 backend applies different nearest sampling strategy
|
||||
from pytorch, which heavily influenced the final performance.
|
||||
|
||||
Args:
|
||||
@ -24,6 +25,10 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
||||
outs (Tensor): A feature map output from FPNC. The tensor shape
|
||||
(N, C, H, W).
|
||||
"""
|
||||
# TensorRT version 8+ matches the upsampling with pytorch
|
||||
import tensorrt as trt
|
||||
apply_rewrite = version.parse(trt.__version__) < version.parse('8')
|
||||
mode = 'bilinear' if apply_rewrite else 'nearest'
|
||||
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
# build laterals
|
||||
@ -36,7 +41,7 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] += F.interpolate(
|
||||
laterals[i], size=prev_shape, mode='bilinear')
|
||||
laterals[i], size=prev_shape, mode=mode)
|
||||
# build outputs
|
||||
# part 1: from original levels
|
||||
outs = [
|
||||
@ -44,8 +49,7 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
||||
]
|
||||
|
||||
for i, out in enumerate(outs):
|
||||
outs[i] = F.interpolate(
|
||||
outs[i], size=outs[0].shape[2:], mode='bilinear')
|
||||
outs[i] = F.interpolate(outs[i], size=outs[0].shape[2:], mode=mode)
|
||||
out = torch.cat(outs, dim=1)
|
||||
|
||||
if self.conv_after_concat:
|
||||
|
Loading…
x
Reference in New Issue
Block a user