mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement]: Added support for masks in OpenVINO. (#148)
* Fix include and lib paths for onnxruntime. * Fixes for SSD export test * Add onnx2openvino and OpenVINODetector. Test models: ssd, retinanet, fcos, fsaf. * Add support for two-stage models: faster_rcnn, cascade_rcnn * Add doc * Add strip_doc_string for openvino. * Fix openvino preprocess. * Add OpenVINO to test_wrapper.py. * Fix * Add openvino_execute. * Removed preprocessing. * Fix onnxruntime cmake. * Rewrote postprocessing and forward, added docstrings and fixes. * Added device type change to OpenVINOWrapper. * Update forward_of_single_roi_extractor_dynamic_openvino and fix doc. * Update docs. * Add support for masks (Mask RCNN). * Add masks to CascadeRoIHead.simple_test. * Added masks to test_OpenVINODetector. * Added test_cascade_roi_head_with_mask. * Update docs. * Fix segm_results shape. * Fix TopK in NMS and add test_multiclass_nms_with_keep_top_k. * Removed unnecessary functions. * Fix. * Fix test_multiclass_nms_with_keep_top_k. * Updated test_OpenVINODetector.
This commit is contained in:
parent
d3e26b68a2
commit
c52b24c67f
@ -39,6 +39,7 @@ Supported backend:
|
||||
- [x] TensorRT
|
||||
- [x] PPL
|
||||
- [x] ncnn
|
||||
- [x] OpenVINO
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -25,15 +25,17 @@ python tools/deploy.py \
|
||||
|
||||
The table below lists the models that are guaranteed to be exportable to OpenVINO from MMDetection.
|
||||
|
||||
| Model | Config | Dynamic Shape |
|
||||
| :----------: | :-----------------------------------------------------------------: | :-----------: |
|
||||
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y |
|
||||
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y |
|
||||
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y |
|
||||
| SSD | `configs/ssd/ssd300_coco.py` | Y |
|
||||
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y |
|
||||
| Cascade R-CNN| `configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py` | Y |
|
||||
| Model name | Config | Dynamic Shape |
|
||||
| :----------------: | :-----------------------------------------------------------------: | :-----------: |
|
||||
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y |
|
||||
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y |
|
||||
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y |
|
||||
| SSD | `configs/ssd/ssd300_coco.py` | Y |
|
||||
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y |
|
||||
| Cascade R-CNN | `configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py` | Y |
|
||||
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y |
|
||||
| Cascade Mask R-CNN | `configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py` | Y |
|
||||
|
||||
Notes:
|
||||
- For faster work in OpenVINO in the Faster-RCNN, Cascade-RCNN
|
||||
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
|
||||
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
|
||||
|
@ -16,14 +16,14 @@
|
||||
|
||||
This tutorial briefly introduces how to export an OpenMMlab model to a specific backend using MMDeploy tools.
|
||||
Notes:
|
||||
- Supported backends are [ONNXRuntime](../backends/onnxruntime.md), [TensorRT](../backends/tensorrt.md), [NCNN](../backends/ncnn.md), [PPL](../backends/ppl.md).
|
||||
- Supported backends are [ONNXRuntime](../backends/onnxruntime.md), [TensorRT](../backends/tensorrt.md), [NCNN](../backends/ncnn.md), [PPL](../backends/ppl.md), [OpenVINO](../backends/openvino.md).
|
||||
- Supported codebases are [MMClassification](../codebases/mmcls.md), [MMDetection](../codebases/mmdet.md), [MMSegmentation](../codebases/mmseg.md), [MMOCR](../codebases/mmocr.md), [MMEditing](../codebases/mmedit.md).
|
||||
|
||||
### How to convert models from Pytorch to other backends
|
||||
|
||||
#### Prerequisite
|
||||
|
||||
1. Install and build your target backend. You could refer to [ONNXRuntime-install](../backends/onnxruntime.md), [TensorRT-install](../backends/tensorrt.md), [NCNN-install](../backends/ncnn.md), [PPL-install](../backends/ppl.md) for more information.
|
||||
1. Install and build your target backend. You could refer to [ONNXRuntime-install](../backends/onnxruntime.md), [TensorRT-install](../backends/tensorrt.md), [NCNN-install](../backends/ncnn.md), [PPL-install](../backends/ppl.md), [OpenVINO-install](../backends/openvino.md) for more information.
|
||||
2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/master/docs/install.md), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/main/docs/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/master/docs/install.md).
|
||||
|
||||
#### Usage
|
||||
@ -78,28 +78,31 @@ You can try to evaluate model, referring to [how_to_evaluate_a_model](./how_to_e
|
||||
|
||||
The table below lists the models that are guaranteed to be exportable to other backend.
|
||||
|
||||
| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL |
|
||||
| :----------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: |
|
||||
| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y |
|
||||
| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y |
|
||||
| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y |
|
||||
| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N |
|
||||
| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y |
|
||||
| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y |
|
||||
| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y |
|
||||
| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y |
|
||||
| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y |
|
||||
| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y |
|
||||
| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y |
|
||||
| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y |
|
||||
| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y |
|
||||
| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y |
|
||||
| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y |
|
||||
| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y |
|
||||
| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y |
|
||||
| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y |
|
||||
| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y |
|
||||
| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N |
|
||||
| Model | codebase | model config file(example) | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO |
|
||||
| :----------------: | :--------------: | :---------------------------------------------------------------------------------------: | :---------: | :-----------: | :---:| :---: | :-------: |
|
||||
| RetinaNet | MMDetection | $PATH_TO_MMDET/configs/retinanet/retinanet_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
|
||||
| Faster R-CNN | MMDetection | $PATH_TO_MMDET/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
|
||||
| YOLOv3 | MMDetection | $PATH_TO_MMDET/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py | Y | Y | N | Y | N |
|
||||
| FCOS | MMDetection | $PATH_TO_MMDET/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py | Y | Y | Y | N | Y |
|
||||
| FSAF | MMDetection | $PATH_TO_MMDET/configs/fsaf/fsaf_r50_fpn_1x_coco.py | Y | Y | Y | Y | Y |
|
||||
| Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py | Y | Y | N | Y | Y |
|
||||
| SSD | MMDetection | $PATH_TO_MMDET/configs/ssd/ssd300_coco.py | Y | ? | ? | ? | Y |
|
||||
| Cascade R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y |
|
||||
| Cascade Mask R-CNN | MMDetection | $PATH_TO_MMDET/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py | Y | ? | ? | ? | Y |
|
||||
| ResNet | MMClassification | $PATH_TO_MMCLS/configs/resnet/resnet18_b32x8_imagenet.py | Y | Y | Y | Y | N |
|
||||
| ResNeXt | MMClassification | $PATH_TO_MMCLS/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | Y | Y | Y | Y | N |
|
||||
| SE-ResNet | MMClassification | $PATH_TO_MMCLS/configs/seresnet/seresnet50_b32x8_imagenet.py | Y | Y | Y | Y | N |
|
||||
| MobileNetV2 | MMClassification | $PATH_TO_MMCLS/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py | Y | Y | Y | Y | N |
|
||||
| ShuffleNetV1 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N |
|
||||
| ShuffleNetV2 | MMClassification | $PATH_TO_MMCLS/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py | Y | Y | N | Y | N |
|
||||
| FCN | MMSegmentation | $PATH_TO_MMSEG/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
|
||||
| PSPNet | MMSegmentation | $PATH_TO_MMSEG/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py | Y | Y | N | Y | N |
|
||||
| DeepLabV3 | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
|
||||
| DeepLabV3+ | MMSegmentation | $PATH_TO_MMSEG/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | Y | Y | Y | Y | N |
|
||||
| SRCNN | MMEditing | $PATH_TO_MMSEG/configs/restorers/srcnn/srcnn_x4k915_g1_1000k_div2k.py | Y | Y | N | Y | N |
|
||||
| ESRGAN | MMEditing | $PATH_TO_MMSEG/configs/restorers/esrgan/esrgan_psnr_x4c64b23g32_g1_1000k_div2k.py | Y | Y | N | Y | N |
|
||||
| DBNet | MMOCR | $PATH_TO_MMOCR/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py | Y | Y | Y | Y | N |
|
||||
| CRNN | MMOCR | $PATH_TO_MMOCR/configs/textrecog/tps/crnn_tps_academic_dataset.py | Y | Y | Y | N | N |
|
||||
|
||||
### Reminders
|
||||
|
||||
|
@ -336,15 +336,21 @@ class OpenVINODetector(DeployBaseDetector):
|
||||
from mmdeploy.apis.openvino import OpenVINOWrapper
|
||||
self.model = OpenVINOWrapper(model_file)
|
||||
|
||||
def forward_test(self, imgs: torch.Tensor, *args, **kwargs):
|
||||
def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> Tuple:
|
||||
"""Implement forward test.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, np.ndarray]: dets of shape [N, num_det, 5]
|
||||
and class labels of shape [N, num_det].
|
||||
If there are no masks in the output:
|
||||
tuple[np.ndarray, np.ndarray]: dets of shape [N, num_det, 5]
|
||||
and class labels of shape [N, num_det].
|
||||
If the output contains masks:
|
||||
tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
dets of shape [N, num_det, 5],
|
||||
class labels of shape [N, num_det] and
|
||||
masks of shape [N, num_det, H, W].
|
||||
"""
|
||||
openvino_outputs = self.model({'input': imgs})
|
||||
output_keys = ['dets', 'labels']
|
||||
|
@ -55,7 +55,9 @@ def select_nms_index(scores: torch.Tensor,
|
||||
(N, 1))), 1)
|
||||
|
||||
# sort
|
||||
if keep_top_k > 0 and keep_top_k < batched_dets.shape[1]:
|
||||
is_use_topk = keep_top_k > 0 and \
|
||||
(torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1])
|
||||
if is_use_topk:
|
||||
_, topk_inds = batched_dets[:, :, -1].topk(keep_top_k, dim=1)
|
||||
else:
|
||||
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)
|
||||
|
@ -53,5 +53,23 @@ def simple_test_of_cascade_roi_head(ctx, self, x, proposals, img_metas,
|
||||
if not self.with_mask:
|
||||
return det_bboxes, det_labels
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Masks are not yet supported in export CascadeRoIHead.simple_test')
|
||||
batch_index = torch.arange(
|
||||
det_bboxes.size(0), device=det_bboxes.device).float().view(
|
||||
-1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
|
||||
rois = det_bboxes[..., :4]
|
||||
mask_rois = torch.cat([batch_index, rois], dim=-1)
|
||||
mask_rois = mask_rois.view(-1, 5)
|
||||
aug_masks = []
|
||||
for i in range(self.num_stages):
|
||||
mask_results = self._mask_forward(i, x, mask_rois)
|
||||
mask_pred = mask_results['mask_pred']
|
||||
aug_masks.append(mask_pred)
|
||||
# Calculate the mean of masks from several stage
|
||||
mask_pred = sum(aug_masks) / len(aug_masks)
|
||||
segm_results = self.mask_head[-1].get_seg_masks(
|
||||
mask_pred, rois.reshape(-1, 4), det_labels.reshape(-1),
|
||||
self.test_cfg, max_shape)
|
||||
segm_results = segm_results.reshape(batch_size, det_bboxes.shape[1],
|
||||
segm_results.shape[-2],
|
||||
segm_results.shape[-1])
|
||||
return det_bboxes, det_labels, segm_results
|
||||
|
@ -212,22 +212,18 @@ def get_flatten_inputs(
|
||||
return flatten_inputs
|
||||
|
||||
|
||||
def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
model_inputs: Dict[str, Union[Tuple, List,
|
||||
torch.Tensor]],
|
||||
deploy_cfg: mmcv.Config) -> Tuple[Any, bool]:
|
||||
"""To get outputs of generated onnx model after rewrite.
|
||||
def get_onnx_model(wrapped_model: nn.Module,
|
||||
model_inputs: Dict[str, Union[Tuple, List, torch.Tensor]],
|
||||
deploy_cfg: mmcv.Config) -> str:
|
||||
"""To get path to onnx model after export.
|
||||
|
||||
Args:
|
||||
wrap_model (nn.Module): The input model.
|
||||
func_name (str): The function of model.
|
||||
model_inputs (dict): Inputs for model.
|
||||
deploy_cfg (mmcv.Config): Deployment config.
|
||||
|
||||
Returns:
|
||||
Any: The outputs of model, decided by the backend wrapper.
|
||||
bool: A flag indicate the type of outputs. If the flag is True, then
|
||||
the outputs are backend output, otherwise they are outputs of wrapped
|
||||
pytorch model.
|
||||
str: The path to the ONNX model file.
|
||||
"""
|
||||
onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name
|
||||
pytorch2onnx_cfg = get_onnx_config(deploy_cfg)
|
||||
@ -239,7 +235,6 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
output_names = pytorch2onnx_cfg.get('output_names', None)
|
||||
with RewriterContext(
|
||||
cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad():
|
||||
ctx_outputs = wrapped_model(**model_inputs)
|
||||
torch.onnx.export(
|
||||
patched_model,
|
||||
tuple([v for k, v in model_inputs.items()]),
|
||||
@ -250,12 +245,33 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
opset_version=11,
|
||||
dynamic_axes=pytorch2onnx_cfg.get('dynamic_axes', None),
|
||||
keep_initializers_as_inputs=False)
|
||||
return onnx_file_path
|
||||
|
||||
|
||||
def get_backend_outputs(onnx_file_path: str,
|
||||
model_inputs: Dict[str, Union[Tuple, List,
|
||||
torch.Tensor]],
|
||||
deploy_cfg: mmcv.Config) -> Any:
|
||||
"""To get backend outputs of model.
|
||||
|
||||
Args:
|
||||
onnx_file_path (str): The path to the ONNX file.
|
||||
model_inputs (dict): Inputs for model.
|
||||
deploy_cfg (mmcv.Config): Deployment config.
|
||||
|
||||
Returns:
|
||||
Any: The outputs of model, decided by the backend wrapper.
|
||||
"""
|
||||
backend = get_backend(deploy_cfg)
|
||||
flatten_model_inputs = get_flatten_inputs(model_inputs)
|
||||
input_names = [k for k, v in flatten_model_inputs.items() if k != 'ctx']
|
||||
output_names = get_onnx_config(deploy_cfg).get('output_names', None)
|
||||
# prepare backend model and input features
|
||||
if backend == Backend.TENSORRT:
|
||||
# convert to engine
|
||||
import mmdeploy.apis.tensorrt as trt_apis
|
||||
if not trt_apis.is_available():
|
||||
return ctx_outputs, False
|
||||
return None
|
||||
trt_file_path = tempfile.NamedTemporaryFile(suffix='.engine').name
|
||||
trt_apis.onnx2tensorrt(
|
||||
'',
|
||||
@ -271,7 +287,7 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
elif backend == Backend.ONNXRUNTIME:
|
||||
import mmdeploy.apis.onnxruntime as ort_apis
|
||||
if not ort_apis.is_available():
|
||||
return ctx_outputs, False
|
||||
return None
|
||||
backend_model = ort_apis.ORTWrapper(onnx_file_path, 0, None)
|
||||
feature_list = []
|
||||
backend_feats = {}
|
||||
@ -296,11 +312,11 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
else:
|
||||
backend_feats[str(i)] = feature_list[i]
|
||||
elif backend == Backend.NCNN:
|
||||
return ctx_outputs, False
|
||||
return None
|
||||
elif backend == Backend.OPENVINO:
|
||||
import mmdeploy.apis.openvino as openvino_apis
|
||||
if not openvino_apis.is_available():
|
||||
return ctx_outputs, False
|
||||
return None
|
||||
openvino_work_dir = tempfile.TemporaryDirectory().name
|
||||
openvino_file_path = openvino_apis.get_output_model_file(
|
||||
onnx_file_path, openvino_work_dir)
|
||||
@ -314,11 +330,44 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
|
||||
backend_feats = flatten_model_inputs
|
||||
elif backend == Backend.DEFAULT:
|
||||
return ctx_outputs, False
|
||||
return None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Unimplemented backend type: {backend.value}')
|
||||
|
||||
with torch.no_grad():
|
||||
backend_outputs = backend_model.forward(backend_feats)
|
||||
return backend_outputs, True
|
||||
return backend_outputs
|
||||
|
||||
|
||||
def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
model_inputs: Dict[str, Union[Tuple, List,
|
||||
torch.Tensor]],
|
||||
deploy_cfg: mmcv.Config) -> Tuple[Any, bool]:
|
||||
"""To get outputs of generated onnx model after rewrite.
|
||||
|
||||
Args:
|
||||
wrap_model (nn.Module): The input model.
|
||||
model_inputs (dict): Inputs for model.
|
||||
deploy_cfg (mmcv.Config): Deployment config.
|
||||
|
||||
Returns:
|
||||
Any: The outputs of model, decided by the backend wrapper.
|
||||
bool: A flag indicate the type of outputs. If the flag is True, then
|
||||
the outputs are backend output, otherwise they are outputs of wrapped
|
||||
pytorch model.
|
||||
"""
|
||||
backend = get_backend(deploy_cfg)
|
||||
with RewriterContext(
|
||||
cfg=deploy_cfg, backend=backend.value, opset=11), torch.no_grad():
|
||||
ctx_outputs = wrapped_model(**model_inputs)
|
||||
|
||||
onnx_file_path = get_onnx_model(wrapped_model, model_inputs, deploy_cfg)
|
||||
|
||||
backend_outputs = get_backend_outputs(onnx_file_path, model_inputs,
|
||||
deploy_cfg)
|
||||
|
||||
if backend_outputs is None:
|
||||
return ctx_outputs, False
|
||||
else:
|
||||
return backend_outputs, True
|
||||
|
@ -107,12 +107,23 @@ def test_OpenVINODetector():
|
||||
openvino_apis.__dict__.update({'OpenVINOWrapper': OpenVINOWrapper})
|
||||
|
||||
# simplify backend inference
|
||||
outputs = {'dets': torch.rand(1, 100, 5), 'labels': torch.rand(1, 100)}
|
||||
num_classes = 80
|
||||
num_dets = 10
|
||||
outputs = {
|
||||
'dets': torch.rand(1, num_dets, 5),
|
||||
'labels': torch.randint(num_classes, (1, num_dets)),
|
||||
'masks': np.random.rand(1, num_dets, 28, 28)
|
||||
}
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
codebase_config=dict(
|
||||
post_processing=dict(export_postprocess_mask=False))))
|
||||
with SwitchBackendWrapper(OpenVINOWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
|
||||
from mmdeploy.mmdet.apis.inference import OpenVINODetector
|
||||
openvino_detector = OpenVINODetector('', ['' for i in range(80)], 0)
|
||||
openvino_detector = OpenVINODetector(
|
||||
'', ['' for i in range(80)], 0, deploy_cfg=deploy_cfg)
|
||||
imgs = [torch.rand(1, 3, 64, 64)]
|
||||
img_metas = [[{
|
||||
'ori_shape': [64, 64, 3],
|
||||
|
@ -5,7 +5,8 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdeploy.utils.test import WrapFunction, get_rewrite_outputs
|
||||
from mmdeploy.utils.test import (WrapFunction, get_onnx_model,
|
||||
get_rewrite_outputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
@ -150,3 +151,71 @@ def test_distance2bbox():
|
||||
distance = torch.rand(3, 4)
|
||||
bbox = distance2bbox(points, distance)
|
||||
assert bbox.shape == torch.Size([3, 4])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||
def test_multiclass_nms_with_keep_top_k():
|
||||
backend_type = 'onnxruntime'
|
||||
|
||||
from mmdeploy.mmdet.core import multiclass_nms
|
||||
max_output_boxes_per_class = 20
|
||||
keep_top_k = 15
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(
|
||||
output_names=None,
|
||||
input_shape=None,
|
||||
dynamic_axes=dict(
|
||||
boxes={
|
||||
0: 'batch_size',
|
||||
1: 'num_boxes'
|
||||
},
|
||||
scores={
|
||||
0: 'batch_size',
|
||||
1: 'num_boxes',
|
||||
2: 'num_classes'
|
||||
},
|
||||
),
|
||||
),
|
||||
backend_config=dict(type=backend_type),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=keep_top_k,
|
||||
background_label_id=-1,
|
||||
))))
|
||||
|
||||
num_classes = 5
|
||||
num_boxes = 2
|
||||
batch_size = 1
|
||||
export_boxes = torch.rand(batch_size, num_boxes, 4)
|
||||
export_scores = torch.ones(batch_size, num_boxes, num_classes)
|
||||
model_inputs = {'boxes': export_boxes, 'scores': export_scores}
|
||||
|
||||
wrapped_func = WrapFunction(
|
||||
multiclass_nms,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
keep_top_k=keep_top_k)
|
||||
|
||||
onnx_model_path = get_onnx_model(
|
||||
wrapped_func, model_inputs=model_inputs, deploy_cfg=deploy_cfg)
|
||||
|
||||
num_boxes = 100
|
||||
test_boxes = torch.rand(batch_size, num_boxes, 4)
|
||||
test_scores = torch.ones(batch_size, num_boxes, num_classes)
|
||||
model_inputs = {'boxes': test_boxes, 'scores': test_scores}
|
||||
|
||||
import mmdeploy.apis.onnxruntime as ort_apis
|
||||
backend_model = ort_apis.ORTWrapper(onnx_model_path, 0, None)
|
||||
dets, _ = backend_model.forward(model_inputs)
|
||||
|
||||
assert dets.shape[1] < keep_top_k, \
|
||||
'multiclass_nms returned more values than "keep_top_k"\n' \
|
||||
f'dets.shape: {dets.shape}\n' \
|
||||
f'keep_top_k: {keep_top_k}'
|
||||
|
@ -368,7 +368,7 @@ def test_single_roi_extractor(backend_type):
|
||||
model_output, backend_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def get_cascade_roi_head():
|
||||
def get_cascade_roi_head(is_with_masks=False):
|
||||
"""CascadeRoIHead Config."""
|
||||
num_stages = 3
|
||||
stage_loss_weights = [1, 0.5, 0.25]
|
||||
@ -408,19 +408,43 @@ def get_cascade_roi_head():
|
||||
}
|
||||
} for target_stds in all_target_stds]
|
||||
|
||||
mask_roi_extractor = {
|
||||
'type': 'SingleRoIExtractor',
|
||||
'roi_layer': {
|
||||
'type': 'RoIAlign',
|
||||
'output_size': 14,
|
||||
'sampling_ratio': 0
|
||||
},
|
||||
'out_channels': 64,
|
||||
'featmap_strides': [4, 8, 16, 32]
|
||||
}
|
||||
mask_head = {
|
||||
'type': 'FCNMaskHead',
|
||||
'num_convs': 4,
|
||||
'in_channels': 64,
|
||||
'conv_out_channels': 64,
|
||||
'num_classes': 80,
|
||||
'loss_mask': {
|
||||
'type': 'CrossEntropyLoss',
|
||||
'use_mask': True,
|
||||
'loss_weight': 1.0
|
||||
}
|
||||
}
|
||||
|
||||
test_cfg = mmcv.Config(
|
||||
dict(
|
||||
score_thr=0.05,
|
||||
nms=mmcv.Config(dict(type='nms', iou_threshold=0.5)),
|
||||
max_per_img=100))
|
||||
max_per_img=100,
|
||||
mask_thr_binary=0.5))
|
||||
|
||||
args = [num_stages, stage_loss_weights, bbox_roi_extractor, bbox_head]
|
||||
kwargs = {'test_cfg': test_cfg}
|
||||
if is_with_masks:
|
||||
args += [mask_roi_extractor, mask_head]
|
||||
|
||||
from mmdet.models import CascadeRoIHead
|
||||
model = CascadeRoIHead(
|
||||
num_stages,
|
||||
stage_loss_weights,
|
||||
bbox_roi_extractor,
|
||||
bbox_head,
|
||||
test_cfg=test_cfg).eval()
|
||||
model = CascadeRoIHead(*args, **kwargs).eval()
|
||||
return model
|
||||
|
||||
|
||||
@ -497,3 +521,56 @@ def test_cascade_roi_head(backend_type):
|
||||
processed_backend_outputs,
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', ['openvino'])
|
||||
def test_cascade_roi_head_with_mask(backend_type):
|
||||
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||
|
||||
cascade_roi_head = get_cascade_roi_head(is_with_masks=True)
|
||||
seed_everything(1234)
|
||||
x = [
|
||||
torch.rand((1, 64, 200, 304)),
|
||||
torch.rand((1, 64, 100, 152)),
|
||||
torch.rand((1, 64, 50, 76)),
|
||||
torch.rand((1, 64, 25, 38)),
|
||||
]
|
||||
proposals = torch.tensor([[587.8285, 52.1405, 886.2484, 341.5644, 0.5]])
|
||||
img_metas = mmcv.Config({
|
||||
'img_shape': torch.tensor([800, 1216]),
|
||||
'ori_shape': torch.tensor([800, 1216]),
|
||||
'scale_factor': torch.tensor([1, 1, 1, 1])
|
||||
})
|
||||
|
||||
output_names = ['bbox_results', 'segm_results']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1))))
|
||||
model_inputs = {'x': x, 'proposals': proposals.unsqueeze(0)}
|
||||
wrapped_model = WrapModel(
|
||||
cascade_roi_head, 'simple_test', img_metas=img_metas)
|
||||
backend_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=model_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
bbox_results = backend_outputs['bbox_results']
|
||||
segm_results = backend_outputs['segm_results']
|
||||
expected_bbox_results = np.zeros((1, 80, 5))
|
||||
expected_segm_results = -np.ones((1, 80))
|
||||
assert np.allclose(
|
||||
expected_bbox_results, bbox_results, rtol=1e-03,
|
||||
atol=1e-05), 'bbox_results do not match.'
|
||||
assert np.allclose(
|
||||
expected_segm_results, segm_results, rtol=1e-03,
|
||||
atol=1e-05), 'segm_results do not match.'
|
||||
|
Loading…
x
Reference in New Issue
Block a user