[Fix] Fix errors about deploying MMYOLO-OpenVINO, DETR, ConvFormer and RTMDet (#1919)

* fix reg test yolox

* fix detr

* fix rtmdet-sdk reg

* fix conformer precision

* add conformer_cls sdk

* add mmcls ut

* fix detr ut

* fix detr ut

* fix lint

* fix yapf

* fix cls sdk

* fix detr_head rewriter

* fix interpolate

* complement the mmdet ut

* fix regression DETR"

* fix ut

* fix ut version

* fix lint
dev-1.x
hanrui1sensetime 2023-03-31 13:45:15 +08:00 committed by GitHub
parent 502692bfd3
commit d76c7b61a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 298 additions and 36 deletions

View File

@ -0,0 +1,6 @@
_base_ = ['./base_dynamic.py', '../../_base_/backends/openvino.py']
onnx_config = dict(input_shape=None)
backend_config = dict(
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))])

View File

@ -0,0 +1 @@
_base_ = ['../_base_/base_openvino_dynamic-640x640.py']

View File

@ -85,6 +85,8 @@ class LinearClsHead : public MMClassification {
};
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMClassification, LinearClsHead);
using ConformerHead = LinearClsHead;
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMClassification, ConformerHead);
class CropBox {
public:

View File

@ -8,31 +8,6 @@ from torch.nn import functional as F
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DETRHead.forward_single')
def detrhead__forward_single__default(self, x, img_metas):
"""forward_single of DETRHead.
Ease the mask computation
"""
batch_size = x.size(0)
x = self.input_proj(x)
# interpolate masks to have the same spatial shape with x
masks = x.new_zeros((batch_size, x.size(-2), x.size(-1))).to(torch.bool)
# position encoding
pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
# outs_dec: [nb_dec, bs, num_query, embed_dim]
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
pos_embed)
all_cls_scores = self.fc_cls(outs_dec)
all_bbox_preds = self.fc_reg(self.activate(
self.reg_ffn(outs_dec))).sigmoid()
return all_cls_scores, all_bbox_preds
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DETRHead.predict_by_feat')
def detrhead__predict_by_feat__default(self,
@ -42,8 +17,8 @@ def detrhead__predict_by_feat__default(self,
rescale: bool = True):
"""Rewrite `predict_by_feat` of `FoveaHead` for default backend."""
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
cls_scores = all_cls_scores_list[-1][-1]
bbox_preds = all_bbox_preds_list[-1][-1]
cls_scores = all_cls_scores_list[-1]
bbox_preds = all_bbox_preds_list[-1]
img_shape = batch_img_metas[0]['img_shape']
max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0]))

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import single_stage, single_stage_instance_seg, two_stage
from . import base_detr, single_stage, single_stage_instance_seg, two_stage
__all__ = ['single_stage', 'single_stage_instance_seg', 'two_stage']
__all__ = [
'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage'
]

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
from mmdet.models.detectors.base import ForwardResults
from mmdet.structures import DetDataSample
from mmdet.structures.det_data_sample import OptSampleList
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils import is_dynamic_shape
@mark('detr_predict', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __predict_impl(self, batch_inputs, data_samples, rescale):
"""Rewrite and adding mark for `predict`.
Encapsulate this function for rewriting `predict` of DetectionTransformer.
1. Add mark for DetectionTransformer.
2. Support both dynamic and static export to onnx.
"""
img_feats = self.extract_feat(batch_inputs)
head_inputs_dict = self.forward_transformer(img_feats, data_samples)
results_list = self.bbox_head.predict(
**head_inputs_dict, rescale=rescale, batch_data_samples=data_samples)
return results_list
@torch.fx.wrap
def _set_metainfo(data_samples, img_shape):
"""Set the metainfo.
Code in this function cannot be traced by fx.
"""
# fx can not trace deepcopy correctly
data_samples = copy.deepcopy(data_samples)
if data_samples is None:
data_samples = [DetDataSample()]
# note that we can not use `set_metainfo`, deepcopy would crash the
# onnx trace.
for data_sample in data_samples:
data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo')
return data_samples
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base_detr.DetectionTransformer.predict')
def detection_transformer__predict(self,
batch_inputs: torch.Tensor,
data_samples: OptSampleList = None,
rescale: bool = True,
**kwargs) -> ForwardResults:
"""Rewrite `predict` for default backend.
Support configured dynamic/static shape for model input and return
detection result as Tensor instead of numpy array.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (Boolean): rescale result or not.
Returns:
tuple[Tensor]: Detection results of the
input images.
- dets (Tensor): Classification bboxes and scores.
Has a shape (num_instances, 5)
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
"""
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
# get origin input shape as tensor to support onnx dynamic shape
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
img_shape = torch._shape_as_tensor(batch_inputs)[2:]
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]
# set the metainfo
data_samples = _set_metainfo(data_samples, img_shape)
return __predict_impl(self, batch_inputs, data_samples, rescale)

View File

@ -81,7 +81,7 @@ def interpolate__tensorrt(
size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int,
int]]] = None,
scale_factor: Optional[Union[float, Tuple[float]]] = None,
mode: str = 'bilinear',
mode: str = 'nearest',
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
):

View File

@ -250,7 +250,9 @@ models:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32
- *pipeline_ncnn_static_fp32
- *pipeline_openvino_dynamic_fp32
- deploy_config: configs/mmdet/detection/detection_openvino_dynamic-640x640.py
convert_image: *convert_image
backend_test: False
- name: Faster R-CNN
metafile: configs/faster_rcnn/metafile.yml
@ -298,7 +300,10 @@ models:
- configs/detr/detr_r50_8xb2-150e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16
- deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py
convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic
- name: CenterNet
metafile: configs/centernet/metafile.yml
@ -335,7 +340,7 @@ models:
- configs/rtmdet/rtmdet_s_8xb32-300e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- deploy_config: configs/mmdet/detection/detection_tensorrt_static-640x640.py
- deploy_config: configs/mmdet/detection/detection_tensorrt_dynamic-64x64-800x800.py
convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic

View File

@ -29,6 +29,14 @@ def get_invertedresidual_model():
return model
def get_fcuup_model():
from mmcls.models.backbones.conformer import FCUUp
model = FCUUp(16, 16, 16)
model.requires_grad_(False)
return model
def get_vit_backbone():
from mmcls.models.classifiers.image import ImageClassifier
model = ImageClassifier(

View File

@ -0,0 +1,129 @@
{
"type": "DETR",
"num_queries": 100,
"data_preprocessor": {
"type": "DetDataPreprocessor",
"mean": [123.675, 116.28, 103.53],
"std": [58.395, 57.12, 57.375],
"bgr_to_rgb": true,
"pad_size_divisor": 1
},
"backbone": {
"type": "ResNet",
"depth": 50,
"num_stages": 4,
"out_indices": [3],
"frozen_stages": 1,
"norm_cfg": {
"type": "BN",
"requires_grad": false
},
"norm_eval": true,
"style": "pytorch",
"init_cfg": {
"type": "Pretrained",
"checkpoint": "torchvision://resnet50"
}
},
"neck": {
"type": "ChannelMapper",
"in_channels": [2048],
"kernel_size": 1,
"out_channels": 256,
"num_outs": 1
},
"encoder": {
"num_layers": 6,
"layer_cfg": {
"self_attn_cfg": {
"embed_dims": 256,
"num_heads": 8,
"dropout": 0.1,
"batch_first": true
},
"ffn_cfg": {
"embed_dims": 256,
"feedforward_channels": 2048,
"num_fcs": 2,
"ffn_drop": 0.1,
"act_cfg": {
"type": "ReLU",
"inplace": true
}
}
}
},
"decoder": {
"num_layers": 6,
"layer_cfg": {
"self_attn_cfg": {
"embed_dims": 256,
"num_heads": 8,
"dropout": 0.1,
"batch_first": true
},
"cross_attn_cfg": {
"embed_dims": 256,
"num_heads": 8,
"dropout": 0.1,
"batch_first": true
},
"ffn_cfg": {
"embed_dims": 256,
"feedforward_channels": 2048,
"num_fcs": 2,
"ffn_drop": 0.1,
"act_cfg": {
"type": "ReLU",
"inplace": true
}
}
},
"return_intermediate": true
},
"positional_encoding": {
"num_feats": 128,
"normalize": true
},
"bbox_head": {
"type": "DETRHead",
"num_classes": 80,
"embed_dims": 256,
"loss_cls": {
"type": "CrossEntropyLoss",
"bg_cls_weight": 0.1,
"use_sigmoid": false,
"loss_weight": 1.0,
"class_weight": 1.0
},
"loss_bbox": {
"type": "L1Loss",
"loss_weight": 5.0
},
"loss_iou": {
"type": "GIoULoss",
"loss_weight": 2.0
}
},
"train_cfg": {
"assigner": {
"type":
"HungarianAssigner",
"match_costs": [{
"type": "ClassificationCost",
"weight": 1.0
}, {
"type": "BBoxL1Cost",
"weight": 5.0,
"box_format": "xywh"
}, {
"type": "IoUCost",
"iou_mode": "giou",
"weight": 2.0
}]
}
},
"test_cfg": {
"max_per_img": 100
}
}

View File

@ -9,6 +9,7 @@ import mmengine
import numpy as np
import pytest
import torch
from packaging import version
try:
from torch.testing import assert_close as torch_assert_close
@ -691,6 +692,50 @@ def test_forward_of_base_detector(model_cfg_path, backend):
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
@pytest.mark.skipif(
reason='mha only support torch greater than 1.10.0',
condition=version.parse(torch.__version__) < version.parse('1.10.0'))
@pytest.mark.parametrize(
'model_cfg_path', ['tests/test_codebase/test_mmdet/data/detr_model.json'])
def test_predict_of_detr_detector(model_cfg_path, backend):
# Skip test when torch.__version__ < 1.10.0
# See https://github.com/open-mmlab/mmdeploy/discussions/1434
check_backend(backend)
deploy_cfg = Config(
dict(
backend_config=dict(type=backend.value),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
export_postprocess_mask=False,
))))
model_cfg = Config(dict(model=mmengine.load(model_cfg_path)))
from mmdet.apis import init_detector
model = init_detector(model_cfg, None, device='cpu', palette='coco')
img = torch.randn(1, 3, 64, 64)
from mmdet.structures import DetDataSample
data_sample = DetDataSample(metainfo=dict(batch_input_shape=(64, 64)))
rewrite_inputs = {'batch_inputs': img}
wrapped_model = WrapModel(model, 'predict', data_samples=[data_sample])
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type',
[Backend.ONNXRUNTIME, Backend.OPENVINO])
def test_single_roi_extractor(backend_type: Backend):
@ -1995,7 +2040,7 @@ def test_mlvl_point_generator__single_level_grid_priors__tensorrt(
@pytest.mark.parametrize('backend_type, ir_type',
[(Backend.ONNXRUNTIME, 'onnx')])
def test_detrhead__predict_by_feat(backend_type: Backend, ir_type: str):
"""Test predict_by_feat rewrite of base dense head."""
"""Test predict_by_feat rewrite of detr head."""
check_backend(backend_type)
dense_head = get_detrhead_model()
dense_head.cpu().eval()
@ -2009,9 +2054,9 @@ def test_detrhead__predict_by_feat(backend_type: Backend, ir_type: str):
deploy_cfg = get_deploy_cfg(backend_type, ir_type)
seed_everything(1234)
cls_score = [[torch.rand(1, 100, 5) for i in range(5, 0, -1)]]
cls_score = [torch.rand(1, 100, 5) for i in range(5, 0, -1)]
seed_everything(5678)
bboxes = [[torch.rand(1, 100, 4) for i in range(5, 0, -1)]]
bboxes = [torch.rand(1, 100, 4) for i in range(5, 0, -1)]
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])