From c52b24c67fa1125ddc1106c66c7927c9c91fdba5 Mon Sep 17 00:00:00 2001 From: Semyon Bevzyuk Date: Wed, 3 Nov 2021 05:27:48 +0300 Subject: [PATCH] [Enhancement]: Added support for masks in OpenVINO. (#148) * Fix include and lib paths for onnxruntime. * Fixes for SSD export test * Add onnx2openvino and OpenVINODetector. Test models: ssd, retinanet, fcos, fsaf. * Add support for two-stage models: faster_rcnn, cascade_rcnn * Add doc * Add strip_doc_string for openvino. * Fix openvino preprocess. * Add OpenVINO to test_wrapper.py. * Fix * Add openvino_execute. * Removed preprocessing. * Fix onnxruntime cmake. * Rewrote postprocessing and forward, added docstrings and fixes. * Added device type change to OpenVINOWrapper. * Update forward_of_single_roi_extractor_dynamic_openvino and fix doc. * Update docs. * Add support for masks (Mask RCNN). * Add masks to CascadeRoIHead.simple_test. * Added masks to test_OpenVINODetector. * Added test_cascade_roi_head_with_mask. * Update docs. * Fix segm_results shape. * Fix TopK in NMS and add test_multiclass_nms_with_keep_top_k. * Removed unnecessary functions. * Fix. * Fix test_multiclass_nms_with_keep_top_k. * Updated test_OpenVINODetector. --- README.md | 1 + docs/backends/openvino.md | 20 ++-- docs/tutorials/how_to_convert_model.md | 51 +++++----- mmdeploy/mmdet/apis/inference.py | 12 ++- .../mmdet/core/post_processing/bbox_nms.py | 4 +- .../models/roi_heads/cascade_roi_head.py | 22 ++++- mmdeploy/utils/test.py | 83 +++++++++++++---- tests/test_mmdet/test_mmdet_apis.py | 15 ++- tests/test_mmdet/test_mmdet_core.py | 71 +++++++++++++- tests/test_mmdet/test_mmdet_models.py | 93 +++++++++++++++++-- 10 files changed, 305 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 740ab6e39..6d2952aeb 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ Supported backend: - [x] TensorRT - [x] PPL - [x] ncnn +- [x] OpenVINO ## Installation diff --git a/docs/backends/openvino.md b/docs/backends/openvino.md index 8ae0092d1..ff0e0e4e1 100644 --- a/docs/backends/openvino.md +++ b/docs/backends/openvino.md @@ -25,15 +25,17 @@ python tools/deploy.py \ The table below lists the models that are guaranteed to be exportable to OpenVINO from MMDetection. -| Model | Config | Dynamic Shape | -| :----------: | :-----------------------------------------------------------------: | :-----------: | -| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y | -| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | -| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | -| SSD | `configs/ssd/ssd300_coco.py` | Y | -| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | -| Cascade R-CNN| `configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py` | Y | +| Model name | Config | Dynamic Shape | +| :----------------: | :-----------------------------------------------------------------: | :-----------: | +| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y | +| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | +| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | +| SSD | `configs/ssd/ssd300_coco.py` | Y | +| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | +| Cascade R-CNN | `configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py` | Y | +| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | +| Cascade Mask R-CNN | `configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py` | Y | Notes: -- For faster work in OpenVINO in the Faster-RCNN, Cascade-RCNN +- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph. diff --git a/docs/tutorials/how_to_convert_model.md b/docs/tutorials/how_to_convert_model.md index a3669726e..83b86a1ab 100644 --- a/docs/tutorials/how_to_convert_model.md +++ b/docs/tutorials/how_to_convert_model.md @@ -16,14 +16,14 @@ This tutorial briefly introduces how to export an OpenMMlab model to a specific backend using MMDeploy tools. Notes: -- Supported backends are [ONNXRuntime](../backends/onnxruntime.md), [TensorRT](../backends/tensorrt.md), [NCNN](../backends/ncnn.md), [PPL](../backends/ppl.md). +- Supported backends are [ONNXRuntime](../backends/onnxruntime.md), [TensorRT](../backends/tensorrt.md), [NCNN](../backends/ncnn.md), [PPL](../backends/ppl.md), [OpenVINO](../backends/openvino.md). - Supported codebases are [MMClassification](../codebases/mmcls.md), [MMDetection](../codebases/mmdet.md), [MMSegmentation](../codebases/mmseg.md), [MMOCR](../codebases/mmocr.md), [MMEditing](../codebases/mmedit.md). ### How to convert models from Pytorch to other backends #### Prerequisite -1. Install and build your target backend. You could refer to [ONNXRuntime-install](../backends/onnxruntime.md), [TensorRT-install](../backends/tensorrt.md), [NCNN-install](../backends/ncnn.md), [PPL-install](../backends/ppl.md) for more information. +1. Install and build your target backend. You could refer to [ONNXRuntime-install](../backends/onnxruntime.md), [TensorRT-install](../backends/tensorrt.md), [NCNN-install](../backends/ncnn.md), [PPL-install](../backends/ppl.md), [OpenVINO-install](../backends/openvino.md) for more information. 2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/master/docs/install.md), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/main/docs/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/master/docs/install.md). #### Usage @@ -78,28 +78,31 @@ You can try to evaluate model, referring to [how_to_evaluate_a_model](./how_to_e The table below lists the models that are guaranteed to be exportable to other backend. -| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL | -| :----------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: | -| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y | -| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y | -| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y | -| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N | -| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y | -| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y | -| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y | -| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y | -| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y | -| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y | -| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | -| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | -| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | -| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y | -| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | -| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | -| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y | -| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y | -| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y | -| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N | +| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO | +| :----------------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: | :-------: | +| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | +| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | +| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y | N | +| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N | Y | +| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y | +| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y | Y | +| SSD | MMDetection | $PATH_TO_MMDET/configs/ssd/ssd300_coco.py | Y | ? | ? | ? | Y | +| Cascade R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y | +| Cascade Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y | +| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y | N | +| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y | N | +| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y | N | +| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y | N | +| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N | +| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N | +| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | +| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y | N | +| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | +| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N | +| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y | N | +| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y | N | +| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y | N | +| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N | N | ### Reminders diff --git a/mmdeploy/mmdet/apis/inference.py b/mmdeploy/mmdet/apis/inference.py index 5ce00d432..dd1387a42 100644 --- a/mmdeploy/mmdet/apis/inference.py +++ b/mmdeploy/mmdet/apis/inference.py @@ -336,15 +336,21 @@ class OpenVINODetector(DeployBaseDetector): from mmdeploy.apis.openvino import OpenVINOWrapper self.model = OpenVINOWrapper(model_file) - def forward_test(self, imgs: torch.Tensor, *args, **kwargs): + def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> Tuple: """Implement forward test. Args: imgs (torch.Tensor): Input image(s) in [N x C x H x W] format. Returns: - tuple[np.ndarray, np.ndarray]: dets of shape [N, num_det, 5] - and class labels of shape [N, num_det]. + If there are no masks in the output: + tuple[np.ndarray, np.ndarray]: dets of shape [N, num_det, 5] + and class labels of shape [N, num_det]. + If the output contains masks: + tuple[np.ndarray, np.ndarray, np.ndarray]: + dets of shape [N, num_det, 5], + class labels of shape [N, num_det] and + masks of shape [N, num_det, H, W]. """ openvino_outputs = self.model({'input': imgs}) output_keys = ['dets', 'labels'] diff --git a/mmdeploy/mmdet/core/post_processing/bbox_nms.py b/mmdeploy/mmdet/core/post_processing/bbox_nms.py index d86bc7aba..c338c3178 100644 --- a/mmdeploy/mmdet/core/post_processing/bbox_nms.py +++ b/mmdeploy/mmdet/core/post_processing/bbox_nms.py @@ -55,7 +55,9 @@ def select_nms_index(scores: torch.Tensor, (N, 1))), 1) # sort - if keep_top_k > 0 and keep_top_k < batched_dets.shape[1]: + is_use_topk = keep_top_k > 0 and \ + (torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1]) + if is_use_topk: _, topk_inds = batched_dets[:, :, -1].topk(keep_top_k, dim=1) else: _, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True) diff --git a/mmdeploy/mmdet/models/roi_heads/cascade_roi_head.py b/mmdeploy/mmdet/models/roi_heads/cascade_roi_head.py index 11de39ce1..5c5e7cedf 100644 --- a/mmdeploy/mmdet/models/roi_heads/cascade_roi_head.py +++ b/mmdeploy/mmdet/models/roi_heads/cascade_roi_head.py @@ -53,5 +53,23 @@ def simple_test_of_cascade_roi_head(ctx, self, x, proposals, img_metas, if not self.with_mask: return det_bboxes, det_labels else: - raise NotImplementedError( - 'Masks are not yet supported in export CascadeRoIHead.simple_test') + batch_index = torch.arange( + det_bboxes.size(0), device=det_bboxes.device).float().view( + -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1) + rois = det_bboxes[..., :4] + mask_rois = torch.cat([batch_index, rois], dim=-1) + mask_rois = mask_rois.view(-1, 5) + aug_masks = [] + for i in range(self.num_stages): + mask_results = self._mask_forward(i, x, mask_rois) + mask_pred = mask_results['mask_pred'] + aug_masks.append(mask_pred) + # Calculate the mean of masks from several stage + mask_pred = sum(aug_masks) / len(aug_masks) + segm_results = self.mask_head[-1].get_seg_masks( + mask_pred, rois.reshape(-1, 4), det_labels.reshape(-1), + self.test_cfg, max_shape) + segm_results = segm_results.reshape(batch_size, det_bboxes.shape[1], + segm_results.shape[-2], + segm_results.shape[-1]) + return det_bboxes, det_labels, segm_results diff --git a/mmdeploy/utils/test.py b/mmdeploy/utils/test.py index 8093d0097..50455368f 100644 --- a/mmdeploy/utils/test.py +++ b/mmdeploy/utils/test.py @@ -212,22 +212,18 @@ def get_flatten_inputs( return flatten_inputs -def get_rewrite_outputs(wrapped_model: nn.Module, - model_inputs: Dict[str, Union[Tuple, List, - torch.Tensor]], - deploy_cfg: mmcv.Config) -> Tuple[Any, bool]: - """To get outputs of generated onnx model after rewrite. +def get_onnx_model(wrapped_model: nn.Module, + model_inputs: Dict[str, Union[Tuple, List, torch.Tensor]], + deploy_cfg: mmcv.Config) -> str: + """To get path to onnx model after export. Args: wrap_model (nn.Module): The input model. - func_name (str): The function of model. model_inputs (dict): Inputs for model. + deploy_cfg (mmcv.Config): Deployment config. Returns: - Any: The outputs of model, decided by the backend wrapper. - bool: A flag indicate the type of outputs. If the flag is True, then - the outputs are backend output, otherwise they are outputs of wrapped - pytorch model. + str: The path to the ONNX model file. """ onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name pytorch2onnx_cfg = get_onnx_config(deploy_cfg) @@ -239,7 +235,6 @@ def get_rewrite_outputs(wrapped_model: nn.Module, output_names = pytorch2onnx_cfg.get('output_names', None) with RewriterContext( cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad(): - ctx_outputs = wrapped_model(**model_inputs) torch.onnx.export( patched_model, tuple([v for k, v in model_inputs.items()]), @@ -250,12 +245,33 @@ def get_rewrite_outputs(wrapped_model: nn.Module, opset_version=11, dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None), keep_initializers_as_inputs=False) + return onnx_file_path + + +def get_backend_outputs(onnx_file_path: str, + model_inputs: Dict[str, Union[Tuple, List, + torch.Tensor]], + deploy_cfg: mmcv.Config) -> Any: + """To get backend outputs of model. + + Args: + onnx_file_path (str): The path to the ONNX file. + model_inputs (dict): Inputs for model. + deploy_cfg (mmcv.Config): Deployment config. + + Returns: + Any: The outputs of model, decided by the backend wrapper. + """ + backend = get_backend(deploy_cfg) + flatten_model_inputs = get_flatten_inputs(model_inputs) + input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx'] + output_names = get_onnx_config(deploy_cfg).get('output_names', None) # prepare backend model and input features if backend == Backend.TENSORRT: # convert to engine import mmdeploy.apis.tensorrt as trt_apis if not trt_apis.is_available(): - return ctx_outputs, False + return None trt_file_path = tempfile.NamedTemporaryFile(suffix='.engine').name trt_apis.onnx2tensorrt( '', @@ -271,7 +287,7 @@ def get_rewrite_outputs(wrapped_model: nn.Module, elif backend == Backend.ONNXRUNTIME: import mmdeploy.apis.onnxruntime as ort_apis if not ort_apis.is_available(): - return ctx_outputs, False + return None backend_model = ort_apis.ORTWrapper(onnx_file_path, 0, None) feature_list = [] backend_feats = {} @@ -296,11 +312,11 @@ def get_rewrite_outputs(wrapped_model: nn.Module, else: backend_feats[str(i)] = feature_list[i] elif backend == Backend.NCNN: - return ctx_outputs, False + return None elif backend == Backend.OPENVINO: import mmdeploy.apis.openvino as openvino_apis if not openvino_apis.is_available(): - return ctx_outputs, False + return None openvino_work_dir = tempfile.TemporaryDirectory().name openvino_file_path = openvino_apis.get_output_model_file( onnx_file_path, openvino_work_dir) @@ -314,11 +330,44 @@ def get_rewrite_outputs(wrapped_model: nn.Module, backend_feats = flatten_model_inputs elif backend == Backend.DEFAULT: - return ctx_outputs, False + return None else: raise NotImplementedError( f'Unimplemented backend type: {backend.value}') with torch.no_grad(): backend_outputs = backend_model.forward(backend_feats) - return backend_outputs, True + return backend_outputs + + +def get_rewrite_outputs(wrapped_model: nn.Module, + model_inputs: Dict[str, Union[Tuple, List, + torch.Tensor]], + deploy_cfg: mmcv.Config) -> Tuple[Any, bool]: + """To get outputs of generated onnx model after rewrite. + + Args: + wrap_model (nn.Module): The input model. + model_inputs (dict): Inputs for model. + deploy_cfg (mmcv.Config): Deployment config. + + Returns: + Any: The outputs of model, decided by the backend wrapper. + bool: A flag indicate the type of outputs. If the flag is True, then + the outputs are backend output, otherwise they are outputs of wrapped + pytorch model. + """ + backend = get_backend(deploy_cfg) + with RewriterContext( + cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad(): + ctx_outputs = wrapped_model(**model_inputs) + + onnx_file_path = get_onnx_model(wrapped_model, model_inputs, deploy_cfg) + + backend_outputs = get_backend_outputs(onnx_file_path, model_inputs, + deploy_cfg) + + if backend_outputs is None: + return ctx_outputs, False + else: + return backend_outputs, True diff --git a/tests/test_mmdet/test_mmdet_apis.py b/tests/test_mmdet/test_mmdet_apis.py index 4ba12915f..4e29af5eb 100644 --- a/tests/test_mmdet/test_mmdet_apis.py +++ b/tests/test_mmdet/test_mmdet_apis.py @@ -107,12 +107,23 @@ def test_OpenVINODetector(): openvino_apis.__dict__.update({'OpenVINOWrapper': OpenVINOWrapper}) # simplify backend inference - outputs = {'dets': torch.rand(1, 100, 5), 'labels': torch.rand(1, 100)} + num_classes = 80 + num_dets = 10 + outputs = { + 'dets': torch.rand(1, num_dets, 5), + 'labels': torch.randint(num_classes, (1, num_dets)), + 'masks': np.random.rand(1, num_dets, 28, 28) + } + deploy_cfg = mmcv.Config( + dict( + codebase_config=dict( + post_processing=dict(export_postprocess_mask=False)))) with SwitchBackendWrapper(OpenVINOWrapper) as wrapper: wrapper.set(outputs=outputs) from mmdeploy.mmdet.apis.inference import OpenVINODetector - openvino_detector = OpenVINODetector('', ['' for i in range(80)], 0) + openvino_detector = OpenVINODetector( + '', ['' for i in range(80)], 0, deploy_cfg=deploy_cfg) imgs = [torch.rand(1, 3, 64, 64)] img_metas = [[{ 'ori_shape': [64, 64, 3], diff --git a/tests/test_mmdet/test_mmdet_core.py b/tests/test_mmdet/test_mmdet_core.py index 9042146be..8dab90fb4 100644 --- a/tests/test_mmdet/test_mmdet_core.py +++ b/tests/test_mmdet/test_mmdet_core.py @@ -5,7 +5,8 @@ import numpy as np import pytest import torch -from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs +from mmdeploy.utils.test import (WrapFunction, get_onnx_model, + get_rewrite_outputs) @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') @@ -150,3 +151,71 @@ def test_distance2bbox(): distance = torch.rand(3, 4) bbox = distance2bbox(points, distance) assert bbox.shape == torch.Size([3, 4]) + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_multiclass_nms_with_keep_top_k(): + backend_type = 'onnxruntime' + + from mmdeploy.mmdet.core import multiclass_nms + max_output_boxes_per_class = 20 + keep_top_k = 15 + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict( + output_names=None, + input_shape=None, + dynamic_axes=dict( + boxes={ + 0: 'batch_size', + 1: 'num_boxes' + }, + scores={ + 0: 'batch_size', + 1: 'num_boxes', + 2: 'num_classes' + }, + ), + ), + backend_config=dict(type=backend_type), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=max_output_boxes_per_class, + pre_top_k=-1, + keep_top_k=keep_top_k, + background_label_id=-1, + )))) + + num_classes = 5 + num_boxes = 2 + batch_size = 1 + export_boxes = torch.rand(batch_size, num_boxes, 4) + export_scores = torch.ones(batch_size, num_boxes, num_classes) + model_inputs = {'boxes': export_boxes, 'scores': export_scores} + + wrapped_func = WrapFunction( + multiclass_nms, + max_output_boxes_per_class=max_output_boxes_per_class, + keep_top_k=keep_top_k) + + onnx_model_path = get_onnx_model( + wrapped_func, model_inputs=model_inputs, deploy_cfg=deploy_cfg) + + num_boxes = 100 + test_boxes = torch.rand(batch_size, num_boxes, 4) + test_scores = torch.ones(batch_size, num_boxes, num_classes) + model_inputs = {'boxes': test_boxes, 'scores': test_scores} + + import mmdeploy.apis.onnxruntime as ort_apis + backend_model = ort_apis.ORTWrapper(onnx_model_path, 0, None) + dets, _ = backend_model.forward(model_inputs) + + assert dets.shape[1] < keep_top_k, \ + 'multiclass_nms returned more values than "keep_top_k"\n' \ + f'dets.shape: {dets.shape}\n' \ + f'keep_top_k: {keep_top_k}' diff --git a/tests/test_mmdet/test_mmdet_models.py b/tests/test_mmdet/test_mmdet_models.py index 981a93509..415e26774 100644 --- a/tests/test_mmdet/test_mmdet_models.py +++ b/tests/test_mmdet/test_mmdet_models.py @@ -368,7 +368,7 @@ def test_single_roi_extractor(backend_type): model_output, backend_output, rtol=1e-03, atol=1e-05) -def get_cascade_roi_head(): +def get_cascade_roi_head(is_with_masks=False): """CascadeRoIHead Config.""" num_stages = 3 stage_loss_weights = [1, 0.5, 0.25] @@ -408,19 +408,43 @@ def get_cascade_roi_head(): } } for target_stds in all_target_stds] + mask_roi_extractor = { + 'type': 'SingleRoIExtractor', + 'roi_layer': { + 'type': 'RoIAlign', + 'output_size': 14, + 'sampling_ratio': 0 + }, + 'out_channels': 64, + 'featmap_strides': [4, 8, 16, 32] + } + mask_head = { + 'type': 'FCNMaskHead', + 'num_convs': 4, + 'in_channels': 64, + 'conv_out_channels': 64, + 'num_classes': 80, + 'loss_mask': { + 'type': 'CrossEntropyLoss', + 'use_mask': True, + 'loss_weight': 1.0 + } + } + test_cfg = mmcv.Config( dict( score_thr=0.05, nms=mmcv.Config(dict(type='nms', iou_threshold=0.5)), - max_per_img=100)) + max_per_img=100, + mask_thr_binary=0.5)) + + args = [num_stages, stage_loss_weights, bbox_roi_extractor, bbox_head] + kwargs = {'test_cfg': test_cfg} + if is_with_masks: + args += [mask_roi_extractor, mask_head] from mmdet.models import CascadeRoIHead - model = CascadeRoIHead( - num_stages, - stage_loss_weights, - bbox_roi_extractor, - bbox_head, - test_cfg=test_cfg).eval() + model = CascadeRoIHead(*args, **kwargs).eval() return model @@ -497,3 +521,56 @@ def test_cascade_roi_head(backend_type): processed_backend_outputs, rtol=1e-03, atol=1e-05) + + +@pytest.mark.parametrize('backend_type', ['openvino']) +def test_cascade_roi_head_with_mask(backend_type): + pytest.importorskip(backend_type, reason=f'requires {backend_type}') + + cascade_roi_head = get_cascade_roi_head(is_with_masks=True) + seed_everything(1234) + x = [ + torch.rand((1, 64, 200, 304)), + torch.rand((1, 64, 100, 152)), + torch.rand((1, 64, 50, 76)), + torch.rand((1, 64, 25, 38)), + ] + proposals = torch.tensor([[587.8285, 52.1405, 886.2484, 341.5644, 0.5]]) + img_metas = mmcv.Config({ + 'img_shape': torch.tensor([800, 1216]), + 'ori_shape': torch.tensor([800, 1216]), + 'scale_factor': torch.tensor([1, 1, 1, 1]) + }) + + output_names = ['bbox_results', 'segm_results'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=-1, + keep_top_k=100, + background_label_id=-1)))) + model_inputs = {'x': x, 'proposals': proposals.unsqueeze(0)} + wrapped_model = WrapModel( + cascade_roi_head, 'simple_test', img_metas=img_metas) + backend_outputs, _ = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=model_inputs, + deploy_cfg=deploy_cfg) + bbox_results = backend_outputs['bbox_results'] + segm_results = backend_outputs['segm_results'] + expected_bbox_results = np.zeros((1, 80, 5)) + expected_segm_results = -np.ones((1, 80)) + assert np.allclose( + expected_bbox_results, bbox_results, rtol=1e-03, + atol=1e-05), 'bbox_results do not match.' + assert np.allclose( + expected_segm_results, segm_results, rtol=1e-03, + atol=1e-05), 'segm_results do not match.'