diff --git a/docs/en/benchmark.md b/docs/en/benchmark.md
index 703c77e3a..03931e9a4 100644
--- a/docs/en/benchmark.md
+++ b/docs/en/benchmark.md
@@ -1954,9 +1954,9 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
model config file |
- RotatedRetinaNet |
- Rotated Detection |
- DOTA-v1.0 |
+ RotatedRetinaNet |
+ Rotated Detection |
+ DOTA-v1.0 |
mAP |
0.698 |
0.698 |
@@ -1964,7 +1964,20 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
0.697 |
- |
- |
- $MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py |
+ $MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py |
+
+
+ Oriented RCNN |
+ Rotated Detection |
+ DOTA-v1.0 |
+ mAP |
+ 0.756 |
+ 0.756 |
+ - |
+ - |
+ - |
+ - |
+ $MMROTATE_DIR/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_le90.py |
diff --git a/docs/en/codebases/mmrotate.md b/docs/en/codebases/mmrotate.md
index aa6705875..efa675902 100644
--- a/docs/en/codebases/mmrotate.md
+++ b/docs/en/codebases/mmrotate.md
@@ -11,6 +11,7 @@ Please refer to [official installation guide](https://mmrotate.readthedocs.io/en
| Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config |
|:----------------------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:|
| RotatedRetinaNet | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
+| Oriented RCNN | RotatedDetection | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
### Example
diff --git a/docs/zh_cn/03-benchmark/benchmark.md b/docs/zh_cn/03-benchmark/benchmark.md
index 40ef70bfa..653eb90c4 100644
--- a/docs/zh_cn/03-benchmark/benchmark.md
+++ b/docs/zh_cn/03-benchmark/benchmark.md
@@ -1557,9 +1557,9 @@ GPU: ncnn, TensorRT, PPLNN
fp32 |
- RotatedRetinaNet |
- Rotated Detection |
- DOTA-v1.0 |
+ RotatedRetinaNet |
+ Rotated Detection |
+ DOTA-v1.0 |
mAP |
0.698 |
0.698 |
@@ -1568,6 +1568,18 @@ GPU: ncnn, TensorRT, PPLNN
- |
- |
+
+ Oriented RCNN |
+ Rotated Detection |
+ DOTA-v1.0 |
+ mAP |
+ 0.756 |
+ 0.756 |
+ - |
+ - |
+ - |
+ - |
+
diff --git a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py
index a42a3cf2d..2933ca8be 100644
--- a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py
+++ b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py
@@ -1,2 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .delta_midpointoffset_rbbox_coder import * # noqa: F401,F403
from .delta_xywha_rbbox_coder import * # noqa: F401,F403
diff --git a/mmdeploy/codebase/mmrotate/core/bbox/delta_midpointoffset_rbbox_coder.py b/mmdeploy/codebase/mmrotate/core/bbox/delta_midpointoffset_rbbox_coder.py
new file mode 100644
index 000000000..92098c4ba
--- /dev/null
+++ b/mmdeploy/codebase/mmrotate/core/bbox/delta_midpointoffset_rbbox_coder.py
@@ -0,0 +1,103 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+from mmrotate.core import poly2obb
+
+from mmdeploy.core import FUNCTION_REWRITER
+
+
+@FUNCTION_REWRITER.register_rewriter(
+ func_name='mmrotate.core.bbox.coder.delta_midpointoffset_rbbox_coder'
+ '.delta2bbox',
+ backend='default')
+def delta2bbox(ctx,
+ rois,
+ deltas,
+ means=(0., 0., 0., 0., 0., 0.),
+ stds=(1., 1., 1., 1., 1., 1.),
+ wh_ratio_clip=16 / 1000,
+ version='oc'):
+ """Apply deltas to shift/scale base boxes.
+
+ Typically the rois are anchor or proposed bounding boxes and the deltas
+ are network outputs used to shift/scale those boxes. This is the inverse
+ function of :func:`bbox2delta`.
+
+
+ Args:
+ rois (torch.Tensor): Boxes to be transformed. Has shape (N, 6).
+ deltas (torch.Tensor): Encoded offsets relative to each roi.
+ Has shape (N, num_classes * 6) or (N, 6). Note
+ N = num_base_anchors * W * H, when rois is a grid of
+ anchors.
+ means (Sequence[float]): Denormalizing means for delta coordinates.
+ Default (0., 0., 0., 0., 0., 0.).
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates. Default (1., 1., 1., 1., 1., 1.).
+ wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
+ 16 / 1000.
+ version (str, optional): Angle representations. Defaults to 'oc'.
+
+ Returns:
+ Tensor: Boxes with shape (N, num_classes * 5) or (N, 5), where 5
+ represent cx, cy, w, h, a.
+ """
+ means = deltas.new_tensor(means).view(1, -1)
+ stds = deltas.new_tensor(stds).view(1, -1)
+ delta_shape = deltas.shape
+ reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 6))
+ denorm_deltas = reshaped_deltas * stds + means
+
+ # means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 6)
+ # stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 6)
+ # denorm_deltas = deltas * stds + means
+ dx = denorm_deltas[..., 0::6]
+ dy = denorm_deltas[..., 1::6]
+ dw = denorm_deltas[..., 2::6]
+ dh = denorm_deltas[..., 3::6]
+ da = denorm_deltas[..., 4::6]
+ db = denorm_deltas[..., 5::6]
+ max_ratio = np.abs(np.log(wh_ratio_clip))
+ dw = dw.clamp(min=-max_ratio, max=max_ratio)
+ dh = dh.clamp(min=-max_ratio, max=max_ratio)
+ # Compute center of each roi
+ px = ((rois[..., None, None, 0] + rois[..., None, None, 2]) * 0.5)
+ py = ((rois[..., None, None, 1] + rois[..., None, None, 3]) * 0.5)
+ # Compute width/height of each roi
+ pw = (rois[..., None, None, 2] - rois[..., None, None, 0])
+ ph = (rois[..., None, None, 3] - rois[..., None, None, 1])
+ # Use exp(network energy) to enlarge/shrink each roi
+ gw = pw * dw.exp()
+ gh = ph * dh.exp()
+ # Use network energy to shift the center of each roi
+ gx = px + pw * dx
+ gy = py + ph * dy
+
+ x1 = gx - gw * 0.5
+ y1 = gy - gh * 0.5
+ x2 = gx + gw * 0.5
+ y2 = gy + gh * 0.5
+
+ da = da.clamp(min=-0.5, max=0.5)
+ db = db.clamp(min=-0.5, max=0.5)
+ ga = gx + da * gw
+ _ga = gx - da * gw
+ gb = gy + db * gh
+ _gb = gy - db * gh
+ polys = torch.stack([ga, y1, x2, gb, _ga, y2, x1, _gb], dim=-1)
+
+ center = torch.stack([gx, gy, gx, gy, gx, gy, gx, gy], dim=-1)
+ center_polys = polys - center
+ diag_len = torch.sqrt(center_polys[..., 0::2] * center_polys[..., 0::2] +
+ center_polys[..., 1::2] * center_polys[..., 1::2])
+ max_diag_len, _ = torch.max(diag_len, dim=-1, keepdim=True)
+ diag_scale_factor = max_diag_len / diag_len
+ center_polys_shape = center_polys.shape
+ center_polys = center_polys.view(*center_polys_shape[:3], 4,
+ -1) * diag_scale_factor.view(
+ *center_polys_shape[:3], 4, 1)
+ center_polys = center_polys.view(center_polys_shape)
+ rectpolys = center_polys + center
+ obboxes = poly2obb(rectpolys, version).view(delta_shape[:-1] + (5, ))
+
+ return obboxes
diff --git a/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py b/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py
index 01c3b72a4..f3fc975fb 100644
--- a/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py
+++ b/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from .bbox_nms import multiclass_nms_rotated
+from .bbox_nms import fake_multiclass_nms_rotated, multiclass_nms_rotated
-__all__ = ['multiclass_nms_rotated']
+__all__ = ['multiclass_nms_rotated', 'fake_multiclass_nms_rotated']
diff --git a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py
index 5d9cf6efa..7c884937a 100644
--- a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py
+++ b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py
@@ -1,17 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
+from mmrotate.core import obb2xyxy
from torch import Tensor
import mmdeploy
from mmdeploy.core import FUNCTION_REWRITER, mark
-from mmdeploy.mmcv.ops import ONNXNMSRotatedOp, TRTBatchedRotatedNMSop
+from mmdeploy.mmcv.ops import (ONNXNMSop, ONNXNMSRotatedOp,
+ TRTBatchedRotatedNMSop)
-def select_nms_index(scores: torch.Tensor,
- boxes: torch.Tensor,
- nms_index: torch.Tensor,
- batch_size: int,
- keep_top_k: int = -1):
+def select_rnms_index(scores: torch.Tensor,
+ boxes: torch.Tensor,
+ nms_index: torch.Tensor,
+ batch_size: int,
+ keep_top_k: int = -1):
"""Transform NMSRotated output.
Args:
@@ -85,6 +87,7 @@ def _multiclass_nms_rotated(boxes: Tensor,
op. It only supports class-agnostic detection results. That is, the scores
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
(N, num_boxes, 5).
+
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5].
scores (Tensor): The detection scores of shape
@@ -114,7 +117,7 @@ def _multiclass_nms_rotated(boxes: Tensor,
selected_indices = ONNXNMSRotatedOp.apply(boxes, scores, iou_threshold,
score_threshold)
- dets, labels = select_nms_index(
+ dets, labels = select_rnms_index(
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
return dets, labels
@@ -173,3 +176,47 @@ def multiclass_nms_rotated(*args, **kwargs):
"""Wrapper function for `_multiclass_nms`."""
return mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms.\
_multiclass_nms_rotated(*args, **kwargs)
+
+
+@mark(
+ 'fake_multiclass_nms_rotated',
+ inputs=['boxes', 'scores'],
+ outputs=['dets', 'labels'])
+def fake_multiclass_nms_rotated(boxes: Tensor,
+ scores: Tensor,
+ max_output_boxes_per_class: int = 1000,
+ iou_threshold: float = 0.5,
+ score_threshold: float = 0.0,
+ pre_top_k: int = -1,
+ keep_top_k: int = -1,
+ version: str = 'le90'):
+ """Fake NMSRotated for multi-class bboxes which use horizontal bboxes for
+ NMS, but return the rotated bboxes result.
+
+ This function helps exporting to onnx with batch and multiclass NMS op. It
+ only supports class-agnostic detection results. That is, the scores is of
+ shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
+ 5).
+ """
+ max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
+ iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
+ score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
+ batch_size = scores.shape[0]
+
+ if pre_top_k > 0:
+ max_scores, _ = scores.max(-1)
+ _, topk_inds = max_scores.topk(pre_top_k)
+ batch_inds = torch.arange(batch_size).view(-1, 1).long()
+ boxes = boxes[batch_inds, topk_inds, :]
+ scores = scores[batch_inds, topk_inds, :]
+
+ scores = scores.permute(0, 2, 1)
+ hboxes = obb2xyxy(boxes, version)
+ selected_indices = ONNXNMSop.apply(hboxes, scores,
+ max_output_boxes_per_class,
+ iou_threshold, score_threshold)
+
+ dets, labels = select_rnms_index(
+ scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
+
+ return dets, labels
diff --git a/mmdeploy/codebase/mmrotate/models/__init__.py b/mmdeploy/codebase/mmrotate/models/__init__.py
index 6fe59fd52..32a7d21e7 100644
--- a/mmdeploy/codebase/mmrotate/models/__init__.py
+++ b/mmdeploy/codebase/mmrotate/models/__init__.py
@@ -1,9 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .oriented_standard_roi_head import (
+ oriented_standard_roi_head__simple_test,
+ oriented_standard_roi_head__simple_test_bboxes)
from .rotated_anchor_head import rotated_anchor_head__get_bbox
+from .rotated_bbox_head import rotated_bbox_head__get_bboxes
+from .rotated_rpn_head import rotated_rpn_head__get_bboxes
from .single_stage_rotated_detector import \
single_stage_rotated_detector__simple_test
__all__ = [
'single_stage_rotated_detector__simple_test',
- 'rotated_anchor_head__get_bbox'
+ 'rotated_anchor_head__get_bbox', 'rotated_rpn_head__get_bboxes',
+ 'oriented_standard_roi_head__simple_test',
+ 'oriented_standard_roi_head__simple_test_bboxes',
+ 'rotated_bbox_head__get_bboxes'
]
diff --git a/mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py b/mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py
new file mode 100644
index 000000000..f977e20d6
--- /dev/null
+++ b/mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdeploy.core import FUNCTION_REWRITER
+
+
+@FUNCTION_REWRITER.register_rewriter(
+ 'mmrotate.models.roi_heads.oriented_standard_roi_head'
+ '.OrientedStandardRoIHead.simple_test')
+def oriented_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas,
+ **kwargs):
+ """Rewrite `simple_test` of `StandardRoIHead` for default backend.
+
+ This function returns detection result as Tensor instead of numpy
+ array.
+
+ Args:
+ ctx (ContextCaller): The context with additional information.
+ self: The instance of the original class.
+ x (tuple[Tensor]): Features from upstream network. Each
+ has shape (batch_size, c, h, w).
+ proposals (list(Tensor)): Proposals from rpn head.
+ Each has shape (num_proposals, 6), last dimension
+ 6 represent (x, y, w, h, theta, score).
+ img_metas (list[dict]): Meta information of images.
+
+ Returns:
+ tuple[Tensor, Tensor]: (det_bboxes, det_labels),
+ `det_bboxes` of shape [N, num_det, 6] and `det_labels`
+ of shape [N, num_det].
+ """
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ det_bboxes, det_labels = self.simple_test_bboxes(
+ x, img_metas, proposals, self.test_cfg, rescale=False)
+
+ return det_bboxes, det_labels
+
+
+@FUNCTION_REWRITER.register_rewriter(
+ 'mmrotate.models.roi_heads.oriented_standard_roi_head'
+ '.OrientedStandardRoIHead.simple_test_bboxes')
+def oriented_standard_roi_head__simple_test_bboxes(ctx,
+ self,
+ x,
+ img_metas,
+ proposals,
+ rcnn_test_cfg,
+ rescale=False):
+ """Test only det bboxes without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ img_metas (list[dict]): Image meta info.
+ proposals (List[Tensor]): Region proposals.
+ rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]]: The first list contains \
+ the boxes of the corresponding image in a batch, each \
+ tensor has the shape (num_boxes, 6) and last dimension \
+ 6 represent (x, y, w, h, theta, score). Each Tensor \
+ in the second list is the labels with shape (num_boxes, ). \
+ The length of both lists should be equal to batch_size.
+ """
+
+ rois, labels = proposals
+ batch_index = torch.arange(
+ rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
+ rois.size(0), rois.size(1), 1)
+ rois = torch.cat([batch_index, rois[..., :5]], dim=-1)
+ batch_size = rois.shape[0]
+ num_proposals_per_img = rois.shape[1]
+
+ # Eliminate the batch dimension
+ rois = rois.view(-1, 6)
+ bbox_results = self._bbox_forward(x, rois)
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+
+ # Recover the batch dimension
+ rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
+ cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
+ cls_score.size(-1))
+
+ bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
+ bbox_pred.size(-1))
+ det_bboxes, det_labels = self.bbox_head.get_bboxes(
+ rois,
+ cls_score,
+ bbox_pred,
+ img_metas[0]['img_shape'],
+ None,
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ return det_bboxes, det_labels
diff --git a/mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py b/mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py
new file mode 100644
index 000000000..b7cfe91a2
--- /dev/null
+++ b/mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn.functional as F
+
+from mmdeploy.codebase.mmdet import get_post_processing_params
+from mmdeploy.codebase.mmrotate.core.post_processing import \
+ multiclass_nms_rotated
+from mmdeploy.core import FUNCTION_REWRITER
+
+
+@FUNCTION_REWRITER.register_rewriter(
+ 'mmrotate.models.roi_heads.bbox_heads.RotatedBBoxHead.get_bboxes')
+def rotated_bbox_head__get_bboxes(ctx,
+ self,
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ rois (torch.Tensor): Boxes to be transformed. Has shape
+ (num_boxes, 6). last dimension 5 arrange as
+ (batch_index, x, y, w, h, theta).
+ cls_score (torch.Tensor): Box scores, has shape
+ (num_boxes, num_classes + 1).
+ bbox_pred (Tensor, optional): Box energies / deltas.
+ has shape (num_boxes, num_classes * 6).
+ img_shape (Sequence[int], optional): Maximum bounds for boxes,
+ specifies (H, W, C) or (H, W).
+ scale_factor (ndarray): Scale factor of the
+ image arrange as (w_scale, h_scale, w_scale, h_scale).
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
+
+ Returns:
+ tuple[Tensor, Tensor]:
+ First tensor is `det_bboxes`, has the shape
+ (num_boxes, 6) and last
+ dimension 6 represent (cx, cy, w, h, theta, score).
+ Second tensor is the labels with shape (num_boxes, ).
+ """
+ assert rois.ndim == 3, 'Only support export two stage ' \
+ 'model to ONNX ' \
+ 'with batch dimension. '
+
+ if self.custom_cls_channels:
+ scores = self.loss_cls.get_activation(cls_score)
+ else:
+ scores = F.softmax(
+ cls_score, dim=-1) if cls_score is not None else None
+
+ assert bbox_pred is not None
+ bboxes = self.bbox_coder.decode(
+ rois[..., 1:], bbox_pred, max_shape=img_shape)
+
+ # ignore background class
+ scores = scores[..., :self.num_classes]
+
+ post_params = get_post_processing_params(ctx.cfg)
+ iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
+ score_threshold = cfg.get('score_thr', post_params.score_threshold)
+ pre_top_k = post_params.pre_top_k
+ keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
+
+ return multiclass_nms_rotated(
+ bboxes,
+ scores,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pre_top_k=pre_top_k,
+ keep_top_k=keep_top_k)
diff --git a/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py b/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py
new file mode 100644
index 000000000..389c67215
--- /dev/null
+++ b/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py
@@ -0,0 +1,142 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdeploy.codebase.mmdet import (get_post_processing_params,
+ pad_with_value_if_necessary)
+from mmdeploy.codebase.mmrotate.core.post_processing import \
+ fake_multiclass_nms_rotated
+from mmdeploy.core import FUNCTION_REWRITER
+from mmdeploy.utils import is_dynamic_shape
+
+
+@FUNCTION_REWRITER.register_rewriter(
+ 'mmrotate.models.dense_heads.RotatedRPNHead.get_bboxes')
+def rotated_rpn_head__get_bboxes(ctx,
+ self,
+ cls_scores,
+ bbox_preds,
+ score_factors=None,
+ img_metas=None,
+ cfg=None,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Rewrite `get_bboxes` of `RPNHead` for default backend.
+
+ Rewrite this function to deploy model, transform network output for a
+ batch into bbox predictions.
+
+ Args:
+ ctx (ContextCaller): The context with additional information.
+ self (FoveaHead): The instance of the class FoveaHead.
+ cls_scores (list[Tensor]): Box scores for each scale level
+ with shape (N, num_anchors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W).
+ score_factors (list[Tensor], Optional): Score factor for
+ all scale level, each is a 4D-tensor, has shape
+ (batch_size, num_priors * 1, H, W). Default None.
+ img_metas (list[dict]): Meta information of the image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default: None.
+ rescale (bool): If True, return boxes in original image space.
+ Default False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ Returns:
+ If with_nms == True:
+ tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
+ `dets` of shape [N, num_det, 5] and `labels` of shape
+ [N, num_det].
+ Else:
+ tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
+ """
+ assert len(cls_scores) == len(bbox_preds)
+ deploy_cfg = ctx.cfg
+ is_dynamic_flag = is_dynamic_shape(deploy_cfg)
+ num_levels = len(cls_scores)
+
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.anchor_generator.grid_anchors(
+ featmap_sizes, device=device)
+
+ mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
+ mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
+ assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
+
+ cfg = self.test_cfg if cfg is None else cfg
+ batch_size = mlvl_cls_scores[0].shape[0]
+ pre_topk = cfg.get('nms_pre', -1)
+
+ # loop over features, decode boxes
+ mlvl_valid_bboxes = []
+ mlvl_scores = []
+ mlvl_valid_anchors = []
+ for level_id, cls_score, bbox_pred, anchors in zip(
+ range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ cls_score = cls_score.permute(0, 2, 3, 1)
+ if self.use_sigmoid_cls:
+ cls_score = cls_score.reshape(batch_size, -1)
+ scores = cls_score.sigmoid()
+ else:
+ cls_score = cls_score.reshape(batch_size, -1, 2)
+ # We set FG labels to [0, num_class-1] and BG label to
+ # num_class in RPN head since mmdet v2.5, which is unified to
+ # be consistent with other head since mmdet v2.0. In mmdet v2.0
+ # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
+ scores = cls_score.softmax(-1)[..., 0]
+ scores = scores.reshape(batch_size, -1, 1)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6)
+
+ # use static anchor if input shape is static
+ if not is_dynamic_flag:
+ anchors = anchors.data
+
+ # anchors = anchors.expand_as(bbox_pred)
+ anchors = anchors.expand(batch_size, -1, anchors.size(-1))
+
+ # topk in tensorrt does not support shape 0:
+ _, topk_inds = scores.squeeze(2).topk(pre_topk)
+ batch_inds = torch.arange(
+ batch_size, device=device).view(-1, 1).expand_as(topk_inds)
+ anchors = anchors[batch_inds, topk_inds, :]
+ bbox_pred = bbox_pred[batch_inds, topk_inds, :]
+ scores = scores[batch_inds, topk_inds, :]
+ mlvl_valid_bboxes.append(bbox_pred)
+ mlvl_scores.append(scores)
+ mlvl_valid_anchors.append(anchors)
+
+ batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
+ batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
+ batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
+ batch_mlvl_bboxes = self.bbox_coder.decode(
+ batch_mlvl_anchors,
+ batch_mlvl_bboxes,
+ max_shape=img_metas[0]['img_shape'])
+ # ignore background class
+ if not self.use_sigmoid_cls:
+ batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
+ if not with_nms:
+ return batch_mlvl_bboxes, batch_mlvl_scores
+
+ post_params = get_post_processing_params(deploy_cfg)
+ iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
+ keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
+ # only one class in rpn
+ max_output_boxes_per_class = keep_top_k
+ return fake_multiclass_nms_rotated(
+ batch_mlvl_bboxes,
+ batch_mlvl_scores,
+ max_output_boxes_per_class,
+ iou_threshold=iou_threshold,
+ keep_top_k=keep_top_k,
+ version=self.version)
diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py
index a3333f547..fae3a6698 100644
--- a/mmdeploy/pytorch/functions/__init__.py
+++ b/mmdeploy/pytorch/functions/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .atan2 import atan2__default
from .chunk import chunk__ncnn
from .getattribute import tensor__getattribute__ncnn
from .group_norm import group_norm__ncnn
@@ -13,5 +14,5 @@ __all__ = [
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
- 'triu'
+ 'triu', 'atan2__default'
]
diff --git a/mmdeploy/pytorch/functions/atan2.py b/mmdeploy/pytorch/functions/atan2.py
new file mode 100644
index 000000000..a09986a8f
--- /dev/null
+++ b/mmdeploy/pytorch/functions/atan2.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdeploy.core import FUNCTION_REWRITER
+
+
+@FUNCTION_REWRITER.register_rewriter(
+ func_name='torch.atan2', backend='default')
+def atan2__default(
+ ctx,
+ input1: torch.Tensor,
+ input2: torch.Tensor,
+):
+ """Rewrite `atan2` for default backend."""
+ return torch.atan(input1 / (input2 + 1e-6))
diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py
index dad51cd1a..487cbb627 100644
--- a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py
+++ b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py
@@ -134,8 +134,10 @@ def test_multiclass_nms_rotated_with_keep_top_k(pre_top_k):
@pytest.mark.parametrize('max_shape,proj_xy,edge_swap',
[(None, False, False),
(torch.tensor([100, 200]), True, True)])
-def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool,
- max_shape: tuple, proj_xy: bool, edge_swap: bool):
+def test_delta_xywha_rbbox_coder_delta2bbox(backend_type: Backend,
+ add_ctr_clamp: bool,
+ max_shape: tuple, proj_xy: bool,
+ edge_swap: bool):
check_backend(backend_type)
deploy_cfg = mmcv.Config(
dict(
@@ -181,3 +183,90 @@ def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool,
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
else:
assert rewrite_outputs is not None
+
+
+@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
+def test_delta_midpointoffset_rbbox_delta2bbox(backend_type: Backend):
+ check_backend(backend_type)
+ deploy_cfg = mmcv.Config(
+ dict(
+ onnx_config=dict(output_names=None, input_shape=None),
+ backend_config=dict(type=backend_type.value, model_inputs=None),
+ codebase_config=dict(type='mmrotate', task='RotatedDetection')))
+
+ # wrap function to enable rewrite
+ def delta2bbox(*args, **kwargs):
+ import mmrotate
+ return mmrotate.core.bbox.coder.delta_midpointoffset_rbbox_coder\
+ .delta2bbox(*args, **kwargs)
+
+ rois = torch.rand(5, 4)
+ deltas = torch.rand(5, 6)
+ original_outputs = delta2bbox(rois, deltas, version='le90')
+
+ # wrap function to nn.Module, enable torch.onnx.export
+ wrapped_func = WrapFunction(delta2bbox)
+ rewrite_outputs, is_backend_output = get_rewrite_outputs(
+ wrapped_func,
+ model_inputs={
+ 'rois': rois.unsqueeze(0),
+ 'deltas': deltas.unsqueeze(0)
+ },
+ deploy_cfg=deploy_cfg)
+
+ if is_backend_output:
+ model_output = original_outputs.squeeze().cpu().numpy()
+ rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy()
+ assert np.allclose(
+ model_output[:, :4], rewrite_output[:, :4], rtol=1e-03, atol=1e-05)
+ else:
+ assert rewrite_outputs is not None
+
+
+@backend_checker(Backend.ONNXRUNTIME)
+def test_fake_multiclass_nms_rotated():
+ from mmdeploy.codebase.mmrotate.core import fake_multiclass_nms_rotated
+ deploy_cfg = mmcv.Config(
+ dict(
+ onnx_config=dict(output_names=None, input_shape=None),
+ backend_config=dict(
+ type='onnxruntime',
+ common_config=dict(
+ fp16_mode=False, max_workspace_size=1 << 20),
+ model_inputs=[
+ dict(
+ input_shapes=dict(
+ boxes=dict(
+ min_shape=[1, 5, 5],
+ opt_shape=[1, 5, 5],
+ max_shape=[1, 5, 5]),
+ scores=dict(
+ min_shape=[1, 5, 8],
+ opt_shape=[1, 5, 8],
+ max_shape=[1, 5, 8])))
+ ]),
+ codebase_config=dict(
+ type='mmrotate',
+ task='RotatedDetection',
+ post_processing=dict(
+ score_threshold=0.05,
+ iou_threshold=0.5,
+ pre_top_k=-1,
+ keep_top_k=10,
+ ))))
+
+ boxes = torch.rand(1, 5, 5)
+ scores = torch.rand(1, 5, 8)
+ keep_top_k = 10
+ wrapped_func = WrapFunction(
+ fake_multiclass_nms_rotated, keep_top_k=keep_top_k)
+ rewrite_outputs, _ = get_rewrite_outputs(
+ wrapped_func,
+ model_inputs={
+ 'boxes': boxes,
+ 'scores': scores
+ },
+ deploy_cfg=deploy_cfg)
+
+ assert rewrite_outputs is not None, 'Got unexpected rewrite '\
+ 'outputs: {}'.format(rewrite_outputs)
diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py
index 501411ce7..656a2b4e2 100644
--- a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py
+++ b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py
@@ -74,47 +74,46 @@ def _replace_r50_with_r18(model):
return model
-# @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
-# @pytest.mark.parametrize('model_cfg_path', [
-# 'tests/test_codebase/test_mmrotate/data/single_stage_model.json'
-# ])
-# def test_forward_of_base_detector(model_cfg_path, backend):
-# check_backend(backend)
-# deploy_cfg = mmcv.Config(
-# dict(
-# backend_config=dict(type=backend.value),
-# onnx_config=dict(
-# output_names=['dets', 'labels'], input_shape=None),
-# codebase_config=dict(
-# type='mmrotate',
-# task='RotatedDetection',
-# post_processing=dict(
-# score_threshold=0.05,
-# iou_threshold=0.5,
-# pre_top_k=-1,
-# keep_top_k=100,
-# ))))
+@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
+@pytest.mark.parametrize(
+ 'model_cfg_path',
+ ['tests/test_codebase/test_mmrotate/data/single_stage_model.json'])
+def test_forward_of_base_detector(model_cfg_path, backend):
+ check_backend(backend)
+ deploy_cfg = mmcv.Config(
+ dict(
+ backend_config=dict(type=backend.value),
+ onnx_config=dict(
+ output_names=['dets', 'labels'], input_shape=None),
+ codebase_config=dict(
+ type='mmrotate',
+ task='RotatedDetection',
+ post_processing=dict(
+ score_threshold=0.05,
+ iou_threshold=0.5,
+ pre_top_k=-1,
+ keep_top_k=100,
+ ))))
-# model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
-# model_cfg.model = _replace_r50_with_r18(model_cfg.model)
+ model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
+ model_cfg.model = _replace_r50_with_r18(model_cfg.model)
-# from mmrotate.models import build_detector
+ from mmrotate.models import build_detector
-# model_cfg.model.pretrained = None
-# model_cfg.model.train_cfg = None
-# model = build_detector(
-# model_cfg.model, test_cfg= model_cfg.get('test_cfg'))
-# model.cfg = model_cfg
-# model.to('cpu')
+ model_cfg.model.pretrained = None
+ model_cfg.model.train_cfg = None
+ model = build_detector(model_cfg.model, test_cfg=model_cfg.get('test_cfg'))
+ model.cfg = model_cfg
+ model.to('cpu')
-# img = torch.randn(1, 3, 64, 64)
-# rewrite_inputs = {'img': img}
-# rewrite_outputs, _ = get_rewrite_outputs(
-# wrapped_model=model,
-# model_inputs=rewrite_inputs,
-# deploy_cfg=deploy_cfg)
+ img = torch.randn(1, 3, 64, 64)
+ rewrite_inputs = {'img': img}
+ rewrite_outputs, _ = get_rewrite_outputs(
+ wrapped_model=model,
+ model_inputs=rewrite_inputs,
+ deploy_cfg=deploy_cfg)
-# assert rewrite_outputs is not None
+ assert rewrite_outputs is not None
def get_deploy_cfg(backend_type: Backend, ir_type: str):
@@ -155,8 +154,8 @@ def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str):
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
- # the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
- # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
+ # the bboxes's size: (1, 45, 32, 32), (1, 45, 16, 16),
+ # (1, 45, 8, 8), (1, 45, 4, 4), (1, 45, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
@@ -201,3 +200,135 @@ def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str):
atol=1e-05)
else:
assert rewrite_outputs is not None
+
+
+def get_single_roi_extractor():
+ """SingleRoIExtractor Config."""
+ from mmrotate.models.roi_heads import RotatedSingleRoIExtractor
+ roi_layer = dict(
+ type='RoIAlignRotated', out_size=7, sample_num=2, clockwise=True)
+ out_channels = 1
+ featmap_strides = [4, 8, 16, 32]
+ model = RotatedSingleRoIExtractor(roi_layer, out_channels,
+ featmap_strides).eval()
+
+ return model
+
+
+@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
+def test_rotated_single_roi_extractor(backend_type: Backend):
+ check_backend(backend_type)
+
+ single_roi_extractor = get_single_roi_extractor()
+ output_names = ['roi_feat']
+ deploy_cfg = mmcv.Config(
+ dict(
+ backend_config=dict(type=backend_type.value),
+ onnx_config=dict(output_names=output_names, input_shape=None),
+ codebase_config=dict(
+ type='mmrotate',
+ task='RotatedDetection',
+ )))
+
+ seed_everything(1234)
+ out_channels = single_roi_extractor.out_channels
+ feats = [
+ torch.rand((1, out_channels, 200, 336)),
+ torch.rand((1, out_channels, 100, 168)),
+ torch.rand((1, out_channels, 50, 84)),
+ torch.rand((1, out_channels, 25, 42)),
+ ]
+ seed_everything(5678)
+ rois = torch.tensor(
+ [[0.0000, 587.8285, 52.1405, 886.2484, 341.5644, 0.0000]])
+
+ model_inputs = {
+ 'feats': feats,
+ 'rois': rois,
+ }
+ model_outputs = get_model_outputs(single_roi_extractor, 'forward',
+ model_inputs)
+
+ backend_outputs, _ = get_rewrite_outputs(
+ wrapped_model=single_roi_extractor,
+ model_inputs=model_inputs,
+ deploy_cfg=deploy_cfg)
+ if isinstance(backend_outputs, dict):
+ backend_outputs = backend_outputs.values()
+ for model_output, backend_output in zip(model_outputs[0], backend_outputs):
+ model_output = model_output.squeeze().cpu().numpy()
+ backend_output = backend_output.squeeze()
+ assert np.allclose(
+ model_output, backend_output, rtol=1e-03, atol=1e-05)
+
+
+def get_oriented_rpn_head_model():
+ """Oriented RPN Head Config."""
+ test_cfg = mmcv.Config(
+ dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(iou_thr=0.1),
+ max_per_img=2000))
+ from mmrotate.models.dense_heads import OrientedRPNHead
+ model = OrientedRPNHead(
+ in_channels=1,
+ version='le90',
+ bbox_coder=dict(type='MidpointOffsetCoder', angle_range='le90'),
+ test_cfg=test_cfg)
+
+ model.requires_grad_(False)
+ return model
+
+
+@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
+def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend):
+ check_backend(backend_type)
+ head = get_oriented_rpn_head_model()
+ head.cpu().eval()
+ s = 128
+ img_metas = [{
+ 'scale_factor': np.ones(4),
+ 'pad_shape': (s, s, 3),
+ 'img_shape': (s, s, 3)
+ }]
+
+ output_names = ['dets', 'labels']
+ deploy_cfg = mmcv.Config(
+ dict(
+ backend_config=dict(type=backend_type.value),
+ onnx_config=dict(output_names=output_names, input_shape=None),
+ codebase_config=dict(
+ type='mmrotate',
+ task='RotatedDetection',
+ post_processing=dict(
+ score_threshold=0.05,
+ iou_threshold=0.1,
+ pre_top_k=2000,
+ keep_top_k=2000))))
+
+ # the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
+ # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
+ # the bboxes's size: (1, 54, 32, 32), (1, 54, 16, 16),
+ # (1, 54, 8, 8), (1, 54, 4, 4), (1, 54, 2, 2)
+ seed_everything(1234)
+ cls_score = [
+ torch.rand(1, 9, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
+ ]
+ seed_everything(5678)
+ bboxes = [torch.rand(1, 54, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
+
+ # to get outputs of onnx model after rewrite
+ img_metas[0]['img_shape'] = torch.Tensor([s, s])
+ wrapped_model = WrapModel(
+ head, 'get_bboxes', img_metas=img_metas, with_nms=True)
+ rewrite_inputs = {
+ 'cls_scores': cls_score,
+ 'bbox_preds': bboxes,
+ }
+ rewrite_outputs, is_backend_output = get_rewrite_outputs(
+ wrapped_model=wrapped_model,
+ model_inputs=rewrite_inputs,
+ deploy_cfg=deploy_cfg)
+ assert rewrite_outputs is not None