[Feature] Support RTMDet-R deployment (#1553)

* support rtmdet-r

* add comments

* fix dep config

* add ut

* fix ut

* fix ut

* fix url

* fix en url
pull/1583/head
Yanyi Liu 2022-12-26 09:52:44 +08:00 committed by GitHub
parent 85be66f7a6
commit 6288141bd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 261 additions and 0 deletions

View File

@ -0,0 +1,16 @@
_base_ = [
'./rotated-detection_static.py', '../_base_/backends/tensorrt-fp16.py'
]
onnx_config = dict(output_names=['dets', 'labels'], input_shape=(1024, 1024))
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 1024, 1024],
opt_shape=[1, 3, 1024, 1024],
max_shape=[1, 3, 1024, 1024])))
])

View File

@ -182,3 +182,4 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
| [Rotated FasterRCNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/rotated_faster_rcnn) | Y | Y |
| [Oriented R-CNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/oriented_rcnn) | Y | Y |
| [Gliding Vertex](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/gliding_vertex) | Y | Y |
| [RTMDET-R](https://github.com/open-mmlab/mmrotate/blob/dev-1.x/configs/rotated_rtmdet) | Y | Y |

View File

@ -186,3 +186,4 @@ det = detector(img)
| [Rotated FasterRCNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/rotated_faster_rcnn) | Y | Y |
| [Oriented R-CNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/oriented_rcnn) | Y | Y |
| [Gliding Vertex](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/gliding_vertex) | Y | Y |
| [RTMDET-R](https://github.com/open-mmlab/mmrotate/blob/dev-1.x/configs/rotated_rtmdet) | Y | Y |

View File

@ -1,2 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import oriented_rpn_head # noqa: F401, F403
from . import rotated_rtmdet_head # noqa: F401, F403

View File

@ -0,0 +1,119 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import torch
from mmengine.config import ConfigDict
from mmrotate.structures import norm_angle
from torch import Tensor
from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmcv.ops.nms_rotated import multiclass_nms_rotated
@FUNCTION_REWRITER.register_rewriter(
func_name='mmrotate.models.dense_heads.rotated_rtmdet_head.'
'RotatedRTMDetHead.predict_by_feat')
def rotated_rtmdet_head__predict_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
angle_preds: List[Tensor],
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> Tuple[Tensor]:
"""Rewrite `predict_by_feat` of `Rotated RTMDet` for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
angle_preds (list[Tensor]): Box angle for each scale level
with shape (batch_size, num_priors * angle_dim, H, W)
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 6) tensor,
where 5 represent (x, y, w, h, angle, score), N is batch
size and the score between 0 and 1. The shape of the second
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
ctx = FUNCTION_REWRITER.get_context()
assert len(cls_scores) == len(bbox_preds)
device = cls_scores[0].device
cfg = self.test_cfg if cfg is None else cfg
batch_size = bbox_preds[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, device=device)
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
for bbox_pred in bbox_preds
]
flatten_angle_preds = [
angle_pred.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.angle_coder.encode_size)
for angle_pred in angle_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_angle_preds = torch.cat(flatten_angle_preds, dim=1)
priors = torch.cat(mlvl_priors)
angle = self.angle_coder.decode(flatten_angle_preds, keepdim=True)
distance = flatten_bbox_preds
cos_angle, sin_angle = torch.cos(angle), torch.sin(angle)
rot_matrix = torch.cat([cos_angle, -sin_angle, sin_angle, cos_angle],
dim=-1)
rot_matrix = rot_matrix.reshape(*rot_matrix.shape[:-1], 2, 2)
wh = distance[..., :2] + distance[..., 2:]
offset_t = (distance[..., 2:] - distance[..., :2]) / 2
offset_t = offset_t.unsqueeze(-1)
offset = torch.matmul(rot_matrix, offset_t).squeeze(-1)
ctr = priors[..., :2] + offset
angle_regular = norm_angle(angle, self.angle_version)
bboxes = torch.cat([ctr, wh, angle_regular], dim=-1)
# directly multiply score factor and feed to nms
max_scores, _ = torch.max(flatten_cls_scores, 1)
mask = max_scores >= cfg.score_thr
scores = flatten_cls_scores.where(mask, flatten_cls_scores.new_zeros(1))
if not with_nms:
return bboxes, scores
deploy_cfg = ctx.cfg
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms_rotated(bboxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold, pre_top_k,
keep_top_k)

View File

@ -335,3 +335,126 @@ def test_gvfixcoder__decode(backend_type: Backend):
run_with_backend=False)
assert rewrite_outputs is not None
def get_rotated_rtmdet_head_model():
"""RTMDet-R Head Config."""
test_cfg = Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms_rotated', iou_threshold=0.1),
max_per_img=2000))
from mmrotate.models.dense_heads import RotatedRTMDetHead
model = RotatedRTMDetHead(
num_classes=4,
in_channels=1,
anchor_generator=dict(
type='mmdet.MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
bbox_coder=dict(type='DistanceAnglePointCoder', angle_version='le90'),
loss_cls=dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type='RotatedIoULoss', mode='linear', loss_weight=2.0),
test_cfg=test_cfg)
model.requires_grad_(False)
return model
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_rotated_rtmdet_head_predict_by_feat(backend_type: Backend):
"""Test predict_by_feat rewrite of RTMDet-R."""
check_backend(backend_type)
rtm_r_head = get_rotated_rtmdet_head_model()
rtm_r_head.cpu().eval()
s = 128
batch_img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['dets', 'labels']
deploy_cfg = Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmrotate',
task='RotatedDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.1,
pre_top_k=3000,
keep_top_k=2000,
max_output_boxes_per_class=2000))))
seed_everything(1234)
cls_scores = [
torch.rand(1, rtm_r_head.num_classes, 2 * pow(2, i), 2 * pow(2, i))
for i in range(3, 0, -1)
]
seed_everything(5678)
bbox_preds = [
torch.rand(1, 4, 2 * pow(2, i), 2 * pow(2, i))
for i in range(3, 0, -1)
]
seed_everything(9101)
angle_preds = [
torch.rand(1, rtm_r_head.angle_coder.encode_size, 2 * pow(2, i),
2 * pow(2, i)) for i in range(3, 0, -1)
]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_scores,
'bbox_preds': bbox_preds,
'angle_preds': angle_preds,
'batch_img_metas': batch_img_metas,
'with_nms': True
}
model_outputs = get_model_outputs(rtm_r_head, 'predict_by_feat',
model_inputs)
# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(
rtm_r_head,
'predict_by_feat',
batch_img_metas=batch_img_metas,
with_nms=True)
rewrite_inputs = {
'cls_scores': cls_scores,
'bbox_preds': bbox_preds,
'angle_preds': angle_preds,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
min_shape = min(model_outputs[0].bboxes.shape[0],
rewrite_outputs[0].shape[1], 5)
for i in range(len(model_outputs)):
assert np.allclose(
model_outputs[i].bboxes.tensor[:min_shape],
rewrite_outputs[0][i, :min_shape, :5],
rtol=1e-03,
atol=1e-05)
assert np.allclose(
model_outputs[i].scores[:min_shape],
rewrite_outputs[0][i, :min_shape, 5],
rtol=1e-03,
atol=1e-05)
assert np.allclose(
model_outputs[i].labels[:min_shape],
rewrite_outputs[1][i, :min_shape],
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None