[Fix] Fix errors about deploying MMYOLO-OpenVINO, DETR, ConvFormer and RTMDet (#1919)
* fix reg test yolox * fix detr * fix rtmdet-sdk reg * fix conformer precision * add conformer_cls sdk * add mmcls ut * fix detr ut * fix detr ut * fix lint * fix yapf * fix cls sdk * fix detr_head rewriter * fix interpolate * complement the mmdet ut * fix regression DETR" * fix ut * fix ut version * fix lintdev-1.x
parent
502692bfd3
commit
d76c7b61a5
|
@ -0,0 +1,6 @@
|
|||
_base_ = ['./base_dynamic.py', '../../_base_/backends/openvino.py']
|
||||
|
||||
onnx_config = dict(input_shape=None)
|
||||
|
||||
backend_config = dict(
|
||||
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))])
|
|
@ -0,0 +1 @@
|
|||
_base_ = ['../_base_/base_openvino_dynamic-640x640.py']
|
|
@ -85,6 +85,8 @@ class LinearClsHead : public MMClassification {
|
|||
};
|
||||
|
||||
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMClassification, LinearClsHead);
|
||||
using ConformerHead = LinearClsHead;
|
||||
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMClassification, ConformerHead);
|
||||
|
||||
class CropBox {
|
||||
public:
|
||||
|
|
|
@ -8,31 +8,6 @@ from torch.nn import functional as F
|
|||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.DETRHead.forward_single')
|
||||
def detrhead__forward_single__default(self, x, img_metas):
|
||||
"""forward_single of DETRHead.
|
||||
|
||||
Ease the mask computation
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
|
||||
x = self.input_proj(x)
|
||||
# interpolate masks to have the same spatial shape with x
|
||||
masks = x.new_zeros((batch_size, x.size(-2), x.size(-1))).to(torch.bool)
|
||||
|
||||
# position encoding
|
||||
pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
|
||||
# outs_dec: [nb_dec, bs, num_query, embed_dim]
|
||||
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
|
||||
pos_embed)
|
||||
all_cls_scores = self.fc_cls(outs_dec)
|
||||
all_bbox_preds = self.fc_reg(self.activate(
|
||||
self.reg_ffn(outs_dec))).sigmoid()
|
||||
return all_cls_scores, all_bbox_preds
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.DETRHead.predict_by_feat')
|
||||
def detrhead__predict_by_feat__default(self,
|
||||
|
@ -42,8 +17,8 @@ def detrhead__predict_by_feat__default(self,
|
|||
rescale: bool = True):
|
||||
"""Rewrite `predict_by_feat` of `FoveaHead` for default backend."""
|
||||
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
|
||||
cls_scores = all_cls_scores_list[-1][-1]
|
||||
bbox_preds = all_bbox_preds_list[-1][-1]
|
||||
cls_scores = all_cls_scores_list[-1]
|
||||
bbox_preds = all_bbox_preds_list[-1]
|
||||
|
||||
img_shape = batch_img_metas[0]['img_shape']
|
||||
max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0]))
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import single_stage, single_stage_instance_seg, two_stage
|
||||
from . import base_detr, single_stage, single_stage_instance_seg, two_stage
|
||||
|
||||
__all__ = ['single_stage', 'single_stage_instance_seg', 'two_stage']
|
||||
__all__ = [
|
||||
'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from mmdet.models.detectors.base import ForwardResults
|
||||
from mmdet.structures import DetDataSample
|
||||
from mmdet.structures.det_data_sample import OptSampleList
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@mark('detr_predict', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
||||
def __predict_impl(self, batch_inputs, data_samples, rescale):
|
||||
"""Rewrite and adding mark for `predict`.
|
||||
|
||||
Encapsulate this function for rewriting `predict` of DetectionTransformer.
|
||||
1. Add mark for DetectionTransformer.
|
||||
2. Support both dynamic and static export to onnx.
|
||||
"""
|
||||
img_feats = self.extract_feat(batch_inputs)
|
||||
head_inputs_dict = self.forward_transformer(img_feats, data_samples)
|
||||
results_list = self.bbox_head.predict(
|
||||
**head_inputs_dict, rescale=rescale, batch_data_samples=data_samples)
|
||||
return results_list
|
||||
|
||||
|
||||
@torch.fx.wrap
|
||||
def _set_metainfo(data_samples, img_shape):
|
||||
"""Set the metainfo.
|
||||
|
||||
Code in this function cannot be traced by fx.
|
||||
"""
|
||||
|
||||
# fx can not trace deepcopy correctly
|
||||
data_samples = copy.deepcopy(data_samples)
|
||||
if data_samples is None:
|
||||
data_samples = [DetDataSample()]
|
||||
|
||||
# note that we can not use `set_metainfo`, deepcopy would crash the
|
||||
# onnx trace.
|
||||
for data_sample in data_samples:
|
||||
data_sample.set_field(
|
||||
name='img_shape', value=img_shape, field_type='metainfo')
|
||||
|
||||
return data_samples
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.detectors.base_detr.DetectionTransformer.predict')
|
||||
def detection_transformer__predict(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
rescale: bool = True,
|
||||
**kwargs) -> ForwardResults:
|
||||
"""Rewrite `predict` for default backend.
|
||||
|
||||
Support configured dynamic/static shape for model input and return
|
||||
detection result as Tensor instead of numpy array.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`DetDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
||||
rescale (Boolean): rescale result or not.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: Detection results of the
|
||||
input images.
|
||||
- dets (Tensor): Classification bboxes and scores.
|
||||
Has a shape (num_instances, 5)
|
||||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
img_shape = torch._shape_as_tensor(batch_inputs)[2:]
|
||||
if not is_dynamic_flag:
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
|
||||
# set the metainfo
|
||||
data_samples = _set_metainfo(data_samples, img_shape)
|
||||
|
||||
return __predict_impl(self, batch_inputs, data_samples, rescale)
|
|
@ -81,7 +81,7 @@ def interpolate__tensorrt(
|
|||
size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int,
|
||||
int]]] = None,
|
||||
scale_factor: Optional[Union[float, Tuple[float]]] = None,
|
||||
mode: str = 'bilinear',
|
||||
mode: str = 'nearest',
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
):
|
||||
|
|
|
@ -250,7 +250,9 @@ models:
|
|||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp32
|
||||
- *pipeline_ncnn_static_fp32
|
||||
- *pipeline_openvino_dynamic_fp32
|
||||
- deploy_config: configs/mmdet/detection/detection_openvino_dynamic-640x640.py
|
||||
convert_image: *convert_image
|
||||
backend_test: False
|
||||
|
||||
- name: Faster R-CNN
|
||||
metafile: configs/faster_rcnn/metafile.yml
|
||||
|
@ -298,7 +300,10 @@ models:
|
|||
- configs/detr/detr_r50_8xb2-150e_coco.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
- deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py
|
||||
convert_image: *convert_image
|
||||
backend_test: *default_backend_test
|
||||
sdk_config: *sdk_dynamic
|
||||
|
||||
- name: CenterNet
|
||||
metafile: configs/centernet/metafile.yml
|
||||
|
@ -335,7 +340,7 @@ models:
|
|||
- configs/rtmdet/rtmdet_s_8xb32-300e_coco.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- deploy_config: configs/mmdet/detection/detection_tensorrt_static-640x640.py
|
||||
- deploy_config: configs/mmdet/detection/detection_tensorrt_dynamic-64x64-800x800.py
|
||||
convert_image: *convert_image
|
||||
backend_test: *default_backend_test
|
||||
sdk_config: *sdk_dynamic
|
||||
|
|
|
@ -29,6 +29,14 @@ def get_invertedresidual_model():
|
|||
return model
|
||||
|
||||
|
||||
def get_fcuup_model():
|
||||
from mmcls.models.backbones.conformer import FCUUp
|
||||
model = FCUUp(16, 16, 16)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_vit_backbone():
|
||||
from mmcls.models.classifiers.image import ImageClassifier
|
||||
model = ImageClassifier(
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
{
|
||||
"type": "DETR",
|
||||
"num_queries": 100,
|
||||
"data_preprocessor": {
|
||||
"type": "DetDataPreprocessor",
|
||||
"mean": [123.675, 116.28, 103.53],
|
||||
"std": [58.395, 57.12, 57.375],
|
||||
"bgr_to_rgb": true,
|
||||
"pad_size_divisor": 1
|
||||
},
|
||||
"backbone": {
|
||||
"type": "ResNet",
|
||||
"depth": 50,
|
||||
"num_stages": 4,
|
||||
"out_indices": [3],
|
||||
"frozen_stages": 1,
|
||||
"norm_cfg": {
|
||||
"type": "BN",
|
||||
"requires_grad": false
|
||||
},
|
||||
"norm_eval": true,
|
||||
"style": "pytorch",
|
||||
"init_cfg": {
|
||||
"type": "Pretrained",
|
||||
"checkpoint": "torchvision://resnet50"
|
||||
}
|
||||
},
|
||||
"neck": {
|
||||
"type": "ChannelMapper",
|
||||
"in_channels": [2048],
|
||||
"kernel_size": 1,
|
||||
"out_channels": 256,
|
||||
"num_outs": 1
|
||||
},
|
||||
"encoder": {
|
||||
"num_layers": 6,
|
||||
"layer_cfg": {
|
||||
"self_attn_cfg": {
|
||||
"embed_dims": 256,
|
||||
"num_heads": 8,
|
||||
"dropout": 0.1,
|
||||
"batch_first": true
|
||||
},
|
||||
"ffn_cfg": {
|
||||
"embed_dims": 256,
|
||||
"feedforward_channels": 2048,
|
||||
"num_fcs": 2,
|
||||
"ffn_drop": 0.1,
|
||||
"act_cfg": {
|
||||
"type": "ReLU",
|
||||
"inplace": true
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"decoder": {
|
||||
"num_layers": 6,
|
||||
"layer_cfg": {
|
||||
"self_attn_cfg": {
|
||||
"embed_dims": 256,
|
||||
"num_heads": 8,
|
||||
"dropout": 0.1,
|
||||
"batch_first": true
|
||||
},
|
||||
"cross_attn_cfg": {
|
||||
"embed_dims": 256,
|
||||
"num_heads": 8,
|
||||
"dropout": 0.1,
|
||||
"batch_first": true
|
||||
},
|
||||
"ffn_cfg": {
|
||||
"embed_dims": 256,
|
||||
"feedforward_channels": 2048,
|
||||
"num_fcs": 2,
|
||||
"ffn_drop": 0.1,
|
||||
"act_cfg": {
|
||||
"type": "ReLU",
|
||||
"inplace": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"return_intermediate": true
|
||||
},
|
||||
"positional_encoding": {
|
||||
"num_feats": 128,
|
||||
"normalize": true
|
||||
},
|
||||
"bbox_head": {
|
||||
"type": "DETRHead",
|
||||
"num_classes": 80,
|
||||
"embed_dims": 256,
|
||||
"loss_cls": {
|
||||
"type": "CrossEntropyLoss",
|
||||
"bg_cls_weight": 0.1,
|
||||
"use_sigmoid": false,
|
||||
"loss_weight": 1.0,
|
||||
"class_weight": 1.0
|
||||
},
|
||||
"loss_bbox": {
|
||||
"type": "L1Loss",
|
||||
"loss_weight": 5.0
|
||||
},
|
||||
"loss_iou": {
|
||||
"type": "GIoULoss",
|
||||
"loss_weight": 2.0
|
||||
}
|
||||
},
|
||||
"train_cfg": {
|
||||
"assigner": {
|
||||
"type":
|
||||
"HungarianAssigner",
|
||||
"match_costs": [{
|
||||
"type": "ClassificationCost",
|
||||
"weight": 1.0
|
||||
}, {
|
||||
"type": "BBoxL1Cost",
|
||||
"weight": 5.0,
|
||||
"box_format": "xywh"
|
||||
}, {
|
||||
"type": "IoUCost",
|
||||
"iou_mode": "giou",
|
||||
"weight": 2.0
|
||||
}]
|
||||
}
|
||||
},
|
||||
"test_cfg": {
|
||||
"max_per_img": 100
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ import mmengine
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from torch.testing import assert_close as torch_assert_close
|
||||
|
@ -691,6 +692,50 @@ def test_forward_of_base_detector(model_cfg_path, backend):
|
|||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.skipif(
|
||||
reason='mha only support torch greater than 1.10.0',
|
||||
condition=version.parse(torch.__version__) < version.parse('1.10.0'))
|
||||
@pytest.mark.parametrize(
|
||||
'model_cfg_path', ['tests/test_codebase/test_mmdet/data/detr_model.json'])
|
||||
def test_predict_of_detr_detector(model_cfg_path, backend):
|
||||
# Skip test when torch.__version__ < 1.10.0
|
||||
# See https://github.com/open-mmlab/mmdeploy/discussions/1434
|
||||
check_backend(backend)
|
||||
deploy_cfg = Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
output_names=['dets', 'labels'], input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
export_postprocess_mask=False,
|
||||
))))
|
||||
model_cfg = Config(dict(model=mmengine.load(model_cfg_path)))
|
||||
from mmdet.apis import init_detector
|
||||
model = init_detector(model_cfg, None, device='cpu', palette='coco')
|
||||
|
||||
img = torch.randn(1, 3, 64, 64)
|
||||
from mmdet.structures import DetDataSample
|
||||
data_sample = DetDataSample(metainfo=dict(batch_input_shape=(64, 64)))
|
||||
rewrite_inputs = {'batch_inputs': img}
|
||||
wrapped_model = WrapModel(model, 'predict', data_samples=[data_sample])
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type',
|
||||
[Backend.ONNXRUNTIME, Backend.OPENVINO])
|
||||
def test_single_roi_extractor(backend_type: Backend):
|
||||
|
@ -1995,7 +2040,7 @@ def test_mlvl_point_generator__single_level_grid_priors__tensorrt(
|
|||
@pytest.mark.parametrize('backend_type, ir_type',
|
||||
[(Backend.ONNXRUNTIME, 'onnx')])
|
||||
def test_detrhead__predict_by_feat(backend_type: Backend, ir_type: str):
|
||||
"""Test predict_by_feat rewrite of base dense head."""
|
||||
"""Test predict_by_feat rewrite of detr head."""
|
||||
check_backend(backend_type)
|
||||
dense_head = get_detrhead_model()
|
||||
dense_head.cpu().eval()
|
||||
|
@ -2009,9 +2054,9 @@ def test_detrhead__predict_by_feat(backend_type: Backend, ir_type: str):
|
|||
deploy_cfg = get_deploy_cfg(backend_type, ir_type)
|
||||
|
||||
seed_everything(1234)
|
||||
cls_score = [[torch.rand(1, 100, 5) for i in range(5, 0, -1)]]
|
||||
cls_score = [torch.rand(1, 100, 5) for i in range(5, 0, -1)]
|
||||
seed_everything(5678)
|
||||
bboxes = [[torch.rand(1, 100, 4) for i in range(5, 0, -1)]]
|
||||
bboxes = [torch.rand(1, 100, 4) for i in range(5, 0, -1)]
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
|
|
Loading…
Reference in New Issue