[Enhancement]: Added static config and CI tests for OpenVINO. (#218)

* Add openvino_static.

* Add openvino-dev.

* Fix skipping ORT tests in test_mmocr_models.

* Updated docs.

* Fix print.

* Fix

* Fix

* Fix other backends

* Fix is_available

* fix ncnn

* Add constrict for get rewrite output

* add not

* Fix

* fix

* Fix

* Fix

* Improve tests

* Remove rebundant `cuda`

* Prevent None object and rename variable

* Fix multi-line string

* rename get_backend_checker

* Add Troubleshooting to doc.

* Fix postprocessing_masks with empty masks.

* Fix tests

* lint

* Update docs.

Co-authored-by: SingleZombie <singlezombie@163.com>
pull/12/head
Semyon Bevzyuk 2021-12-01 09:03:48 +03:00 committed by GitHub
parent f4b3db188e
commit bb9b0a98de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 166 additions and 52 deletions

View File

@ -0,0 +1,3 @@
_base_ = ['./base_static.py', '../../_base_/backends/openvino.py']
onnx_config = dict(input_shape=(1344, 800))

View File

@ -0,0 +1 @@
_base_ = ['../_base_/base_openvino_static-800x1344.py']

View File

@ -3,9 +3,30 @@
This tutorial is based on Linux systems like Ubuntu-18.04.
### Installation
It is recommended to create a virtual environment for the project.
1. Install [OpenVINO](https://docs.openvino.ai/2021.4/get_started.html). For example, you can install OpenVINO with [pip](https://pypi.org/project/openvino-dev/).
```bash
pip install openvino-dev
```
2. Install [PyTorch](https://pytorch.org/get-started/locally/).
```bash
pip install torch torchvision
```
3. Install [MMCV](https://mmcv.readthedocs.io/en/latest/get_started/installation.html). It is advisable to install the latest version `mmcv-full`.
```bash
pip install mmcv-full
```
4. Install MMDeploy following the [instructions](../build.md).
1. Install [OpenVINO](https://docs.openvinotoolkit.org/latest/installation_guides.html).
2. Install MMDeploy following the [instructions](../build.md).
To work with models from [MMDetection](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md), you may need to install it additionally.
### Troubleshooting
#### ImportError: libpython3.7m.so.1.0: cannot open shared object file: No such file or directory
To resolve missing external dependency on Ubuntu*, execute the following command:
```bash
sudo apt-get install libpython3.7
```
### Usage
@ -41,7 +62,13 @@ The table below lists the models that are guaranteed to be exportable to OpenVIN
| Faster R-CNN + DCN | `configs/dcn/faster_rcnn_r50_fpn_dconv_c3-c5_1x_coco.py` | Y |
| VFNet | `configs/vfnet/vfnet_r50_fpn_1x_coco.py` | Y |
Notes:
- Custom operations from OpenVINO use the domain `org.openvinotoolkit`.
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
- Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation.
### FAQs
- None

View File

@ -50,3 +50,6 @@ Build the inference engine extension libraries you need.
cd ${MMDEPLOY_DIR} # To mmdeploy root directory
pip install -e .
```
Some dependencies are optional. Simply running `pip install -e .` will only install the minimum runtime requirements.
To use optional dependencies install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -e .[optional]`).
Valid keys for the extras field are: `all`, `tests`, `build`, `optional`.

View File

@ -8,14 +8,22 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/
### List of MMDetection models supported by MMDeploy
| model | task | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO | model config file(example) |
|:-------------|:-------------|:-----------:|:--------:|:----:|:---:|:--------:|:------------------------------------------------------------------|
| RetinaNet | single-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/retinanet/retinanet_r50_fpn_1x_coco.py |
| YOLOv3 | single-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py |
| FCOS | single-stage | Y | Y | Y | N | Y | $MMDET_DIR/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py |
| FSAF | single-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/fsaf/fsaf_r50_fpn_1x_coco.py |
| Faster R-CNN | two-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py |
| Mask R-CNN | two-stage | Y | Y | N | Y | Y | $MMDET_DIR/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py |
| model | task | OnnxRuntime | TensorRT | NCNN | PPL | OpenVINO | model config file(example) |
|:-------------------|:-------------|:-----------:|:--------:|:----:|:---:|:--------:|:---------------------------------------------------------------------|
| ATSS | single-stage | Y | ? | ? | ? | Y | $MMDET_DIR/configs/atss/atss_r50_fpn_1x_coco.py |
| FCOS | single-stage | Y | Y | Y | N | Y | $MMDET_DIR/configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py |
| FoveaBox | single-stage | Y | ? | ? | ? | Y | $MMDET_DIR/configs/foveabox/fovea_r50_fpn_4x4_1x_coco.py |
| FSAF | single-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/fsaf/fsaf_r50_fpn_1x_coco.py |
| RetinaNet | single-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/retinanet/retinanet_r50_fpn_1x_coco.py |
| SSD | single-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/ssd/ssd300_coco.py |
| VFNet | single-stage | Y | ? | ? | ? | Y | $MMDET_DIR/configs/vfnet/vfnet_r50_fpn_1x_coco.py |
| YOLOv3 | single-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py |
| YOLOX | single-stage | Y | ? | ? | ? | Y | $MMDET_DIR/configs/yolox/yolox_tiny_8x8_300e_coco.py |
| Cascade R-CNN | two-stage | Y | ? | ? | Y | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py |
| Faster R-CNN | two-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py |
| Faster R-CNN + DCN | two-stage | Y | Y | Y | Y | Y | $MMDET_DIR/configs/dcn/faster_rcnn_r50_fpn_dconv_c3-c5_1x_coco.py` |
| Mask Cascade R-CNN | two-stage | Y | ? | ? | Y | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py |
| Mask R-CNN | two-stage | Y | Y | N | Y | Y | $MMDET_DIR/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py |
### Reminder

View File

@ -11,7 +11,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl
| model | OnnxRuntime | TensorRT | NCNN | PPL | OpenVino | model config file(example) |
|:------------------------------|:-----------:|:--------:|:----:|:---:|:--------:|:-----------------------------------------------------------------------------------|
| FCN | Y | Y | Y | Y | ? | ${MMSEG_DIR}/configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py |
| PSPNet[*](#pspnet) | Y | Y | N | Y | ? | ${MMSEG_DIR}/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py |
| PSPNet[*](#pspnet) | Y | Y | N | Y | ? | ${MMSEG_DIR}/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py |
| DeepLabV3 | Y | Y | Y | Y | ? | ${MMSEG_DIR}/configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py |
| DeepLabV3+ | Y | Y | Y | Y | ? | ${MMSEG_DIR}/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py |

View File

@ -3,21 +3,38 @@ from typing import List
import mmcv
from mmdeploy.utils import get_input_shape
def get_input_shape_from_cfg(config: mmcv.Config) -> List[int]:
"""Get the input shape from the model config for OpenVINO Model Optimizer.
def get_input_shape_from_cfg(deploy_cfg: mmcv.Config,
model_cfg: mmcv.Config) -> List[int]:
"""Get the input shape from the configs for OpenVINO Model Optimizer. The
value from config 'deploy_cfg' has the highest priority, then 'model_cfg'.
If there is no input shape in configs, then the default value will be used.
Args:
config (mmcv.Config): Model config.
deploy_cfg (mmcv.Config): Deployment config.
model_cfg (mmcv.Config): Model config.
Returns:
List[int]: The input shape in [1, 3, H, W] format from config
or [1, 3, 800, 1344].
"""
shape = []
test_pipeline = config.get('test_pipeline', None)
if test_pipeline is not None:
img_scale = test_pipeline[1]['img_scale']
shape = [1, 3, img_scale[1], img_scale[0]]
shape = [1, 3]
is_use_deploy_cfg = False
try:
input_shape = get_input_shape(deploy_cfg)
if input_shape is not None:
is_use_deploy_cfg = True
except KeyError:
is_use_deploy_cfg = False
if is_use_deploy_cfg:
shape += [input_shape[1], input_shape[0]]
else:
shape = [1, 3, 800, 1344]
test_pipeline = model_cfg.get('test_pipeline', None)
if test_pipeline is not None:
img_scale = test_pipeline[1]['img_scale']
shape += [img_scale[1], img_scale[0]]
else:
shape += [800, 1344]
return shape

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
import subprocess
from subprocess import CalledProcessError, run
@ -16,7 +17,7 @@ def is_mo_available() -> bool:
"""
is_available = True
try:
run('mo.py -h',
run('mo -h',
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
shell=True,
@ -68,11 +69,11 @@ def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]],
f'--input="{input_names}" ' \
f'--input_shape="{input_shapes}" ' \
f'--disable_fusing '
command = f'mo.py {mo_args}'
print(f'Args for mo.py: {command}')
command = f'mo {mo_args}'
logging.info(f'Args for mo: {command}')
mo_output = run(command, capture_output=True, shell=True, check=True)
print(mo_output.stdout.decode())
print(mo_output.stderr.decode())
logging.info(mo_output.stdout.decode())
logging.debug(mo_output.stderr.decode())
model_xml = get_output_model_file(onnx_path, work_dir)
print(f'Successfully exported OpenVINO model: {model_xml}')
logging.info(f'Successfully exported OpenVINO model: {model_xml}')

View File

@ -91,6 +91,28 @@ class OpenVINOWrapper(BaseWrapper):
device_name=self.device.upper(),
num_requests=1)
def __process_outputs(
self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Converts tensors from 'torch' to 'numpy' and fixes the names of the
outputs.
Args:
outputs Dict[str, torch.Tensor]: The output name and tensor pairs.
Returns:
Dict[str, torch.Tensor]: The output name and tensor pairs
after processing.
"""
outputs = {
name: torch.from_numpy(tensor)
for name, tensor in outputs.items()
}
for output_name in outputs.keys():
if '.' in output_name:
new_output_name = output_name.split('.')[0]
outputs[new_output_name] = outputs.pop(output_name)
return outputs
def forward(self, inputs: Dict[str,
torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Run forward inference.
@ -104,8 +126,7 @@ class OpenVINOWrapper(BaseWrapper):
inputs = self.__update_device(inputs)
self.__reshape(inputs)
outputs = self.__openvino_execute(inputs)
for output_name, numpy_tensor in outputs.items():
outputs[output_name] = torch.from_numpy(numpy_tensor)
outputs = self.__process_outputs(outputs)
return outputs
@TimeCounter.count_time()

View File

@ -96,28 +96,33 @@ class End2EndModel(BaseBackendModel):
return outputs
@staticmethod
def __postprocessing_masks(det_bboxes: np.ndarray,
det_masks: np.ndarray,
img_w: int,
img_h: int,
mask_thr_binary: float = 0.5) -> np.ndarray:
def postprocessing_masks(det_bboxes: np.ndarray,
det_masks: np.ndarray,
img_w: int,
img_h: int,
mask_thr_binary: float = 0.5) -> np.ndarray:
"""Additional processing of masks. Resizes masks from [num_det, 28, 28]
to [num_det, img_w, img_h]. Analog of the 'mmdeploy.codebase.mmdet.
models.roi_heads.fcn_mask_head._do_paste_mask' function.
Args:
det_bboxes (np.ndarray): Bbox of shape [num_det, 5]
det_bboxes (np.ndarray): Bbox of shape [num_det, 4]
det_masks (np.ndarray): Masks of shape [num_det, 28, 28].
img_w (int): Width of the original image.
img_h (int): Height of the original image.
mask_thr_binary (float): The threshold for the mask.
Returns:
np.ndarray: masks of shape [N, num_det, img_w, img_h].
np.ndarray: masks of shape [N, num_det, img_h, img_w].
"""
masks = det_masks
bboxes = det_bboxes
num_det = bboxes.shape[0]
# Skip postprocessing if no detections are found.
if num_det == 0:
return np.zeros((0, img_h, img_w))
if isinstance(masks, np.ndarray):
masks = torch.tensor(masks)
bboxes = torch.tensor(bboxes)
@ -214,7 +219,7 @@ class End2EndModel(BaseBackendModel):
export_postprocess_mask = mmdet_deploy_cfg.get(
'export_postprocess_mask', True)
if not export_postprocess_mask:
masks = End2EndModel.__postprocessing_masks(
masks = End2EndModel.postprocessing_masks(
dets[:, :4], masks, ori_w, ori_h)
else:
masks = masks[:, :img_h, :img_w]
@ -225,7 +230,7 @@ class End2EndModel(BaseBackendModel):
masks = torch.nn.functional.interpolate(
masks.unsqueeze(0), size=(ori_h, ori_w))
masks = masks.squeeze(0).detach().numpy()
if masks.dtype != np.bool:
if masks.dtype != bool:
masks = masks >= 0.5
segms_results = [[] for _ in range(len(self.CLASSES))]
for j in range(len(dets)):

View File

@ -5,4 +5,5 @@ mmocr>=0.3.0
mmsegmentation
ncnn
onnxruntime>=1.8.0
openvino-dev[onnx,pytorch]
tensorrt

View File

@ -3,6 +3,7 @@ import os
import os.path as osp
import tempfile
import mmcv
import numpy as np
import pytest
import torch
@ -104,14 +105,25 @@ def test_get_input_shape_from_cfg():
from mmdeploy.apis.openvino import get_input_shape_from_cfg
# Test with default value
model_cfg = {}
input_shape = get_input_shape_from_cfg(model_cfg)
deploy_cfg = mmcv.Config()
model_cfg = mmcv.Config()
input_shape = get_input_shape_from_cfg(deploy_cfg, model_cfg)
assert input_shape == [1, 3, 800, 1344], \
'The function returned a different default shape.'
# Test with config that contains the required data.
# Test with model_cfg that contains the required data.
height, width = 800, 1200
model_cfg = {'test_pipeline': [{}, {'img_scale': (width, height)}]}
input_shape = get_input_shape_from_cfg(model_cfg)
model_cfg = mmcv.Config(
{'test_pipeline': [{}, {
'img_scale': (width, height)
}]})
input_shape = get_input_shape_from_cfg(deploy_cfg, model_cfg)
assert input_shape == [1, 3, height, width], \
'The shape in the config does not match the output shape.'
'The shape in the model_cfg does not match the output shape.'
# Test with deploy_cfg that contains the required data.
height, width = 600, 1000
deploy_cfg = mmcv.Config({'onnx_config': {'input_shape': (width, height)}})
input_shape = get_input_shape_from_cfg(deploy_cfg, model_cfg)
assert input_shape == [1, 3, height, width], \
'The shape in the deploy_cfg does not match the output shape.'

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import pytest

View File

@ -561,11 +561,8 @@ def test_cascade_roi_head(backend_type: Backend):
wrapped_model=wrapped_model,
model_inputs=model_inputs,
deploy_cfg=deploy_cfg)
processed_backend_outputs = []
if isinstance(backend_outputs, dict):
processed_backend_outputs = convert_to_list(backend_outputs,
output_names)
elif isinstance(backend_outputs, (list, tuple)) and \
if isinstance(backend_outputs, (list, tuple)) and \
backend_outputs[0].shape == (1, 0, 5):
processed_backend_outputs = torch.zeros((1, 80, 5))
else:

View File

@ -72,6 +72,23 @@ def test_init_backend_model(backend_model):
assert isinstance(backend_model, End2EndModel)
def test_can_postprocess_masks():
from mmdeploy.codebase.mmdet.deploy.object_detection_model \
import End2EndModel
num_dets = [0, 1, 5]
for num_det in num_dets:
det_bboxes = np.random.randn(num_det, 4)
det_masks = np.random.randn(num_det, 28, 28)
img_w, img_h = (30, 40)
masks = End2EndModel.postprocessing_masks(det_bboxes, det_masks, img_w,
img_h)
expected_shape = (num_det, img_h, img_w)
actual_shape = masks.shape
assert actual_shape == expected_shape, \
f'The expected shape of masks {expected_shape} ' \
f'did not match actual shape {actual_shape}.'
@pytest.mark.parametrize('device', ['cpu', 'cuda:0'])
def test_create_input(device):
if device == 'cuda:0' and not torch.cuda.is_available():

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch

View File

@ -7,11 +7,15 @@ import pytest
import torch
from mmocr.models.textdet.necks import FPNC
from mmdeploy.apis.onnxruntime import is_available as ort_available
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import Backend
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs)
onnxruntime_skip = not ort_available()
cuda_skip = not torch.cuda.is_available()
class FPNCNeckModel(FPNC):

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import pytest

View File

@ -234,7 +234,7 @@ def main():
for onnx_path in onnx_files:
model_xml_path = get_output_model_file(onnx_path, args.work_dir)
input_name = deploy_cfg.onnx_config.input_names
input_shape = [get_input_shape_from_cfg(model_cfg)]
input_shape = [get_input_shape_from_cfg(deploy_cfg, model_cfg)]
input_info = dict(zip(input_name, input_shape))
output_names = deploy_cfg.onnx_config.output_names
create_process(