From 6e91614171cd6fd6119aa6a1f5f4c697116f75b7 Mon Sep 17 00:00:00 2001 From: DDGRCF Date: Thu, 22 Dec 2022 00:02:25 +0800 Subject: [PATCH] CodeCamp #102: support SOLO deployment with OpenVINO (#1454) * feat: support solo mmdet3.x openvino 2.0 api * feat: support solo mmdet3.x openvino * fix: lint * fix: add solo head test * docs: add supported_modesl * docs: add supported_models * fix: fix unreasonable code * fix: fix ci failed * feat: add linspace func rewrite * fix: fix unreasonable rewrite linspace__onnx * fix: change func name from __onnx to __default * feat: add solo test regression --- docs/en/03-benchmark/benchmark.md | 35 ++++++ docs/en/03-benchmark/supported_models.md | 1 + docs/en/04-supported-codebases/mmdet.md | 1 + docs/zh_cn/03-benchmark/benchmark.md | 35 ++++++ docs/zh_cn/03-benchmark/supported_models.md | 1 + docs/zh_cn/04-supported-codebases/mmdet.md | 43 ++++---- mmdeploy/codebase/mmdet/models/__init__.py | 1 + .../mmdet/models/dense_heads/__init__.py | 1 + .../mmdet/models/dense_heads/solo_head.py | 102 +++++++++++++++++ .../mmdet/models/detectors/__init__.py | 4 +- .../detectors/single_stage_instance_seg.py | 67 ++++++++++++ .../models/{layers.py => layers/__init__.py} | 3 + .../mmdet/models/layers/matrix_nms.py | 103 ++++++++++++++++++ mmdeploy/pytorch/functions/__init__.py | 1 + mmdeploy/pytorch/functions/linspace.py | 20 ++++ mmdeploy/pytorch/functions/triu.py | 1 + tests/regression/mmdet.yml | 7 ++ .../test_mmcls/test_mmcls_models.py | 8 +- .../test_mmdet/test_mmdet_models.py | 93 ++++++++++++++++ tests/test_pytorch/test_pytorch_functions.py | 35 ++++++ 20 files changed, 533 insertions(+), 29 deletions(-) create mode 100644 mmdeploy/codebase/mmdet/models/dense_heads/solo_head.py create mode 100644 mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py rename mmdeploy/codebase/mmdet/models/{layers.py => layers/__init__.py} (64%) create mode 100644 mmdeploy/codebase/mmdet/models/layers/matrix_nms.py create mode 100644 mmdeploy/pytorch/functions/linspace.py diff --git a/docs/en/03-benchmark/benchmark.md b/docs/en/03-benchmark/benchmark.md index 46d51cfde..ed7604f72 100644 --- a/docs/en/03-benchmark/benchmark.md +++ b/docs/en/03-benchmark/benchmark.md @@ -638,6 +638,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ TensorRT PPLNN Ascend + OpenVINO @@ -654,6 +655,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ int8 fp16 fp32 + fp32 YOLOV3 @@ -668,6 +670,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ 33.5 - - + - SSD @@ -682,6 +685,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - RetinaNet @@ -696,6 +700,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ 36.3 36.5 36.4 + - FCOS @@ -710,6 +715,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - FSAF @@ -724,6 +730,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ 37.2 37.4 - + - CenterNet @@ -738,6 +745,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - YOLOX @@ -752,6 +760,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ 29.3 - - + - Faster R-CNN @@ -766,6 +775,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ 37.1 37.3 37.2 + - ATSS @@ -780,6 +790,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - Cascade R-CNN @@ -794,6 +805,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - 40.4 - + - GFL @@ -808,6 +820,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - RepPoints @@ -822,6 +835,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - DETR @@ -835,6 +849,8 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ 40.1 - - + - + - Mask R-CNN @@ -849,6 +865,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - 38.0 - + - mask AP @@ -860,6 +877,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - Swin-Transformer @@ -874,6 +892,7 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - mask AP @@ -885,6 +904,22 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ - - - + - + + + SOLO + Instance Segmentation + COCO2017 + mask AP + 33.1 + - + - + - + - + - + - + - + 32.7 diff --git a/docs/en/03-benchmark/supported_models.md b/docs/en/03-benchmark/supported_models.md index 140da9a38..ec459dc3e 100644 --- a/docs/en/03-benchmark/supported_models.md +++ b/docs/en/03-benchmark/supported_models.md @@ -22,6 +22,7 @@ The table below lists the models that are guaranteed to be exportable to other b | [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | MMDetection | N | N | Y | N | ? | Y | N | N | | [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | MMDetection | N | Y | Y | N | ? | N | N | N | | [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | MMDetection | N | Y | Y | N | ? | N | N | N | +| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | MMDetection | N | N | N | N | N | Y | N | N | | [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y | | [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y | | [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y | diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md index c2356ef89..d8525af1d 100644 --- a/docs/en/04-supported-codebases/mmdet.md +++ b/docs/en/04-supported-codebases/mmdet.md @@ -210,3 +210,4 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter | [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | N | N | N | Y | | [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | | [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | N | +| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | N | N | N | N | Y | diff --git a/docs/zh_cn/03-benchmark/benchmark.md b/docs/zh_cn/03-benchmark/benchmark.md index 7afb43359..04d3e5e06 100644 --- a/docs/zh_cn/03-benchmark/benchmark.md +++ b/docs/zh_cn/03-benchmark/benchmark.md @@ -633,6 +633,7 @@ GPU: ncnn, TensorRT, PPLNN TensorRT PPLNN Ascend + OpenVINO @@ -649,6 +650,7 @@ GPU: ncnn, TensorRT, PPLNN int8 fp16 fp32 + fp32 YOLOV3 @@ -663,6 +665,7 @@ GPU: ncnn, TensorRT, PPLNN 33.5 - - + - SSD @@ -677,6 +680,7 @@ GPU: ncnn, TensorRT, PPLNN - - - + - RetinaNet @@ -691,6 +695,7 @@ GPU: ncnn, TensorRT, PPLNN 36.3 36.5 36.4 + - FCOS @@ -705,6 +710,7 @@ GPU: ncnn, TensorRT, PPLNN - - - + - FSAF @@ -719,6 +725,7 @@ GPU: ncnn, TensorRT, PPLNN 37.2 37.4 - + - CenterNet @@ -733,6 +740,7 @@ GPU: ncnn, TensorRT, PPLNN - - - + - YOLOX @@ -747,6 +755,7 @@ GPU: ncnn, TensorRT, PPLNN 29.3 - - + - Faster R-CNN @@ -761,6 +770,7 @@ GPU: ncnn, TensorRT, PPLNN 37.1 37.3 37.2 + - ATSS @@ -775,6 +785,7 @@ GPU: ncnn, TensorRT, PPLNN - - - + - Cascade R-CNN @@ -789,6 +800,7 @@ GPU: ncnn, TensorRT, PPLNN - 40.4 - + - GFL @@ -803,6 +815,7 @@ GPU: ncnn, TensorRT, PPLNN - - - + - RepPoints @@ -817,6 +830,7 @@ GPU: ncnn, TensorRT, PPLNN - - - + - DETR @@ -830,6 +844,8 @@ GPU: ncnn, TensorRT, PPLNN 40.1 - - + - + - Mask R-CNN @@ -844,6 +860,7 @@ GPU: ncnn, TensorRT, PPLNN - 38.0 - + - mask AP @@ -855,6 +872,7 @@ GPU: ncnn, TensorRT, PPLNN - - - + - Swin-Transformer @@ -869,6 +887,7 @@ GPU: ncnn, TensorRT, PPLNN - - - + - mask AP @@ -880,6 +899,22 @@ GPU: ncnn, TensorRT, PPLNN - - - + - + + + SOLO + Instance Segmentation + COCO2017 + mask AP + 33.1 + - + - + - + - + - + - + - + 32.7 diff --git a/docs/zh_cn/03-benchmark/supported_models.md b/docs/zh_cn/03-benchmark/supported_models.md index bcbfaf2f6..a58071df5 100644 --- a/docs/zh_cn/03-benchmark/supported_models.md +++ b/docs/zh_cn/03-benchmark/supported_models.md @@ -22,6 +22,7 @@ | [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | MMDetection | N | N | Y | N | ? | Y | N | N | | [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | MMDetection | N | Y | Y | N | ? | N | N | N | | [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | MMDetection | N | Y | Y | N | ? | N | N | N | +| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | MMDetection | N | N | N | N | N | Y | N | N | | [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y | | [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y | | [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y | diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md index 600e0f274..5886a420f 100644 --- a/docs/zh_cn/04-supported-codebases/mmdet.md +++ b/docs/zh_cn/04-supported-codebases/mmdet.md @@ -192,24 +192,25 @@ cv2.imwrite('output_detection.png', img) ## 模型支持列表 -| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | -| :-------------------------------------------------------------------------------------------: | :------------------: | :---------: | :------: | :--: | :---: | :------: | -| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | ObjectDetection | Y | Y | N | N | Y | -| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | ObjectDetection | Y | Y | Y | N | Y | -| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | ObjectDetection | Y | N | N | N | Y | -| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | ObjectDetection | Y | Y | Y | Y | Y | -| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | ObjectDetection | Y | Y | Y | Y | Y | -| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | ObjectDetection | Y | Y | Y | N | Y | -| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | ObjectDetection | N | N | N | N | Y | -| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | ObjectDetection | Y | Y | Y | N | Y | -| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | ObjectDetection | Y | Y | Y | N | Y | -| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | ObjectDetection | Y | Y | N | Y | Y | -| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | ObjectDetection | Y | Y | Y | Y | Y | -| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | ObjectDetection | Y | Y | Y | Y | Y | -| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | ObjectDetection | Y | Y | N | ? | Y | -| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | ObjectDetection | N | Y | N | ? | Y | -| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | ObjectDetection | Y | Y | N | ? | Y | -| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | ? | -| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | InstanceSegmentation | Y | N | N | N | Y | -| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | InstanceSegmentation | Y | Y | N | N | Y | -| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | InstanceSegmentation | Y | Y | N | N | N | +| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | +| :-------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: | +| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | ObjectDetection | Y | Y | N | N | Y | +| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | ObjectDetection | Y | Y | Y | N | Y | +| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | ObjectDetection | Y | N | N | N | Y | +| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | ObjectDetection | Y | Y | Y | Y | Y | +| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | ObjectDetection | Y | Y | Y | Y | Y | +| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | ObjectDetection | Y | Y | Y | N | Y | +| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | ObjectDetection | N | N | N | N | Y | +| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | ObjectDetection | Y | Y | Y | N | Y | +| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | ObjectDetection | Y | Y | Y | N | Y | +| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | ObjectDetection | Y | Y | N | Y | Y | +| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | ObjectDetection | Y | Y | Y | Y | Y | +| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | ObjectDetection | Y | Y | Y | Y | Y | +| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | ObjectDetection | Y | Y | N | ? | Y | +| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | ObjectDetection | N | Y | N | ? | Y | +| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | ObjectDetection | Y | Y | N | ? | Y | +| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | ? | +| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | InstanceSegmentation | Y | N | N | N | Y | +| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | InstanceSegmentation | Y | Y | N | N | Y | +| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | InstanceSegmentation | Y | Y | N | N | N | +| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | N | N | N | N | Y | diff --git a/mmdeploy/codebase/mmdet/models/__init__.py b/mmdeploy/codebase/mmdet/models/__init__.py index 53291d6bf..38b7e336d 100644 --- a/mmdeploy/codebase/mmdet/models/__init__.py +++ b/mmdeploy/codebase/mmdet/models/__init__.py @@ -2,6 +2,7 @@ from . import backbones # noqa: F401, F403 from . import dense_heads # noqa: F401,F403 from . import detectors # noqa: F401,F403 +from . import layers # noqa: F401,F403 from . import necks # noqa: F401,F403 from . import roi_heads # noqa: F401,F403 from . import task_modules # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index 78178fa89..ecbb3bc8b 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -7,5 +7,6 @@ from . import gfl_head # noqa: F401,F403 from . import reppoints_head # noqa: F401,F403 from . import rpn_head # noqa: F401,F403 from . import rtmdet_head # noqa: F401,F403 +from . import solo_head # noqa: F401,F403 from . import yolo_head # noqa: F401,F403 from . import yolox_head # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/solo_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/solo_head.py new file mode 100644 index 000000000..fb937b089 --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/dense_heads/solo_head.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch +from mmdet.models.layers import mask_matrix_nms +from mmdet.utils import OptConfigType +from torch import Tensor +from torch.nn import functional as F + +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.dense_heads.SOLOHead.predict_by_feat', backend='openvino') +def solohead__predict_by_feat__openvino(self, + mlvl_mask_preds: List[Tensor], + mlvl_cls_scores: List[Tensor], + batch_img_metas: List[Dict], + cfg: OptConfigType = None, + **kwargs): + """Rewrite `predict_by_feat` of `SOLOHead` for openvino backend.""" + + ctx = FUNCTION_REWRITER.get_context() + batch_size = mlvl_cls_scores[0].size(0) + cfg = self.test_cfg + mlvl_cls_scores = [ + item.permute(0, 2, 3, 1).view(item.size(0), -1, self.cls_out_channels) + for item in mlvl_cls_scores + ] + + # avoid setting items + lvl_strides = [ + torch.ones_like(mlvl_cls_scores[lvl][0, :, 0]) * self.strides[lvl] + for lvl in range(len(mlvl_cls_scores)) + ] + strides = torch.cat(lvl_strides, 0) + assert len(mlvl_mask_preds) == len(mlvl_cls_scores) + batch_mlvl_cls_scores = torch.cat(mlvl_cls_scores, dim=1) + batch_mlvl_mask_preds = torch.cat(mlvl_mask_preds, dim=1) + featmap_size = batch_mlvl_mask_preds.size()[-2:] + batch_mlvl_cls_scores, cls_labels = torch.max(batch_mlvl_cls_scores, -1) + + score_mask = (batch_mlvl_cls_scores > cfg.score_thr) + # pad zero to filter items + batch_mlvl_cls_scores = batch_mlvl_cls_scores.where( + score_mask, batch_mlvl_cls_scores.new_zeros(1)).view(-1) + + cls_labels = cls_labels.view(-1) + + mask_preds = batch_mlvl_mask_preds.view(-1, featmap_size[0], + featmap_size[1]) + + masks = (mask_preds > cfg.mask_thr) + sum_masks = masks.sum((1, 2)) + keep = sum_masks > strides + # pad zero to filter items + cls_scores = batch_mlvl_cls_scores.where( + keep, batch_mlvl_cls_scores.new_zeros(1)) + sum_masks = sum_masks.where(keep, sum_masks.new_ones(1)) + + # maskness + mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks + cls_scores *= mask_scores + sum_masks = sum_masks.where(keep, sum_masks.new_zeros(1)) + + scores, labels, _, keep_inds = mask_matrix_nms( + masks, + cls_labels, + cls_scores, + mask_area=sum_masks, + nms_pre=cfg.nms_pre, + max_num=cfg.max_per_img, + kernel=cfg.kernel, + sigma=cfg.sigma, + filter_thr=cfg.filter_thr) + + h, w = batch_img_metas[0]['img_shape'][:2] + mask_preds = mask_preds[keep_inds].unsqueeze(0) + + mmdet_params = get_post_processing_params(ctx.cfg) + export_postprocess_mask = mmdet_params.get('export_postprocess_mask', True) + if export_postprocess_mask: + upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) + mask_preds = F.interpolate( + mask_preds, size=upsampled_size, mode='bilinear') + bboxes = scores.new_zeros(batch_size, scores.shape[-1], 4) + else: + + bboxes = scores.new_zeros(batch_size, scores.shape[-1], 2) + # full screen box so we can postprocess mask outside the model + bboxes = torch.cat([ + bboxes, + bboxes.new_full((*bboxes.shape[:2], 1), w), + bboxes.new_full((*bboxes.shape[:2], 1), h) + ], + dim=-1) + + labels = labels.reshape(batch_size, -1) + dets = torch.cat([bboxes, scores.reshape(batch_size, -1, 1)], dim=-1) + + return dets, labels, mask_preds diff --git a/mmdeploy/codebase/mmdet/models/detectors/__init__.py b/mmdeploy/codebase/mmdet/models/detectors/__init__.py index 638036332..5b9df70a0 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/__init__.py +++ b/mmdeploy/codebase/mmdet/models/detectors/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import single_stage, two_stage +from . import single_stage, single_stage_instance_seg, two_stage -__all__ = ['single_stage', 'two_stage'] +__all__ = ['single_stage', 'single_stage_instance_seg', 'two_stage'] diff --git a/mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py b/mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py new file mode 100644 index 000000000..bdff6e636 --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmdet.models.detectors.base import ForwardResults +from mmdet.structures.det_data_sample import OptSampleList + +from mmdeploy.core import FUNCTION_REWRITER, mark +from mmdeploy.utils import is_dynamic_shape +from .single_stage import _set_metainfo + + +@mark( + 'instance_segmentor_forward', + inputs=['input'], + outputs=['dets', 'labels', 'masks']) +def __forward_impl_instance_seg(self, batch_inputs, data_samples, **kwargs): + """Rewrite and adding mark for `forward`. + + Encapsulate this function for rewriting `forward` of BaseDetector. + 1. Add mark for BaseDetector. + 2. Support both dynamic and static export to onnx. + """ + x = self.extract_feat(batch_inputs) + mask_outs = self.mask_head.predict(x, data_samples, rescale=False) + return mask_outs + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.single_stage_instance_seg.' + 'SingleStageInstanceSegmentor.forward') +def single_stage_instance_segmentor__forward( + self, + batch_inputs: torch.Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor', + **kwargs) -> ForwardResults: + """Rewrite `forward` for default backend. + Support configured dynamic/static shape for model input and return + detection result as Tensor instead of numpy array. + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + Returns: + tuple[Tensor]: Detection results of the + input images. + - dets (Tensor): Classification bboxes and scores. + Has a shape (num_instances, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + """ + ctx = FUNCTION_REWRITER.get_context() + deploy_cfg = ctx.cfg + + # get origin input shape as tensor to support onnx dynamic shape + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + img_shape = torch._shape_as_tensor(batch_inputs)[2:] + if not is_dynamic_flag: + img_shape = [int(val) for val in img_shape] + + # set the metainfo + data_samples = _set_metainfo(data_samples, img_shape) + + return __forward_impl_instance_seg( + self, batch_inputs, data_samples=data_samples, **kwargs) diff --git a/mmdeploy/codebase/mmdet/models/layers.py b/mmdeploy/codebase/mmdet/models/layers/__init__.py similarity index 64% rename from mmdeploy/codebase/mmdet/models/layers.py rename to mmdeploy/codebase/mmdet/models/layers/__init__.py index 6a7457e7e..0559af920 100644 --- a/mmdeploy/codebase/mmdet/models/layers.py +++ b/mmdeploy/codebase/mmdet/models/layers/__init__.py @@ -1,3 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. # recovery for mmyolo from mmdeploy.mmcv.ops import multiclass_nms # noqa: F401, F403 +from . import matrix_nms # noqa: F401, F403 + +__all__ = ['multiclass_nms'] diff --git a/mmdeploy/codebase/mmdet/models/layers/matrix_nms.py b/mmdeploy/codebase/mmdet/models/layers/matrix_nms.py new file mode 100644 index 000000000..7bdf2e978 --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/layers/matrix_nms.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.layers.matrix_nms.mask_matrix_nms') +def mask_matrix_nms__default(masks, + labels, + scores, + filter_thr=-1, + nms_pre=-1, + max_num=-1, + kernel='gaussian', + sigma=2.0, + mask_area=None): + """Matrix NMS for multi-class masks. + + Args: + masks (Tensor): Has shape (num_instances, h, w) + labels (Tensor): Labels of corresponding masks, + has shape (num_instances,). + scores (Tensor): Mask scores of corresponding masks, + has shape (num_instances). + filter_thr (float): Score threshold to filter the masks + after matrix nms. Default: -1, which means do not + use filter_thr. + nms_pre (int): The max number of instances to do the matrix nms. + Default: -1, which means do not use nms_pre. + max_num (int, optional): If there are more than max_num masks after + matrix, only top max_num will be kept. Default: -1, which means + do not use max_num. + kernel (str): 'linear' or 'gaussian'. + sigma (float): std in gaussian method. + mask_area (Tensor): The sum of seg_masks. + + Returns: + tuple(Tensor): Processed mask results. + + - scores (Tensor): Updated scores, has shape (n,). + - labels (Tensor): Remained labels, has shape (n,). + - masks (Tensor): Remained masks, has shape (n, w, h). + - keep_inds (Tensor): The indices number of + the remaining mask in the input mask, has shape (n,). + """ + assert len(labels) == len(masks) == len(scores) + assert len(masks) == len(mask_area) + # sort and keep top nms_pre + nms_pre = max(0, nms_pre) + if nms_pre <= 0: + nms_pre = scores.shape[0] + scores, sort_inds = torch.topk(scores, nms_pre) + + keep_inds = sort_inds + masks = masks[sort_inds] + mask_area = mask_area[sort_inds] + labels = labels[sort_inds] + num_masks = labels.size(0) + flatten_masks = masks.reshape(num_masks, -1).float() + # inter. + inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0)) + expanded_mask_area = mask_area.unsqueeze(1) + total_area = expanded_mask_area + expanded_mask_area.transpose( + 1, 0) - inter_matrix + total_mask = total_area > 0 + total_area = total_area.where(total_mask, total_area.new_ones(1)) + iou_matrix = (inter_matrix / total_area).triu(diagonal=1) + expanded_labels = labels.unsqueeze(1) + label_matrix = expanded_labels == expanded_labels.transpose(1, 0) + + # iou decay + decay_iou = iou_matrix.where(label_matrix, iou_matrix.new_zeros(1)) + + # iou compensation + compensate_iou, _ = decay_iou.max(0) + compensate_iou = compensate_iou.expand(num_masks, + num_masks).transpose(1, 0) + + # calculate the decay_coefficient + if kernel == 'gaussian': + decay_matrix = torch.exp(-1 * sigma * (decay_iou**2)) + compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2)) + decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) + elif kernel == 'linear': + decay_matrix = (1 - decay_iou) / (1 - compensate_iou) + decay_coefficient, _ = decay_matrix.min(0) + else: + raise NotImplementedError( + f'{kernel} kernel is not supported in matrix nms!') + # update the score. + scores = scores * decay_coefficient + + keep = scores >= filter_thr + scores = scores.where(keep, scores.new_zeros(1)) + + # sort and keep top max_num + scores, sort_inds = torch.topk(scores, max(max_num, 0)) + keep_inds = keep_inds[sort_inds] + masks = masks[sort_inds] + labels = labels[sort_inds] + + return scores, labels, masks, keep_inds diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index f3e145b4d..a07a2aab0 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -9,6 +9,7 @@ from . import getattribute # noqa: F401,F403 from . import group_norm # noqa: F401,F403 from . import interpolate # noqa: F401,F403 from . import linear # noqa: F401,F403 +from . import linspace # noqa: F401,F403 from . import masked_fill # noqa: F401,F403 from . import mod # noqa: F401,F403 from . import multi_head_attention_forward # noqa: F401,F403 diff --git a/mmdeploy/pytorch/functions/linspace.py b/mmdeploy/pytorch/functions/linspace.py new file mode 100644 index 000000000..7385b79f6 --- /dev/null +++ b/mmdeploy/pytorch/functions/linspace.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.types import Number + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter(func_name='torch.linspace') +def linspace__default(start: Number, end: Number, steps: int = None, **kwargs): + """Rewrite `linspace` for onnxruntime.""" + steps = 100 if steps is None else steps + dtype = kwargs.pop('dtype', torch.float32) + dtype = dtype if dtype else torch.float32 + if steps == 1: + output = torch.arange(start, end + 1, dtype=dtype, **kwargs)[:steps] + else: + output = torch.arange( + start, end + 1, (end - start) / (steps - 1), dtype=dtype, + **kwargs)[:steps] + return output diff --git a/mmdeploy/pytorch/functions/triu.py b/mmdeploy/pytorch/functions/triu.py index e7e8e501e..2d393eba6 100644 --- a/mmdeploy/pytorch/functions/triu.py +++ b/mmdeploy/pytorch/functions/triu.py @@ -5,6 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter(func_name='torch.triu') +@FUNCTION_REWRITER.register_rewriter(func_name='torch.Tensor.triu') def triu__default(input: torch.Tensor, diagonal: int = 0, *args, diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index 266f3fe67..4f2a938f1 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -339,3 +339,10 @@ models: convert_image: *convert_image backend_test: *default_backend_test sdk_config: *sdk_dynamic + + - name: SOLO + metafile: configs/solo/metafile.yml + model_configs: + - configs/solo/solo_r50_fpn_1x_coco.py + pipelines: + - *pipeline_seg_openvino_dynamic_fp32 diff --git a/tests/test_codebase/test_mmcls/test_mmcls_models.py b/tests/test_codebase/test_mmcls/test_mmcls_models.py index d1420a80b..b75a9cb10 100644 --- a/tests/test_codebase/test_mmcls/test_mmcls_models.py +++ b/tests/test_codebase/test_mmcls/test_mmcls_models.py @@ -71,16 +71,12 @@ def test_baseclassifier_forward(): def __init__(self, backbone): super().__init__(backbone=backbone) + self.head = lambda x: x + self.predict = lambda x, data_samples: x def extract_feat(self, batch_inputs: torch.Tensor): return batch_inputs - def head(self, x): - return x - - def predict(self, x, data_samples): - return x - backbone_cfg = dict( type='ResNet', depth=18, diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 2ad56bda9..cfa912838 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -2155,3 +2155,96 @@ def test_detrhead__predict_by_feat(backend_type: Backend, ir_type: str): run_with_backend=False) assert rewrite_outputs is not None + + +def get_solo_head_model(): + test_cfg = Config( + dict( + nms_pre=500, + score_thr=0.1, + mask_thr=0.5, + filter_thr=0.05, + kernel='gaussian', # gaussian/linear + sigma=2.0, + max_per_img=100)) + from mmdet.models.dense_heads import SOLOHead + model = SOLOHead(4, 32, feat_channels=32, test_cfg=test_cfg) + + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', [Backend.OPENVINO]) +def test_solo_head_predict_by_feat(backend_type: Backend): + """Test predict_by_feat rewrite of solo head.""" + check_backend(backend_type) + solo_head = get_solo_head_model() + s = 128 + solo_head.cpu().eval() + batch_img_metas = [{'img_shape': (s, s, 3), 'ori_shape': (s, s, 3)}] + + output_names = ['dets', 'labels', 'masks'] + deploy_cfg = Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=20, + pre_top_k=-1, + keep_top_k=10, + background_label_id=-1, + export_postprocess_mask=True)))) + seed_everything(1234) + num_grids = [24, 20, 16, 12, 8] + mask_preds = [ + torch.rand(1, num_grid**2, s // 4, s // 4) for num_grid in num_grids + ] + seed_everything(5678) + cls_scores = [ + torch.rand(1, solo_head.num_classes, num_grid, num_grid) + for num_grid in num_grids + ] + + # to get outputs of pytorch model + model_inputs = { + 'mlvl_mask_preds': mask_preds, + 'mlvl_cls_scores': cls_scores, + 'batch_img_metas': batch_img_metas, + } + model_outputs = get_model_outputs(solo_head, 'predict_by_feat', + model_inputs) + + wrapped_model = WrapModel( + solo_head, 'predict_by_feat', batch_img_metas=batch_img_metas) + rewrite_inputs = { + 'mlvl_mask_preds': mask_preds, + 'mlvl_cls_scores': cls_scores, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + # hard code to make two tensors with the same shape + # rewrite and original codes applied different nms strategy + min_shape = min(model_outputs[0].bboxes.shape[0], + rewrite_outputs[0].shape[1], 5) + for i in range(len(model_outputs)): + assert np.allclose( + model_outputs[i].scores[:min_shape], + rewrite_outputs[0][i, :min_shape, 4], + rtol=1e-03, + atol=1e-05) + assert np.allclose( + model_outputs[i].labels[:min_shape], + rewrite_outputs[1][i, :min_shape], + rtol=1e-03, + atol=1e-05) + else: + assert rewrite_outputs is not None diff --git a/tests/test_pytorch/test_pytorch_functions.py b/tests/test_pytorch/test_pytorch_functions.py index b63830b22..6b296c919 100644 --- a/tests/test_pytorch/test_pytorch_functions.py +++ b/tests/test_pytorch/test_pytorch_functions.py @@ -555,3 +555,38 @@ def test_prepare_onnx_paddings__tensorrt(): run_with_backend=True) assert torch.allclose( pytorch_output, rewrite_output[0], rtol=1e-3, atol=1e-5) + + +@backend_checker(Backend.ONNXRUNTIME) +def test_linspace__default(): + import random + + deploy_cfg_ort = Config( + dict( + onnx_config=dict(input_shape=None), + backend_config=dict(type='onnxruntime'))) + + def linspace_caller(*arg, **kwargs): + return torch.linspace(*arg, **kwargs) + + steps_list = [1, random.randint(1, 1000)] + for steps in steps_list: + start = random.random() * 100 + end = random.random() * 100 + start + + model_output = linspace_caller(start=start, end=end, steps=steps) + + wrapped_func = WrapFunction( + linspace_caller, start=start, end=end, steps=steps) + + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_func, + model_inputs={}, + deploy_cfg=deploy_cfg_ort, + run_with_backend=True) + + if is_backend_output: + rewrite_outputs = rewrite_outputs[0] + + assert np.allclose( + model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)