[Feature] Support reppoints TensorRT (#457)

* Support reppoints tensorrt

* add ut and docs

* update zh_cn documents

* update document
This commit is contained in:
q.yao 2022-05-18 11:54:45 +08:00 committed by GitHub
parent 0ce7c83c63
commit a4b7bced55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 319 additions and 2 deletions

View File

@ -1033,8 +1033,24 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
<td align="center">40.0</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td>$MMDET_DIR/configs/gfl/gfl_r50_fpn_1x_coco.py</td>
</tr>
<tr>
<td align="center">RepPoints</td>
<td align="center">Object Detection</td>
<td align="center">COCO2017</td>
<td align="center">box AP</td>
<td align="center">37.0</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">36.9</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
<td>$MMDET_DIR/configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py</td>
</tr>
<tr>
<td align="center" rowspan="2">Mask R-CNN</td>
<td align="center" rowspan="2">Instance Segmentation</td>
@ -1877,3 +1893,5 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
- Mask AP of Mask R-CNN drops by 1% for the backend. The main reason is that the predicted masks are directly interpolated to original image in PyTorch, while they are at first interpolated to the preprocessed input image of the model and then to original image in other backends.
- MMPose models are tested with `flip_test` explicitly set to `False` in model configs.
- Some models might get low accuracy in fp16 mode. Please adjust the model to avoid value overflow.

View File

@ -23,6 +23,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/
| Faster R-CNN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
| Faster R-CNN + DCN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
| GFL | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
| RepPoints | ObjectDetection | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
| Cascade Mask R-CNN | InstanceSegmentation | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| Mask R-CNN | InstanceSegmentation | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |

View File

@ -712,6 +712,19 @@ GPU: ncnn, TensorRT, PPLNN
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py">RepPoints</a></td>
<td align="center">Object Detection</td>
<td align="center">COCO2017</td>
<td align="center">box AP</td>
<td align="center">37.0</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">36.9</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center" rowspan="2"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py">Mask R-CNN</a></td>
<td align="center" rowspan="2">Instance Segmentation</td>
@ -1478,4 +1491,5 @@ GPU: ncnn, TensorRT, PPLNN
- 由于某些数据集在代码库中包含各种分辨率的图像,例如 MMDet速度基准是通过 MMDeploy 中的静态配置获得的,而性能基准是通过动态配置获得的
- TensorRT 的一些 int8 性能基准测试需要有 tensor core 的 Nvidia 卡,否则性能会大幅下降
- DBNet 在模型 `neck` 使用了`nearest` 插值TensorRT-7 用了与 Pytorch 完全不同的策略。为了使与 TensorRT-7 兼容,我们重写了`neck`以使用`bilinear`插值,这提高了检测性能。为了获得与 Pytorch 匹配的性能,推荐使用 TensorRT-8+,其插值方法与 Pytorch 相同。
- 对于 mmpose 模型,是在模型配置文件中 `flip_test` 需设置为 `False`
- 对于 mmpose 模型,在模型配置文件中 `flip_test` 需设置为 `False`
- 部分模型在 fp16 模式下可能存在较大的精度损失,请根据具体情况对模型进行调整。

View File

@ -18,6 +18,7 @@
| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
| RepPoints | MMDetection | N | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) |
| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) |
| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) |

View File

@ -3,6 +3,7 @@ from .base_dense_head import (base_dense_head__get_bbox,
base_dense_head__get_bboxes__ncnn)
from .fovea_head import fovea_head__get_bboxes
from .gfl_head import gfl_head__get_bbox
from .reppoints_head import reppoints_head__get_bboxes
from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn
from .ssd_head import ssd_head__get_bboxes__ncnn
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
@ -14,5 +15,5 @@ __all__ = [
'yolox_head__get_bboxes', 'base_dense_head__get_bbox',
'fovea_head__get_bboxes', 'base_dense_head__get_bboxes__ncnn',
'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn',
'gfl_head__get_bbox'
'gfl_head__get_bbox', 'reppoints_head__get_bboxes'
]

View File

@ -0,0 +1,167 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape
def _bbox_pre_decode(points: torch.Tensor, bbox_pred: torch.Tensor,
stride: torch.Tensor):
"""compute real bboxes."""
points = points[..., :2]
bbox_pos_center = torch.cat([points, points], dim=-1)
bboxes = bbox_pred * stride + bbox_pos_center
return bboxes
def _bbox_post_decode(bboxes: torch.Tensor, max_shape: Sequence[int]):
"""clamp bbox."""
x1 = bboxes[..., 0].clamp(min=0, max=max_shape[1])
y1 = bboxes[..., 1].clamp(min=0, max=max_shape[0])
x2 = bboxes[..., 2].clamp(min=0, max=max_shape[1])
y2 = bboxes[..., 3].clamp(min=0, max=max_shape[0])
decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
return decoded_bboxes
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.RepPointsHead.points2bbox')
def reppoints_head__points2bbox(ctx, self, pts, y_first=True):
"""Rewrite of `points2bbox` in `RepPointsHead`.
Use `self.moment_transfer` in `points2bbox` will cause error:
RuntimeError: Input, output and indices must be on the current device
"""
moment_transfer = self.moment_transfer
delattr(self, 'moment_transfer')
self.moment_transfer = torch.tensor(moment_transfer.data)
ret = ctx.origin_func(self, pts, y_first=y_first)
self.moment_transfer = moment_transfer
return ret
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.RepPointsHead.get_bboxes')
def reppoints_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
score_factors=None,
img_metas=None,
cfg=None,
rescale=None,
**kwargs):
"""Rewrite `get_bboxes` of `RepPointsHead` for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
self (RepPointsHead): The instance of the class RepPointsHead.
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
score_factors (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Default None.
img_metas (list[dict]): Meta information of the image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default: False.
Returns:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
"""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device)
mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors]
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
assert img_metas is not None
img_shape = img_metas[0]['img_shape']
assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
batch_size = cls_scores[0].shape[0]
cfg = self.test_cfg
pre_topk = cfg.get('nms_pre', -1)
mlvl_valid_bboxes = []
mlvl_valid_scores = []
for level_idx, (cls_score, bbox_pred, priors) in enumerate(
zip(mlvl_cls_scores, mlvl_bbox_preds, mlvl_priors)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = scores.sigmoid()
else:
scores = scores.softmax(-1)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
if not is_dynamic_flag:
priors = priors.data
if pre_topk > 0:
priors = pad_with_value_if_necessary(priors, 1, pre_topk)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
nms_pre_score = scores
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = nms_pre_score.max(-1)
else:
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=bbox_pred.device).unsqueeze(-1)
prior_inds = batch_inds.new_zeros((1, 1))
priors = priors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
bbox_pred = _bbox_pre_decode(priors, bbox_pred,
self.point_strides[level_idx])
mlvl_valid_bboxes.append(bbox_pred)
mlvl_valid_scores.append(scores)
batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1)
batch_scores = torch.cat(mlvl_valid_scores, dim=1)
batch_bboxes = _bbox_post_decode(
bboxes=batch_mlvl_bboxes_pred, max_shape=img_shape)
if not self.use_sigmoid_cls:
batch_scores = batch_scores[..., :self.num_classes]
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(
batch_bboxes,
batch_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)

View File

@ -148,6 +148,23 @@ def get_rpn_head_model():
return model
def get_reppoints_head_model():
"""Reppoints Head Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models.dense_heads import RepPointsHead
model = RepPointsHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
def get_single_roi_extractor():
"""SingleRoIExtractor Config."""
from mmdet.models.roi_heads import SingleRoIExtractor
@ -1462,3 +1479,101 @@ def test_ssd_head_get_bboxes__ncnn(is_dynamic: bool):
rewrite_outputs = rewrite_outputs[0]
assert rewrite_outputs.shape[-1] == 6
@pytest.mark.parametrize('backend_type, ir_type', [(Backend.OPENVINO, 'onnx')])
def test_reppoints_head_get_bboxes(backend_type: Backend, ir_type: str):
"""Test get_bboxes rewrite of base dense head."""
check_backend(backend_type)
dense_head = get_reppoints_head_model()
dense_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
deploy_cfg = get_deploy_cfg(backend_type, ir_type)
output_names = get_ir_config(deploy_cfg).get('output_names', None)
# the cls_score's size: (1, 4, 32, 32), (1, 4, 16, 16),
# (1, 4, 8, 8), (1, 4, 4, 4), (1, 4, 2, 2).
# the bboxes's size: (1, 4, 32, 32), (1, 4, 16, 16),
# (1, 4, 8, 8), (1, 4, 4, 4), (1, 4, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(dense_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
dense_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type, ir_type', [(Backend.OPENVINO, 'onnx')])
def test_reppoints_head_points2bbox(backend_type: Backend, ir_type: str):
"""Test get_bboxes rewrite of base dense head."""
check_backend(backend_type)
dense_head = get_reppoints_head_model()
dense_head.cpu().eval()
output_names = ['output']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(
input_shape=None,
input_names=['pts'],
output_names=output_names)))
# the cls_score's size: (1, 4, 32, 32), (1, 4, 16, 16),
# (1, 4, 8, 8), (1, 4, 4, 4), (1, 4, 2, 2).
# the bboxes's size: (1, 4, 32, 32), (1, 4, 16, 16),
# (1, 4, 8, 8), (1, 4, 4, 4), (1, 4, 2, 2)
seed_everything(1234)
pts = torch.rand(1, 18, 16, 16)
# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(dense_head, 'points2bbox', y_first=True)
rewrite_inputs = {'pts': pts}
_ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)