[Enhancement]: Support VFNet from MMDetection for OpenVINO and ONNX Runtime. (#195)

* Add deform_conv_openvino.

* Add get_bboxes_of_vfnet_head.

* Fix vfnet and add test_get_bboxes_of_vfnet_head.

* Update docs.

* Fix test_shufflenetv2_backbone__forward for openvino.

* Fixes.
This commit is contained in:
Semyon Bevzyuk 2021-11-16 05:59:59 +03:00 committed by GitHub
parent e142c72663
commit 49dd1cf678
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 350 additions and 29 deletions

View File

@ -24,7 +24,6 @@ python tools/deploy.py \
### List of supported models exportable to OpenVINO from MMDetection
The table below lists the models that are guaranteed to be exportable to OpenVINO from MMDetection.
| Model name | Config | Dynamic Shape |
| :----------------: | :-----------------------------------------------------------------------: | :-----------: |
| ATSS | `configs/atss/atss_r50_fpn_1x_coco.py` | Y |
@ -39,7 +38,10 @@ The table below lists the models that are guaranteed to be exportable to OpenVIN
| SSD | `configs/ssd/ssd300_coco.py` | Y |
| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y |
| YOLOX | `configs/yolox/yolox_tiny_8x8_300e_coco.py` | Y |
| Faster R-CNN + DCN | `configs/dcn/faster_rcnn_r50_fpn_dconv_c3-c5_1x_coco.py` | Y |
| VFNet | `configs/vfnet/vfnet_r50_fpn_1x_coco.py` | Y |
Notes:
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
- Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation.

View File

@ -100,6 +100,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| ATSS | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/atss/atss_r50_fpn_1x_coco.py |
| Cascade R-CNN | MMDetection | Y | ? | ? | Y | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py |
| Cascade Mask R-CNN | MMDetection | Y | ? | ? | Y | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py |
| VFNet | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/vfnet/vfnet_r50_fpn_1x_coco.py |
| ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnet/resnet18_b32x8_imagenet.py |
| ResNeXt | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnext/resnext50_32x4d_b32x8_imagenet.py |
| SE-ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/seresnet/seresnet50_b32x8_imagenet.py |

View File

@ -1,4 +1,5 @@
from .deform_conv import deform_conv_openvino
from .nms import * # noqa: F401,F403
from .roi_align import roi_align_default
__all__ = ['roi_align_default']
__all__ = ['roi_align_default', 'deform_conv_openvino']

View File

@ -0,0 +1,34 @@
from mmdeploy.core import SYMBOLIC_REWRITER
@SYMBOLIC_REWRITER.register_symbolic(
'mmcv.ops.deform_conv.DeformConv2dFunction', backend='openvino')
def deform_conv_openvino(ctx,
g,
input,
offset,
weight,
stride,
padding,
dilation,
groups,
deform_groups,
bias=False,
im2col_step=32):
"""Rewrite symbolic function for OpenVINO backend."""
assert not bias, 'The "bias" parameter should be False.'
assert groups == 1, 'The "groups" parameter should be 1.'
kh, kw = weight.type().sizes()[2:]
domain = 'org.openvinotoolkit'
op_name = 'DeformableConv2D'
return g.op(
f'{domain}::{op_name}',
input,
offset,
weight,
strides_i=stride,
pads_i=[p for pair in zip(padding, padding) for p in pair],
dilations_i=dilation,
groups_i=groups,
deformable_groups_i=deform_groups,
kernel_shape_i=[kh, kw])

View File

@ -3,12 +3,13 @@ from .atss_head import atss_head__get_bboxes
from .fcos_head import fcos_head__get_bboxes
from .fovea_head import fovea_head__get_bboxes
from .rpn_head import rpn_head__get_bboxes
from .vfnet_head import vfnet_head__get_bboxes
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
from .yolox_head import yolox_head__get_bboxes
__all__ = [
'anchor_head__get_bboxes', 'atss_head__get_bboxes',
'fcos_head__get_bboxes', 'fovea_head__get_bboxes', 'rpn_head__get_bboxes',
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn',
'yolox_head__get_bboxes'
'vfnet_head__get_bboxes', 'yolov3_head__get_bboxes',
'yolov3_head__get_bboxes__ncnn', 'yolox_head__get_bboxes'
]

View File

@ -0,0 +1,113 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmdet.core import distance2bbox, multiclass_nms
from mmdeploy.utils import get_mmdet_params
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.VFNetHead.get_bboxes')
def vfnet_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
bbox_preds_refine,
img_metas,
cfg=None,
rescale=None,
with_nms=True):
"""Rewrite `get_bboxes` of VFNetHead for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box iou-aware scores for each scale
level with shape (N, num_points * num_classes, H, W).
bbox_preds (list[Tensor]): Box offsets for each scale
level with shape (N, num_points * 4, H, W).
bbox_preds_refine (list[Tensor]): Refined Box offsets for
each scale level with shape (N, num_points * 4, H, W).
img_metas (dict): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before returning boxes.
Default: True.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine), \
'The lengths of lists "cls_scores", "bbox_preds", "bbox_preds_refine"'\
' should be the same.'
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
points_list = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
cls_score_list = [cls_scores[i].detach() for i in range(num_levels)]
bbox_pred_list = [bbox_preds_refine[i].detach() for i in range(num_levels)]
cfg = self.test_cfg if cfg is None else cfg
batch_size = cls_scores[0].shape[0]
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_bboxes = []
mlvl_scores = []
mlvl_points = []
for cls_score, bbox_pred, points, in zip(cls_score_list, bbox_pred_list,
points_list):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:], \
'The Height and Width should be the same.'
scores = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels).sigmoid()
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
points = points.expand(batch_size, -1, 2)
if pre_topk > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(batch_size).view(-1,
1).expand_as(topk_inds)
points = points[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
mlvl_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_points.append(points)
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_points = torch.cat(mlvl_points, dim=1)
batch_mlvl_bboxes = distance2bbox(
batch_mlvl_points, batch_mlvl_bboxes, max_shape=img_metas['img_shape'])
if rescale:
batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
img_metas['scale_factor'])
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
deploy_cfg = ctx.cfg
post_params = get_mmdet_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = cfg.get('nms_pre', post_params.pre_top_k)
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class, iou_threshold,
score_threshold, pre_top_k, keep_top_k)

View File

@ -94,7 +94,7 @@ def test_shufflenetv2_backbone__forward(backend_type):
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(input_shape=None),
onnx_config=dict(input_shape=None, output_names=['output']),
codebase_config=dict(type='mmcls', task='Classification')))
imgs = torch.rand((1, 16, 28, 28))

View File

@ -46,3 +46,48 @@ def test_ONNXNMSop(iou_threshold, score_threshold, max_output_boxes_per_class):
opset_version=11)
model = onnx.load(onnx_file_path)
assert model.graph.node[3].op_type == 'NonMaxSuppression'
def test_deform_conv_openvino():
pytest.importorskip('openvino', reason='requires openvino')
input = torch.Tensor([[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]])
offset = torch.Tensor([[[[1.7000, 2.9000], [3.4000, 4.8000]],
[[1.1000, 2.0000], [2.1000, 1.9000]],
[[3.1000, 5.1000], [5.9000, 4.9000]],
[[2.0000, 4.1000], [4.0000, 6.6000]],
[[1.6000, 2.7000], [3.8000, 3.1000]],
[[2.5000, 4.3000], [4.2000, 5.3000]],
[[1.7000, 3.3000], [3.6000, 4.5000]],
[[1.7000, 3.4000], [5.2000, 6.1000]]]])
expected_output = torch.Tensor([[[[1.6500, 0.0000], [0.0000, 0.0000]]]])
from mmcv.ops.deform_conv import DeformConv2dFunction
def wrapped_function(input, offset):
weight = torch.Tensor([[[[0.4000, 0.2000], [0.1000, 0.9000]]]])
stride = (1, 1)
padding = (0, 0)
dilation = (1, 1)
groups = 1
deform_groups = 1
return DeformConv2dFunction.apply(input, offset, weight, stride,
padding, dilation, groups,
deform_groups)
wrapped_model = WrapFunction(wrapped_function).eval()
model_output = wrapped_model(input, offset)
assert torch.allclose(expected_output, model_output)
onnx_file_path = tempfile.NamedTemporaryFile().name
with RewriterContext({}, backend='openvino'), torch.no_grad():
torch.onnx.export(
wrapped_model, (input, offset),
onnx_file_path,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'offset'],
output_names=['result'],
opset_version=11)
model = onnx.load(onnx_file_path)
assert model.graph.node[1].op_type == 'DeformableConv2D'
assert model.graph.node[1].domain == 'org.openvinotoolkit'

View File

@ -3,6 +3,7 @@ import importlib
import os
import random
import tempfile
from typing import Dict, List
import mmcv
import numpy as np
@ -27,6 +28,18 @@ def seed_everything(seed=1029):
torch.backends.cudnn.enabled = False
def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List:
"""Converts output from a dictionary to a list.
The new list will contain only those output values, whose names are in list
'output_names'.
"""
outputs = [
value for name, value in rewrite_output.items() if name in output_names
]
return outputs
def get_anchor_head_model():
"""AnchorHead Config."""
test_cfg = mmcv.Config(
@ -152,10 +165,7 @@ def test_anchor_head_get_bboxes(backend_type):
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = [
value for name, value in rewrite_outputs.items()
if name in output_names
]
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
@ -507,10 +517,8 @@ def test_cascade_roi_head(backend_type):
deploy_cfg=deploy_cfg)
processed_backend_outputs = []
if isinstance(backend_outputs, dict):
processed_backend_outputs = [
backend_outputs[name] for name in output_names
if name in backend_outputs
]
processed_backend_outputs = convert_to_list(backend_outputs,
output_names)
elif isinstance(backend_outputs, (list, tuple)) and \
backend_outputs[0].shape == (1, 0, 5):
processed_backend_outputs = np.zeros((1, 80, 5))
@ -601,10 +609,7 @@ def test_get_bboxes_of_fovea_head(backend_type):
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = [
value for name, value in rewrite_outputs.items()
if name in output_names
]
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
@ -716,10 +721,7 @@ def test_get_bboxes_of_atss_head(backend_type):
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = [
value for name, value in rewrite_outputs.items()
if name in output_names
]
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
@ -863,10 +865,7 @@ def test_yolov3_head_get_bboxes(backend_type):
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = [
value for name, value in rewrite_outputs.items()
if name in output_names
]
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
@ -965,10 +964,7 @@ def test_yolox_head_get_bboxes(backend_type):
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = [
value for name, value in rewrite_outputs.items()
if name in output_names
]
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
@ -983,3 +979,131 @@ def test_yolox_head_get_bboxes(backend_type):
atol=1e-05)
else:
assert rewrite_outputs is not None
def get_vfnet_head_model():
"""VFNet Head Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import VFNetHead
model = VFNetHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
model.cpu().eval()
return model
@pytest.mark.parametrize('backend_type', ['openvino'])
def test_get_bboxes_of_vfnet_head(backend_type):
"""Test get_bboxes rewrite of VFNet head."""
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
class TestModel(torch.nn.Module):
"""Stub for VFNetHead with fake bbox_preds operations.
Then bbox_preds will be one of the inputs to the ONNX graph.
"""
def __init__(self, vfnet_head):
super().__init__()
self.vfnet_head = vfnet_head
def get_bboxes(self,
cls_scores,
bbox_preds,
bbox_preds_refine,
img_metas,
cfg=None,
rescale=None,
with_nms=True):
tmp_bbox_pred_refine = []
for bbox_pred, bbox_pred_refine in zip(bbox_preds,
bbox_preds_refine):
tmp = bbox_pred_refine + bbox_pred
tmp = tmp - bbox_pred
tmp_bbox_pred_refine.append(tmp)
bbox_preds_refine = tmp_bbox_pred_refine
return self.vfnet_head.get_bboxes(cls_scores, bbox_preds,
bbox_preds_refine, img_metas,
cfg, rescale, with_nms)
test_model = TestModel(get_vfnet_head_model())
test_model.requires_grad_(False)
test_model.cpu().eval()
s = 16
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['dets', 'labels']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
seed_everything(1234)
cls_score = [
torch.rand(1, test_model.vfnet_head.num_classes, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
seed_everything(9101)
bbox_preds_refine = [
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'bbox_preds_refine': bbox_preds_refine,
'img_metas': img_metas
}
model_outputs = get_model_outputs(test_model, 'get_bboxes', model_inputs)
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
test_model, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'bbox_preds_refine': bbox_preds_refine
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
min_shape = min(model_output.shape[0], rewrite_output.shape[0])
assert np.allclose(
model_output[:min_shape],
rewrite_output[:min_shape],
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None