[Feature] Support RTMDet-R deployment (#1553)
* support rtmdet-r * add comments * fix dep config * add ut * fix ut * fix ut * fix url * fix en urlpull/1583/head
parent
85be66f7a6
commit
6288141bd5
|
@ -0,0 +1,16 @@
|
|||
_base_ = [
|
||||
'./rotated-detection_static.py', '../_base_/backends/tensorrt-fp16.py'
|
||||
]
|
||||
|
||||
onnx_config = dict(output_names=['dets', 'labels'], input_shape=(1024, 1024))
|
||||
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 1024, 1024],
|
||||
opt_shape=[1, 3, 1024, 1024],
|
||||
max_shape=[1, 3, 1024, 1024])))
|
||||
])
|
|
@ -182,3 +182,4 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
|
|||
| [Rotated FasterRCNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/rotated_faster_rcnn) | Y | Y |
|
||||
| [Oriented R-CNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/oriented_rcnn) | Y | Y |
|
||||
| [Gliding Vertex](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/gliding_vertex) | Y | Y |
|
||||
| [RTMDET-R](https://github.com/open-mmlab/mmrotate/blob/dev-1.x/configs/rotated_rtmdet) | Y | Y |
|
||||
|
|
|
@ -186,3 +186,4 @@ det = detector(img)
|
|||
| [Rotated FasterRCNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/rotated_faster_rcnn) | Y | Y |
|
||||
| [Oriented R-CNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/oriented_rcnn) | Y | Y |
|
||||
| [Gliding Vertex](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/gliding_vertex) | Y | Y |
|
||||
| [RTMDET-R](https://github.com/open-mmlab/mmrotate/blob/dev-1.x/configs/rotated_rtmdet) | Y | Y |
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import oriented_rpn_head # noqa: F401, F403
|
||||
from . import rotated_rtmdet_head # noqa: F401, F403
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.config import ConfigDict
|
||||
from mmrotate.structures import norm_angle
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.codebase.mmdet import get_post_processing_params
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.mmcv.ops.nms_rotated import multiclass_nms_rotated
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmrotate.models.dense_heads.rotated_rtmdet_head.'
|
||||
'RotatedRTMDetHead.predict_by_feat')
|
||||
def rotated_rtmdet_head__predict_by_feat(
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
angle_preds: List[Tensor],
|
||||
batch_img_metas: Optional[List[dict]] = None,
|
||||
cfg: Optional[ConfigDict] = None,
|
||||
rescale: bool = False,
|
||||
with_nms: bool = True) -> Tuple[Tensor]:
|
||||
"""Rewrite `predict_by_feat` of `Rotated RTMDet` for default backend.
|
||||
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
cls_scores (list[Tensor]): Classification scores for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 4, H, W).
|
||||
angle_preds (list[Tensor]): Box angle for each scale level
|
||||
with shape (batch_size, num_priors * angle_dim, H, W)
|
||||
batch_img_metas (list[dict], Optional): Batch image meta info.
|
||||
Defaults to None.
|
||||
cfg (ConfigDict, optional): Test / postprocessing
|
||||
configuration, if None, test_cfg would be used.
|
||||
Defaults to None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Defaults to False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: The first item is an (N, num_box, 6) tensor,
|
||||
where 5 represent (x, y, w, h, angle, score), N is batch
|
||||
size and the score between 0 and 1. The shape of the second
|
||||
tensor in the tuple is (N, num_box), and each element
|
||||
represents the class label of the corresponding box.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
device = cls_scores[0].device
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
batch_size = bbox_preds[0].shape[0]
|
||||
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
||||
mlvl_priors = self.prior_generator.grid_priors(
|
||||
featmap_sizes, device=device)
|
||||
|
||||
flatten_cls_scores = [
|
||||
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
|
||||
self.cls_out_channels)
|
||||
for cls_score in cls_scores
|
||||
]
|
||||
flatten_bbox_preds = [
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
]
|
||||
flatten_angle_preds = [
|
||||
angle_pred.permute(0, 2, 3, 1).reshape(batch_size, -1,
|
||||
self.angle_coder.encode_size)
|
||||
for angle_pred in angle_preds
|
||||
]
|
||||
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
||||
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
||||
flatten_angle_preds = torch.cat(flatten_angle_preds, dim=1)
|
||||
priors = torch.cat(mlvl_priors)
|
||||
|
||||
angle = self.angle_coder.decode(flatten_angle_preds, keepdim=True)
|
||||
distance = flatten_bbox_preds
|
||||
cos_angle, sin_angle = torch.cos(angle), torch.sin(angle)
|
||||
|
||||
rot_matrix = torch.cat([cos_angle, -sin_angle, sin_angle, cos_angle],
|
||||
dim=-1)
|
||||
rot_matrix = rot_matrix.reshape(*rot_matrix.shape[:-1], 2, 2)
|
||||
|
||||
wh = distance[..., :2] + distance[..., 2:]
|
||||
offset_t = (distance[..., 2:] - distance[..., :2]) / 2
|
||||
offset_t = offset_t.unsqueeze(-1)
|
||||
offset = torch.matmul(rot_matrix, offset_t).squeeze(-1)
|
||||
ctr = priors[..., :2] + offset
|
||||
|
||||
angle_regular = norm_angle(angle, self.angle_version)
|
||||
bboxes = torch.cat([ctr, wh, angle_regular], dim=-1)
|
||||
|
||||
# directly multiply score factor and feed to nms
|
||||
max_scores, _ = torch.max(flatten_cls_scores, 1)
|
||||
mask = max_scores >= cfg.score_thr
|
||||
scores = flatten_cls_scores.where(mask, flatten_cls_scores.new_zeros(1))
|
||||
if not with_nms:
|
||||
return bboxes, scores
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
return multiclass_nms_rotated(bboxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, pre_top_k,
|
||||
keep_top_k)
|
|
@ -335,3 +335,126 @@ def test_gvfixcoder__decode(backend_type: Backend):
|
|||
run_with_backend=False)
|
||||
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
def get_rotated_rtmdet_head_model():
|
||||
"""RTMDet-R Head Config."""
|
||||
test_cfg = Config(
|
||||
dict(
|
||||
deploy_nms_pre=0,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms_rotated', iou_threshold=0.1),
|
||||
max_per_img=2000))
|
||||
|
||||
from mmrotate.models.dense_heads import RotatedRTMDetHead
|
||||
model = RotatedRTMDetHead(
|
||||
num_classes=4,
|
||||
in_channels=1,
|
||||
anchor_generator=dict(
|
||||
type='mmdet.MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
|
||||
bbox_coder=dict(type='DistanceAnglePointCoder', angle_version='le90'),
|
||||
loss_cls=dict(
|
||||
type='mmdet.QualityFocalLoss',
|
||||
use_sigmoid=True,
|
||||
beta=2.0,
|
||||
loss_weight=1.0),
|
||||
loss_bbox=dict(type='RotatedIoULoss', mode='linear', loss_weight=2.0),
|
||||
test_cfg=test_cfg)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_rotated_rtmdet_head_predict_by_feat(backend_type: Backend):
|
||||
"""Test predict_by_feat rewrite of RTMDet-R."""
|
||||
check_backend(backend_type)
|
||||
rtm_r_head = get_rotated_rtmdet_head_model()
|
||||
rtm_r_head.cpu().eval()
|
||||
s = 128
|
||||
batch_img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
output_names = ['dets', 'labels']
|
||||
deploy_cfg = Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=3000,
|
||||
keep_top_k=2000,
|
||||
max_output_boxes_per_class=2000))))
|
||||
seed_everything(1234)
|
||||
cls_scores = [
|
||||
torch.rand(1, rtm_r_head.num_classes, 2 * pow(2, i), 2 * pow(2, i))
|
||||
for i in range(3, 0, -1)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bbox_preds = [
|
||||
torch.rand(1, 4, 2 * pow(2, i), 2 * pow(2, i))
|
||||
for i in range(3, 0, -1)
|
||||
]
|
||||
seed_everything(9101)
|
||||
angle_preds = [
|
||||
torch.rand(1, rtm_r_head.angle_coder.encode_size, 2 * pow(2, i),
|
||||
2 * pow(2, i)) for i in range(3, 0, -1)
|
||||
]
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'cls_scores': cls_scores,
|
||||
'bbox_preds': bbox_preds,
|
||||
'angle_preds': angle_preds,
|
||||
'batch_img_metas': batch_img_metas,
|
||||
'with_nms': True
|
||||
}
|
||||
model_outputs = get_model_outputs(rtm_r_head, 'predict_by_feat',
|
||||
model_inputs)
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
wrapped_model = WrapModel(
|
||||
rtm_r_head,
|
||||
'predict_by_feat',
|
||||
batch_img_metas=batch_img_metas,
|
||||
with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_scores,
|
||||
'bbox_preds': bbox_preds,
|
||||
'angle_preds': angle_preds,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
# hard code to make two tensors with the same shape
|
||||
# rewrite and original codes applied different nms strategy
|
||||
min_shape = min(model_outputs[0].bboxes.shape[0],
|
||||
rewrite_outputs[0].shape[1], 5)
|
||||
for i in range(len(model_outputs)):
|
||||
assert np.allclose(
|
||||
model_outputs[i].bboxes.tensor[:min_shape],
|
||||
rewrite_outputs[0][i, :min_shape, :5],
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
assert np.allclose(
|
||||
model_outputs[i].scores[:min_shape],
|
||||
rewrite_outputs[0][i, :min_shape, 5],
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
assert np.allclose(
|
||||
model_outputs[i].labels[:min_shape],
|
||||
rewrite_outputs[1][i, :min_shape],
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
|
Loading…
Reference in New Issue