mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] Dbnet performance of trt8 (#278)
* compatible trt version for dbnet * judge inside rewrite
This commit is contained in:
parent
bd2867178d
commit
105acc9de9
@ -1,2 +1,2 @@
|
|||||||
[settings]
|
[settings]
|
||||||
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision
|
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision,version
|
||||||
|
@ -673,7 +673,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](doc
|
|||||||
<td class="tg-0lax">model config file</td>
|
<td class="tg-0lax">model config file</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td class="tg-nrix" rowspan="3">DBNet</td>
|
<td class="tg-nrix" rowspan="3">DBNet*</td>
|
||||||
<td class="tg-nrix" rowspan="3">TextDetection</td>
|
<td class="tg-nrix" rowspan="3">TextDetection</td>
|
||||||
<td class="tg-baqh">recall</td>
|
<td class="tg-baqh">recall</td>
|
||||||
<td class="tg-baqh">0.7310</td>
|
<td class="tg-baqh">0.7310</td>
|
||||||
@ -823,6 +823,8 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](doc
|
|||||||
|
|
||||||
|
|
||||||
### Notes
|
### Notes
|
||||||
As some datasets contains images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones.
|
- As some datasets contains images with various resolutions in codebase like MMDet. The speed benchmark is gained through static configs in MMDeploy, while the performance benchmark is gained through dynamic ones.
|
||||||
|
|
||||||
Some int8 performance benchmarks of tensorrt require nvidia cards with tensor core, or the performance would drop heavily.
|
- Some int8 performance benchmarks of TensorRT require nvidia cards with tensor core, or the performance would drop heavily.
|
||||||
|
|
||||||
|
- DBNet uses the interpolate mode `nearest` in the neck of the model, which TensorRT-7 applies quite different strategy from pytorch. To make the repository compatible with TensorRT-7, we rewrite the neck to use the interpolate mode `bilinear` which improves final detection performance. To get the matched performance with Pytorch, TensorRT-8+ is recommended, which the interpolate methods are all the same as Pytorch.
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import version
|
||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
|||||||
"""Rewrite `forward` of FPNC for tensorrt backend.
|
"""Rewrite `forward` of FPNC for tensorrt backend.
|
||||||
|
|
||||||
Rewrite this function to replace nearest upsampling with bilinear
|
Rewrite this function to replace nearest upsampling with bilinear
|
||||||
upsampling. Tensorrt backend applies different nearest sampling strategy
|
upsampling. TensorRT-7 backend applies different nearest sampling strategy
|
||||||
from pytorch, which heavily influenced the final performance.
|
from pytorch, which heavily influenced the final performance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -24,6 +25,10 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
|||||||
outs (Tensor): A feature map output from FPNC. The tensor shape
|
outs (Tensor): A feature map output from FPNC. The tensor shape
|
||||||
(N, C, H, W).
|
(N, C, H, W).
|
||||||
"""
|
"""
|
||||||
|
# TensorRT version 8+ matches the upsampling with pytorch
|
||||||
|
import tensorrt as trt
|
||||||
|
apply_rewrite = version.parse(trt.__version__) < version.parse('8')
|
||||||
|
mode = 'bilinear' if apply_rewrite else 'nearest'
|
||||||
|
|
||||||
assert len(inputs) == len(self.in_channels)
|
assert len(inputs) == len(self.in_channels)
|
||||||
# build laterals
|
# build laterals
|
||||||
@ -36,7 +41,7 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
|||||||
for i in range(used_backbone_levels - 1, 0, -1):
|
for i in range(used_backbone_levels - 1, 0, -1):
|
||||||
prev_shape = laterals[i - 1].shape[2:]
|
prev_shape = laterals[i - 1].shape[2:]
|
||||||
laterals[i - 1] += F.interpolate(
|
laterals[i - 1] += F.interpolate(
|
||||||
laterals[i], size=prev_shape, mode='bilinear')
|
laterals[i], size=prev_shape, mode=mode)
|
||||||
# build outputs
|
# build outputs
|
||||||
# part 1: from original levels
|
# part 1: from original levels
|
||||||
outs = [
|
outs = [
|
||||||
@ -44,8 +49,7 @@ def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for i, out in enumerate(outs):
|
for i, out in enumerate(outs):
|
||||||
outs[i] = F.interpolate(
|
outs[i] = F.interpolate(outs[i], size=outs[0].shape[2:], mode=mode)
|
||||||
outs[i], size=outs[0].shape[2:], mode='bilinear')
|
|
||||||
out = torch.cat(outs, dim=1)
|
out = torch.cat(outs, dim=1)
|
||||||
|
|
||||||
if self.conv_after_concat:
|
if self.conv_after_concat:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user