[Enhancement]: Added support for masks in OpenVINO. (#148)

* Fix include and lib paths for onnxruntime.

* Fixes for SSD export test

* Add onnx2openvino and OpenVINODetector. Test models: ssd, retinanet, fcos, fsaf.

* Add support for two-stage models: faster_rcnn, cascade_rcnn

* Add doc

* Add strip_doc_string for openvino.

* Fix openvino preprocess.

* Add OpenVINO to test_wrapper.py.

* Fix

* Add openvino_execute.

* Removed preprocessing.

* Fix onnxruntime cmake.

* Rewrote postprocessing and forward, added docstrings and fixes.

* Added device type change to OpenVINOWrapper.

* Update forward_of_single_roi_extractor_dynamic_openvino and fix doc.

* Update docs.

* Add support for masks (Mask RCNN).

* Add masks to CascadeRoIHead.simple_test.

* Added masks to test_OpenVINODetector.

* Added test_cascade_roi_head_with_mask.

* Update docs.

* Fix segm_results shape.

* Fix TopK in NMS and add test_multiclass_nms_with_keep_top_k.

* Removed unnecessary functions.

* Fix.

* Fix test_multiclass_nms_with_keep_top_k.

* Updated test_OpenVINODetector.
This commit is contained in:
Semyon Bevzyuk 2021-11-03 05:27:48 +03:00 committed by GitHub
parent d3e26b68a2
commit c52b24c67f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 305 additions and 67 deletions

View File

@ -39,6 +39,7 @@ Supported backend:
- [x] TensorRT
- [x] PPL
- [x] ncnn
- [x] OpenVINO
## Installation

View File

@ -25,15 +25,17 @@ python tools/deploy.py \
The table below lists the models that are guaranteed to be exportable to OpenVINO from MMDetection.
| Model | Config | Dynamic Shape |
| :----------: | :-----------------------------------------------------------------: | :-----------: |
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y |
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y |
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y |
| SSD | `configs/ssd/ssd300_coco.py` | Y |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y |
| Cascade R-CNN| `configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py` | Y |
| Model name | Config | Dynamic Shape |
| :----------------: | :-----------------------------------------------------------------: | :-----------: |
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y |
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y |
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y |
| SSD | `configs/ssd/ssd300_coco.py` | Y |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y |
| Cascade R-CNN | `configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py` | Y |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y |
| Cascade Mask R-CNN | `configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py` | Y |
Notes:
- For faster work in OpenVINO in the Faster-RCNN, Cascade-RCNN
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.

View File

@ -16,14 +16,14 @@
This tutorial briefly introduces how to export an OpenMMlab model to a specific backend using MMDeploy tools.
Notes:
- Supported backends are [ONNXRuntime](../backends/onnxruntime.md), [TensorRT](../backends/tensorrt.md), [NCNN](../backends/ncnn.md), [PPL](../backends/ppl.md).
- Supported backends are [ONNXRuntime](../backends/onnxruntime.md), [TensorRT](../backends/tensorrt.md), [NCNN](../backends/ncnn.md), [PPL](../backends/ppl.md), [OpenVINO](../backends/openvino.md).
- Supported codebases are [MMClassification](../codebases/mmcls.md), [MMDetection](../codebases/mmdet.md), [MMSegmentation](../codebases/mmseg.md), [MMOCR](../codebases/mmocr.md), [MMEditing](../codebases/mmedit.md).
### How to convert models from Pytorch to other backends
#### Prerequisite
1. Install and build your target backend. You could refer to [ONNXRuntime-install](../backends/onnxruntime.md), [TensorRT-install](../backends/tensorrt.md), [NCNN-install](../backends/ncnn.md), [PPL-install](../backends/ppl.md) for more information.
1. Install and build your target backend. You could refer to [ONNXRuntime-install](../backends/onnxruntime.md), [TensorRT-install](../backends/tensorrt.md), [NCNN-install](../backends/ncnn.md), [PPL-install](../backends/ppl.md), [OpenVINO-install](../backends/openvino.md) for more information.
2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/master/docs/install.md), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/main/docs/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/master/docs/install.md).
#### Usage
@ -78,28 +78,31 @@ You can try to evaluate model, referring to [how_to_evaluate_a_model](./how_to_e
The table below lists the models that are guaranteed to be exportable to other backend.
| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL |
| :----------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: |
| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y |
| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y |
| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y |
| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N |
| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y |
| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y |
| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y |
| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y |
| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y |
| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y |
| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y |
| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y |
| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y |
| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y |
| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y |
| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y |
| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y |
| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y |
| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y |
| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N |
| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO |
| :----------------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: | :-------: |
| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y | N |
| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N | Y |
| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y | Y |
| SSD | MMDetection | $PATH_TO_MMDET/configs/ssd/ssd300_coco.py | Y | ? | ? | ? | Y |
| Cascade R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y |
| Cascade Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y |
| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y | N |
| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y | N |
| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y | N |
| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y | N |
| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N |
| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N |
| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y | N |
| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y | N |
| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y | N |
| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y | N |
| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N | N |
### Reminders

View File

@ -336,15 +336,21 @@ class OpenVINODetector(DeployBaseDetector):
from mmdeploy.apis.openvino import OpenVINOWrapper
self.model = OpenVINOWrapper(model_file)
def forward_test(self, imgs: torch.Tensor, *args, **kwargs):
def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> Tuple:
"""Implement forward test.
Args:
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.
Returns:
tuple[np.ndarray, np.ndarray]: dets of shape [N, num_det, 5]
and class labels of shape [N, num_det].
If there are no masks in the output:
tuple[np.ndarray, np.ndarray]: dets of shape [N, num_det, 5]
and class labels of shape [N, num_det].
If the output contains masks:
tuple[np.ndarray, np.ndarray, np.ndarray]:
dets of shape [N, num_det, 5],
class labels of shape [N, num_det] and
masks of shape [N, num_det, H, W].
"""
openvino_outputs = self.model({'input': imgs})
output_keys = ['dets', 'labels']

View File

@ -55,7 +55,9 @@ def select_nms_index(scores: torch.Tensor,
(N, 1))), 1)
# sort
if keep_top_k > 0 and keep_top_k < batched_dets.shape[1]:
is_use_topk = keep_top_k > 0 and \
(torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1])
if is_use_topk:
_, topk_inds = batched_dets[:, :, -1].topk(keep_top_k, dim=1)
else:
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)

View File

@ -53,5 +53,23 @@ def simple_test_of_cascade_roi_head(ctx, self, x, proposals, img_metas,
if not self.with_mask:
return det_bboxes, det_labels
else:
raise NotImplementedError(
'Masks are not yet supported in export CascadeRoIHead.simple_test')
batch_index = torch.arange(
det_bboxes.size(0), device=det_bboxes.device).float().view(
-1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
rois = det_bboxes[..., :4]
mask_rois = torch.cat([batch_index, rois], dim=-1)
mask_rois = mask_rois.view(-1, 5)
aug_masks = []
for i in range(self.num_stages):
mask_results = self._mask_forward(i, x, mask_rois)
mask_pred = mask_results['mask_pred']
aug_masks.append(mask_pred)
# Calculate the mean of masks from several stage
mask_pred = sum(aug_masks) / len(aug_masks)
segm_results = self.mask_head[-1].get_seg_masks(
mask_pred, rois.reshape(-1, 4), det_labels.reshape(-1),
self.test_cfg, max_shape)
segm_results = segm_results.reshape(batch_size, det_bboxes.shape[1],
segm_results.shape[-2],
segm_results.shape[-1])
return det_bboxes, det_labels, segm_results

View File

@ -212,22 +212,18 @@ def get_flatten_inputs(
return flatten_inputs
def get_rewrite_outputs(wrapped_model: nn.Module,
model_inputs: Dict[str, Union[Tuple, List,
torch.Tensor]],
deploy_cfg: mmcv.Config) -> Tuple[Any, bool]:
"""To get outputs of generated onnx model after rewrite.
def get_onnx_model(wrapped_model: nn.Module,
model_inputs: Dict[str, Union[Tuple, List, torch.Tensor]],
deploy_cfg: mmcv.Config) -> str:
"""To get path to onnx model after export.
Args:
wrap_model (nn.Module): The input model.
func_name (str): The function of model.
model_inputs (dict): Inputs for model.
deploy_cfg (mmcv.Config): Deployment config.
Returns:
Any: The outputs of model, decided by the backend wrapper.
bool: A flag indicate the type of outputs. If the flag is True, then
the outputs are backend output, otherwise they are outputs of wrapped
pytorch model.
str: The path to the ONNX model file.
"""
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
pytorch2onnx_cfg = get_onnx_config(deploy_cfg)
@ -239,7 +235,6 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
output_names = pytorch2onnx_cfg.get('output_names', None)
with RewriterContext(
cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad():
ctx_outputs = wrapped_model(**model_inputs)
torch.onnx.export(
patched_model,
tuple([v for k, v in model_inputs.items()]),
@ -250,12 +245,33 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
opset_version=11,
dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None),
keep_initializers_as_inputs=False)
return onnx_file_path
def get_backend_outputs(onnx_file_path: str,
model_inputs: Dict[str, Union[Tuple, List,
torch.Tensor]],
deploy_cfg: mmcv.Config) -> Any:
"""To get backend outputs of model.
Args:
onnx_file_path (str): The path to the ONNX file.
model_inputs (dict): Inputs for model.
deploy_cfg (mmcv.Config): Deployment config.
Returns:
Any: The outputs of model, decided by the backend wrapper.
"""
backend = get_backend(deploy_cfg)
flatten_model_inputs = get_flatten_inputs(model_inputs)
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
output_names = get_onnx_config(deploy_cfg).get('output_names', None)
# prepare backend model and input features
if backend == Backend.TENSORRT:
# convert to engine
import mmdeploy.apis.tensorrt as trt_apis
if not trt_apis.is_available():
return ctx_outputs, False
return None
trt_file_path = tempfile.NamedTemporaryFile(suffix='.engine').name
trt_apis.onnx2tensorrt(
'',
@ -271,7 +287,7 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
elif backend == Backend.ONNXRUNTIME:
import mmdeploy.apis.onnxruntime as ort_apis
if not ort_apis.is_available():
return ctx_outputs, False
return None
backend_model = ort_apis.ORTWrapper(onnx_file_path, 0, None)
feature_list = []
backend_feats = {}
@ -296,11 +312,11 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
else:
backend_feats[str(i)] = feature_list[i]
elif backend == Backend.NCNN:
return ctx_outputs, False
return None
elif backend == Backend.OPENVINO:
import mmdeploy.apis.openvino as openvino_apis
if not openvino_apis.is_available():
return ctx_outputs, False
return None
openvino_work_dir = tempfile.TemporaryDirectory().name
openvino_file_path = openvino_apis.get_output_model_file(
onnx_file_path, openvino_work_dir)
@ -314,11 +330,44 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
backend_feats = flatten_model_inputs
elif backend == Backend.DEFAULT:
return ctx_outputs, False
return None
else:
raise NotImplementedError(
f'Unimplemented backend type: {backend.value}')
with torch.no_grad():
backend_outputs = backend_model.forward(backend_feats)
return backend_outputs, True
return backend_outputs
def get_rewrite_outputs(wrapped_model: nn.Module,
model_inputs: Dict[str, Union[Tuple, List,
torch.Tensor]],
deploy_cfg: mmcv.Config) -> Tuple[Any, bool]:
"""To get outputs of generated onnx model after rewrite.
Args:
wrap_model (nn.Module): The input model.
model_inputs (dict): Inputs for model.
deploy_cfg (mmcv.Config): Deployment config.
Returns:
Any: The outputs of model, decided by the backend wrapper.
bool: A flag indicate the type of outputs. If the flag is True, then
the outputs are backend output, otherwise they are outputs of wrapped
pytorch model.
"""
backend = get_backend(deploy_cfg)
with RewriterContext(
cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad():
ctx_outputs = wrapped_model(**model_inputs)
onnx_file_path = get_onnx_model(wrapped_model, model_inputs, deploy_cfg)
backend_outputs = get_backend_outputs(onnx_file_path, model_inputs,
deploy_cfg)
if backend_outputs is None:
return ctx_outputs, False
else:
return backend_outputs, True

View File

@ -107,12 +107,23 @@ def test_OpenVINODetector():
openvino_apis.__dict__.update({'OpenVINOWrapper': OpenVINOWrapper})
# simplify backend inference
outputs = {'dets': torch.rand(1, 100, 5), 'labels': torch.rand(1, 100)}
num_classes = 80
num_dets = 10
outputs = {
'dets': torch.rand(1, num_dets, 5),
'labels': torch.randint(num_classes, (1, num_dets)),
'masks': np.random.rand(1, num_dets, 28, 28)
}
deploy_cfg = mmcv.Config(
dict(
codebase_config=dict(
post_processing=dict(export_postprocess_mask=False))))
with SwitchBackendWrapper(OpenVINOWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmdet.apis.inference import OpenVINODetector
openvino_detector = OpenVINODetector('', ['' for i in range(80)], 0)
openvino_detector = OpenVINODetector(
'', ['' for i in range(80)], 0, deploy_cfg=deploy_cfg)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],

View File

@ -5,7 +5,8 @@ import numpy as np
import pytest
import torch
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
from mmdeploy.utils.test import (WrapFunction, get_onnx_model,
get_rewrite_outputs)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@ -150,3 +151,71 @@ def test_distance2bbox():
distance = torch.rand(3, 4)
bbox = distance2bbox(points, distance)
assert bbox.shape == torch.Size([3, 4])
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_multiclass_nms_with_keep_top_k():
backend_type = 'onnxruntime'
from mmdeploy.mmdet.core import multiclass_nms
max_output_boxes_per_class = 20
keep_top_k = 15
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
output_names=None,
input_shape=None,
dynamic_axes=dict(
boxes={
0: 'batch_size',
1: 'num_boxes'
},
scores={
0: 'batch_size',
1: 'num_boxes',
2: 'num_classes'
},
),
),
backend_config=dict(type=backend_type),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=max_output_boxes_per_class,
pre_top_k=-1,
keep_top_k=keep_top_k,
background_label_id=-1,
))))
num_classes = 5
num_boxes = 2
batch_size = 1
export_boxes = torch.rand(batch_size, num_boxes, 4)
export_scores = torch.ones(batch_size, num_boxes, num_classes)
model_inputs = {'boxes': export_boxes, 'scores': export_scores}
wrapped_func = WrapFunction(
multiclass_nms,
max_output_boxes_per_class=max_output_boxes_per_class,
keep_top_k=keep_top_k)
onnx_model_path = get_onnx_model(
wrapped_func, model_inputs=model_inputs, deploy_cfg=deploy_cfg)
num_boxes = 100
test_boxes = torch.rand(batch_size, num_boxes, 4)
test_scores = torch.ones(batch_size, num_boxes, num_classes)
model_inputs = {'boxes': test_boxes, 'scores': test_scores}
import mmdeploy.apis.onnxruntime as ort_apis
backend_model = ort_apis.ORTWrapper(onnx_model_path, 0, None)
dets, _ = backend_model.forward(model_inputs)
assert dets.shape[1] < keep_top_k, \
'multiclass_nms returned more values than "keep_top_k"\n' \
f'dets.shape: {dets.shape}\n' \
f'keep_top_k: {keep_top_k}'

View File

@ -368,7 +368,7 @@ def test_single_roi_extractor(backend_type):
model_output, backend_output, rtol=1e-03, atol=1e-05)
def get_cascade_roi_head():
def get_cascade_roi_head(is_with_masks=False):
"""CascadeRoIHead Config."""
num_stages = 3
stage_loss_weights = [1, 0.5, 0.25]
@ -408,19 +408,43 @@ def get_cascade_roi_head():
}
} for target_stds in all_target_stds]
mask_roi_extractor = {
'type': 'SingleRoIExtractor',
'roi_layer': {
'type': 'RoIAlign',
'output_size': 14,
'sampling_ratio': 0
},
'out_channels': 64,
'featmap_strides': [4, 8, 16, 32]
}
mask_head = {
'type': 'FCNMaskHead',
'num_convs': 4,
'in_channels': 64,
'conv_out_channels': 64,
'num_classes': 80,
'loss_mask': {
'type': 'CrossEntropyLoss',
'use_mask': True,
'loss_weight': 1.0
}
}
test_cfg = mmcv.Config(
dict(
score_thr=0.05,
nms=mmcv.Config(dict(type='nms', iou_threshold=0.5)),
max_per_img=100))
max_per_img=100,
mask_thr_binary=0.5))
args = [num_stages, stage_loss_weights, bbox_roi_extractor, bbox_head]
kwargs = {'test_cfg': test_cfg}
if is_with_masks:
args += [mask_roi_extractor, mask_head]
from mmdet.models import CascadeRoIHead
model = CascadeRoIHead(
num_stages,
stage_loss_weights,
bbox_roi_extractor,
bbox_head,
test_cfg=test_cfg).eval()
model = CascadeRoIHead(*args, **kwargs).eval()
return model
@ -497,3 +521,56 @@ def test_cascade_roi_head(backend_type):
processed_backend_outputs,
rtol=1e-03,
atol=1e-05)
@pytest.mark.parametrize('backend_type', ['openvino'])
def test_cascade_roi_head_with_mask(backend_type):
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
cascade_roi_head = get_cascade_roi_head(is_with_masks=True)
seed_everything(1234)
x = [
torch.rand((1, 64, 200, 304)),
torch.rand((1, 64, 100, 152)),
torch.rand((1, 64, 50, 76)),
torch.rand((1, 64, 25, 38)),
]
proposals = torch.tensor([[587.8285, 52.1405, 886.2484, 341.5644, 0.5]])
img_metas = mmcv.Config({
'img_shape': torch.tensor([800, 1216]),
'ori_shape': torch.tensor([800, 1216]),
'scale_factor': torch.tensor([1, 1, 1, 1])
})
output_names = ['bbox_results', 'segm_results']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1))))
model_inputs = {'x': x, 'proposals': proposals.unsqueeze(0)}
wrapped_model = WrapModel(
cascade_roi_head, 'simple_test', img_metas=img_metas)
backend_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=model_inputs,
deploy_cfg=deploy_cfg)
bbox_results = backend_outputs['bbox_results']
segm_results = backend_outputs['segm_results']
expected_bbox_results = np.zeros((1, 80, 5))
expected_segm_results = -np.ones((1, 80))
assert np.allclose(
expected_bbox_results, bbox_results, rtol=1e-03,
atol=1e-05), 'bbox_results do not match.'
assert np.allclose(
expected_segm_results, segm_results, rtol=1e-03,
atol=1e-05), 'segm_results do not match.'