From a4b7bced55f3b012d2dceb660a94e81507caf191 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 18 May 2022 11:54:45 +0800 Subject: [PATCH] [Feature] Support reppoints TensorRT (#457) * Support reppoints tensorrt * add ut and docs * update zh_cn documents * update document --- docs/en/benchmark.md | 18 ++ docs/en/codebases/mmdet.md | 1 + docs/zh_cn/03-benchmark/benchmark.md | 16 +- docs/zh_cn/03-benchmark/supported_models.md | 1 + .../mmdet/models/dense_heads/__init__.py | 3 +- .../models/dense_heads/reppoints_head.py | 167 ++++++++++++++++++ .../test_mmdet/test_mmdet_models.py | 115 ++++++++++++ 7 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py diff --git a/docs/en/benchmark.md b/docs/en/benchmark.md index 47264ba90..97b514ec7 100644 --- a/docs/en/benchmark.md +++ b/docs/en/benchmark.md @@ -1033,8 +1033,24 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut 40.0 - - + - $MMDET_DIR/configs/gfl/gfl_r50_fpn_1x_coco.py + + RepPoints + Object Detection + COCO2017 + box AP + 37.0 + - + - + 36.9 + - + - + - + - + $MMDET_DIR/configs/reppoints/reppoints_moment_r50_fpn_1x_coco.py + Mask R-CNN Instance Segmentation @@ -1877,3 +1893,5 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut - Mask AP of Mask R-CNN drops by 1% for the backend. The main reason is that the predicted masks are directly interpolated to original image in PyTorch, while they are at first interpolated to the preprocessed input image of the model and then to original image in other backends. - MMPose models are tested with `flip_test` explicitly set to `False` in model configs. + +- Some models might get low accuracy in fp16 mode. Please adjust the model to avoid value overflow. diff --git a/docs/en/codebases/mmdet.md b/docs/en/codebases/mmdet.md index 062fba55e..0808f8db9 100644 --- a/docs/en/codebases/mmdet.md +++ b/docs/en/codebases/mmdet.md @@ -23,6 +23,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/ | Faster R-CNN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | | Faster R-CNN + DCN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | | GFL | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | +| RepPoints | ObjectDetection | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) | | Cascade Mask R-CNN | InstanceSegmentation | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | | Mask R-CNN | InstanceSegmentation | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) | diff --git a/docs/zh_cn/03-benchmark/benchmark.md b/docs/zh_cn/03-benchmark/benchmark.md index 86d7d1796..dc354e9c9 100644 --- a/docs/zh_cn/03-benchmark/benchmark.md +++ b/docs/zh_cn/03-benchmark/benchmark.md @@ -712,6 +712,19 @@ GPU: ncnn, TensorRT, PPLNN - - + + RepPoints + Object Detection + COCO2017 + box AP + 37.0 + - + - + 36.9 + - + - + - + Mask R-CNN Instance Segmentation @@ -1478,4 +1491,5 @@ GPU: ncnn, TensorRT, PPLNN - 由于某些数据集在代码库中包含各种分辨率的图像,例如 MMDet,速度基准是通过 MMDeploy 中的静态配置获得的,而性能基准是通过动态配置获得的 - TensorRT 的一些 int8 性能基准测试需要有 tensor core 的 Nvidia 卡,否则性能会大幅下降 - DBNet 在模型 `neck` 使用了`nearest` 插值,TensorRT-7 用了与 Pytorch 完全不同的策略。为了使与 TensorRT-7 兼容,我们重写了`neck`以使用`bilinear`插值,这提高了检测性能。为了获得与 Pytorch 匹配的性能,推荐使用 TensorRT-8+,其插值方法与 Pytorch 相同。 -- 对于 mmpose 模型,是在模型配置文件中 `flip_test` 需设置为 `False` +- 对于 mmpose 模型,在模型配置文件中 `flip_test` 需设置为 `False` +- 部分模型在 fp16 模式下可能存在较大的精度损失,请根据具体情况对模型进行调整。 diff --git a/docs/zh_cn/03-benchmark/supported_models.md b/docs/zh_cn/03-benchmark/supported_models.md index c252a87c2..16761b169 100644 --- a/docs/zh_cn/03-benchmark/supported_models.md +++ b/docs/zh_cn/03-benchmark/supported_models.md @@ -18,6 +18,7 @@ | Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | | Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | | VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) | +| RepPoints | MMDetection | N | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) | | ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) | | ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) | | SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) | diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index 9043d4426..aabf0c45e 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -3,6 +3,7 @@ from .base_dense_head import (base_dense_head__get_bbox, base_dense_head__get_bboxes__ncnn) from .fovea_head import fovea_head__get_bboxes from .gfl_head import gfl_head__get_bbox +from .reppoints_head import reppoints_head__get_bboxes from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn from .ssd_head import ssd_head__get_bboxes__ncnn from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn @@ -14,5 +15,5 @@ __all__ = [ 'yolox_head__get_bboxes', 'base_dense_head__get_bbox', 'fovea_head__get_bboxes', 'base_dense_head__get_bboxes__ncnn', 'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn', - 'gfl_head__get_bbox' + 'gfl_head__get_bbox', 'reppoints_head__get_bboxes' ] diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py new file mode 100644 index 000000000..a6a4c2b11 --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch + +from mmdeploy.codebase.mmdet import (get_post_processing_params, + multiclass_nms, + pad_with_value_if_necessary) +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +def _bbox_pre_decode(points: torch.Tensor, bbox_pred: torch.Tensor, + stride: torch.Tensor): + """compute real bboxes.""" + points = points[..., :2] + bbox_pos_center = torch.cat([points, points], dim=-1) + bboxes = bbox_pred * stride + bbox_pos_center + return bboxes + + +def _bbox_post_decode(bboxes: torch.Tensor, max_shape: Sequence[int]): + """clamp bbox.""" + x1 = bboxes[..., 0].clamp(min=0, max=max_shape[1]) + y1 = bboxes[..., 1].clamp(min=0, max=max_shape[0]) + x2 = bboxes[..., 2].clamp(min=0, max=max_shape[1]) + y2 = bboxes[..., 3].clamp(min=0, max=max_shape[0]) + decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) + return decoded_bboxes + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.dense_heads.RepPointsHead.points2bbox') +def reppoints_head__points2bbox(ctx, self, pts, y_first=True): + """Rewrite of `points2bbox` in `RepPointsHead`. + + Use `self.moment_transfer` in `points2bbox` will cause error: + RuntimeError: Input, output and indices must be on the current device + """ + moment_transfer = self.moment_transfer + delattr(self, 'moment_transfer') + self.moment_transfer = torch.tensor(moment_transfer.data) + ret = ctx.origin_func(self, pts, y_first=y_first) + self.moment_transfer = moment_transfer + return ret + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.dense_heads.RepPointsHead.get_bboxes') +def reppoints_head__get_bboxes(ctx, + self, + cls_scores, + bbox_preds, + score_factors=None, + img_metas=None, + cfg=None, + rescale=None, + **kwargs): + """Rewrite `get_bboxes` of `RepPointsHead` for default backend. + + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + + Args: + ctx (ContextCaller): The context with additional information. + self (RepPointsHead): The instance of the class RepPointsHead. + cls_scores (list[Tensor]): Box scores for each scale level + with shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + score_factors (list[Tensor], Optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Default None. + img_metas (list[dict]): Meta information of the image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. Default: None. + rescale (bool): If True, return boxes in original image space. + Default: False. + + Returns: + tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), + `dets` of shape [N, num_det, 5] and `labels` of shape + [N, num_det]. + """ + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + num_levels = len(cls_scores) + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) + mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors] + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + assert img_metas is not None + img_shape = img_metas[0]['img_shape'] + + assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors) + batch_size = cls_scores[0].shape[0] + cfg = self.test_cfg + pre_topk = cfg.get('nms_pre', -1) + + mlvl_valid_bboxes = [] + mlvl_valid_scores = [] + + for level_idx, (cls_score, bbox_pred, priors) in enumerate( + zip(mlvl_cls_scores, mlvl_bbox_preds, mlvl_priors)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, + self.cls_out_channels) + if self.use_sigmoid_cls: + scores = scores.sigmoid() + else: + scores = scores.softmax(-1) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) + if not is_dynamic_flag: + priors = priors.data + if pre_topk > 0: + priors = pad_with_value_if_necessary(priors, 1, pre_topk) + bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk) + scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.) + + nms_pre_score = scores + + # Get maximum scores for foreground classes. + if self.use_sigmoid_cls: + max_scores, _ = nms_pre_score.max(-1) + else: + max_scores, _ = nms_pre_score[..., :-1].max(-1) + _, topk_inds = max_scores.topk(pre_topk) + batch_inds = torch.arange( + batch_size, device=bbox_pred.device).unsqueeze(-1) + prior_inds = batch_inds.new_zeros((1, 1)) + priors = priors[prior_inds, topk_inds, :] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + + bbox_pred = _bbox_pre_decode(priors, bbox_pred, + self.point_strides[level_idx]) + mlvl_valid_bboxes.append(bbox_pred) + mlvl_valid_scores.append(scores) + + batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1) + batch_scores = torch.cat(mlvl_valid_scores, dim=1) + batch_bboxes = _bbox_post_decode( + bboxes=batch_mlvl_bboxes_pred, max_shape=img_shape) + + if not self.use_sigmoid_cls: + batch_scores = batch_scores[..., :self.num_classes] + + post_params = get_post_processing_params(deploy_cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + return multiclass_nms( + batch_bboxes, + batch_scores, + max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index b3a1a9417..def6a6283 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -148,6 +148,23 @@ def get_rpn_head_model(): return model +def get_reppoints_head_model(): + """Reppoints Head Config.""" + test_cfg = mmcv.Config( + dict( + deploy_nms_pre=0, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) + + from mmdet.models.dense_heads import RepPointsHead + model = RepPointsHead(num_classes=4, in_channels=1, test_cfg=test_cfg) + + model.requires_grad_(False) + return model + + def get_single_roi_extractor(): """SingleRoIExtractor Config.""" from mmdet.models.roi_heads import SingleRoIExtractor @@ -1462,3 +1479,101 @@ def test_ssd_head_get_bboxes__ncnn(is_dynamic: bool): rewrite_outputs = rewrite_outputs[0] assert rewrite_outputs.shape[-1] == 6 + + +@pytest.mark.parametrize('backend_type, ir_type', [(Backend.OPENVINO, 'onnx')]) +def test_reppoints_head_get_bboxes(backend_type: Backend, ir_type: str): + """Test get_bboxes rewrite of base dense head.""" + check_backend(backend_type) + dense_head = get_reppoints_head_model() + dense_head.cpu().eval() + s = 128 + img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + + deploy_cfg = get_deploy_cfg(backend_type, ir_type) + output_names = get_ir_config(deploy_cfg).get('output_names', None) + + # the cls_score's size: (1, 4, 32, 32), (1, 4, 16, 16), + # (1, 4, 8, 8), (1, 4, 4, 4), (1, 4, 2, 2). + # the bboxes's size: (1, 4, 32, 32), (1, 4, 16, 16), + # (1, 4, 8, 8), (1, 4, 4, 4), (1, 4, 2, 2) + seed_everything(1234) + cls_score = [ + torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(5678) + bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)] + + # to get outputs of pytorch model + model_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + 'img_metas': img_metas + } + model_outputs = get_model_outputs(dense_head, 'get_bboxes', model_inputs) + + # to get outputs of onnx model after rewrite + img_metas[0]['img_shape'] = torch.Tensor([s, s]) + wrapped_model = WrapModel( + dense_head, 'get_bboxes', img_metas=img_metas, with_nms=True) + rewrite_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + if isinstance(rewrite_outputs, dict): + rewrite_outputs = convert_to_list(rewrite_outputs, output_names) + for model_output, rewrite_output in zip(model_outputs[0], + rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + # hard code to make two tensors with the same shape + # rewrite and original codes applied different nms strategy + assert np.allclose( + model_output[:rewrite_output.shape[0]], + rewrite_output, + rtol=1e-03, + atol=1e-05) + else: + assert rewrite_outputs is not None + + +@pytest.mark.parametrize('backend_type, ir_type', [(Backend.OPENVINO, 'onnx')]) +def test_reppoints_head_points2bbox(backend_type: Backend, ir_type: str): + """Test get_bboxes rewrite of base dense head.""" + check_backend(backend_type) + dense_head = get_reppoints_head_model() + dense_head.cpu().eval() + output_names = ['output'] + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict( + input_shape=None, + input_names=['pts'], + output_names=output_names))) + + # the cls_score's size: (1, 4, 32, 32), (1, 4, 16, 16), + # (1, 4, 8, 8), (1, 4, 4, 4), (1, 4, 2, 2). + # the bboxes's size: (1, 4, 32, 32), (1, 4, 16, 16), + # (1, 4, 8, 8), (1, 4, 4, 4), (1, 4, 2, 2) + seed_everything(1234) + pts = torch.rand(1, 18, 16, 16) + + # to get outputs of onnx model after rewrite + wrapped_model = WrapModel(dense_head, 'points2bbox', y_first=True) + rewrite_inputs = {'pts': pts} + _ = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg)