From 89204d16cea6b261b06b891b5c138af8ec7a4ccb Mon Sep 17 00:00:00 2001 From: Yue Zhou <592267829@qq.com> Date: Wed, 25 May 2022 13:41:16 +0800 Subject: [PATCH] [Feature] Support two stage rotated detector in MMRotate (#473) * upload * add fake_multiclass_nms_rotated * delete unused code * align with pytorch * Update delta_midpointoffset_rbbox_coder.py * rewrite atan2 * Update bbox_nms.py --- docs/en/benchmark.md | 21 +- docs/en/codebases/mmrotate.md | 1 + docs/zh_cn/03-benchmark/benchmark.md | 18 +- .../codebase/mmrotate/core/bbox/__init__.py | 1 + .../bbox/delta_midpointoffset_rbbox_coder.py | 103 +++++++++ .../mmrotate/core/post_processing/__init__.py | 4 +- .../mmrotate/core/post_processing/bbox_nms.py | 61 +++++- mmdeploy/codebase/mmrotate/models/__init__.py | 10 +- .../models/oriented_standard_roi_head.py | 97 ++++++++ .../mmrotate/models/rotated_bbox_head.py | 75 +++++++ .../mmrotate/models/rotated_rpn_head.py | 142 ++++++++++++ mmdeploy/pytorch/functions/__init__.py | 3 +- mmdeploy/pytorch/functions/atan2.py | 15 ++ .../test_mmrotate/test_mmrotate_core.py | 93 +++++++- .../test_mmrotate/test_mmrotate_models.py | 207 ++++++++++++++---- 15 files changed, 793 insertions(+), 58 deletions(-) create mode 100644 mmdeploy/codebase/mmrotate/core/bbox/delta_midpointoffset_rbbox_coder.py create mode 100644 mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py create mode 100644 mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py create mode 100644 mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py create mode 100644 mmdeploy/pytorch/functions/atan2.py diff --git a/docs/en/benchmark.md b/docs/en/benchmark.md index 703c77e3a..03931e9a4 100644 --- a/docs/en/benchmark.md +++ b/docs/en/benchmark.md @@ -1954,9 +1954,9 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut model config file - RotatedRetinaNet - Rotated Detection - DOTA-v1.0 + RotatedRetinaNet + Rotated Detection + DOTA-v1.0 mAP 0.698 0.698 @@ -1964,7 +1964,20 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut 0.697 - - - $MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py + $MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py + + + Oriented RCNN + Rotated Detection + DOTA-v1.0 + mAP + 0.756 + 0.756 + - + - + - + - + $MMROTATE_DIR/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_le90.py diff --git a/docs/en/codebases/mmrotate.md b/docs/en/codebases/mmrotate.md index aa6705875..efa675902 100644 --- a/docs/en/codebases/mmrotate.md +++ b/docs/en/codebases/mmrotate.md @@ -11,6 +11,7 @@ Please refer to [official installation guide](https://mmrotate.readthedocs.io/en | Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config | |:----------------------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:| | RotatedRetinaNet | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | +| Oriented RCNN | RotatedDetection | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | ### Example diff --git a/docs/zh_cn/03-benchmark/benchmark.md b/docs/zh_cn/03-benchmark/benchmark.md index 40ef70bfa..653eb90c4 100644 --- a/docs/zh_cn/03-benchmark/benchmark.md +++ b/docs/zh_cn/03-benchmark/benchmark.md @@ -1557,9 +1557,9 @@ GPU: ncnn, TensorRT, PPLNN fp32 - RotatedRetinaNet - Rotated Detection - DOTA-v1.0 + RotatedRetinaNet + Rotated Detection + DOTA-v1.0 mAP 0.698 0.698 @@ -1568,6 +1568,18 @@ GPU: ncnn, TensorRT, PPLNN - - + + Oriented RCNN + Rotated Detection + DOTA-v1.0 + mAP + 0.756 + 0.756 + - + - + - + - + diff --git a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py index a42a3cf2d..2933ca8be 100644 --- a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py +++ b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .delta_midpointoffset_rbbox_coder import * # noqa: F401,F403 from .delta_xywha_rbbox_coder import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmrotate/core/bbox/delta_midpointoffset_rbbox_coder.py b/mmdeploy/codebase/mmrotate/core/bbox/delta_midpointoffset_rbbox_coder.py new file mode 100644 index 000000000..92098c4ba --- /dev/null +++ b/mmdeploy/codebase/mmrotate/core/bbox/delta_midpointoffset_rbbox_coder.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from mmrotate.core import poly2obb + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmrotate.core.bbox.coder.delta_midpointoffset_rbbox_coder' + '.delta2bbox', + backend='default') +def delta2bbox(ctx, + rois, + deltas, + means=(0., 0., 0., 0., 0., 0.), + stds=(1., 1., 1., 1., 1., 1.), + wh_ratio_clip=16 / 1000, + version='oc'): + """Apply deltas to shift/scale base boxes. + + Typically the rois are anchor or proposed bounding boxes and the deltas + are network outputs used to shift/scale those boxes. This is the inverse + function of :func:`bbox2delta`. + + + Args: + rois (torch.Tensor): Boxes to be transformed. Has shape (N, 6). + deltas (torch.Tensor): Encoded offsets relative to each roi. + Has shape (N, num_classes * 6) or (N, 6). Note + N = num_base_anchors * W * H, when rois is a grid of + anchors. + means (Sequence[float]): Denormalizing means for delta coordinates. + Default (0., 0., 0., 0., 0., 0.). + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates. Default (1., 1., 1., 1., 1., 1.). + wh_ratio_clip (float): Maximum aspect ratio for boxes. Default + 16 / 1000. + version (str, optional): Angle representations. Defaults to 'oc'. + + Returns: + Tensor: Boxes with shape (N, num_classes * 5) or (N, 5), where 5 + represent cx, cy, w, h, a. + """ + means = deltas.new_tensor(means).view(1, -1) + stds = deltas.new_tensor(stds).view(1, -1) + delta_shape = deltas.shape + reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 6)) + denorm_deltas = reshaped_deltas * stds + means + + # means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 6) + # stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 6) + # denorm_deltas = deltas * stds + means + dx = denorm_deltas[..., 0::6] + dy = denorm_deltas[..., 1::6] + dw = denorm_deltas[..., 2::6] + dh = denorm_deltas[..., 3::6] + da = denorm_deltas[..., 4::6] + db = denorm_deltas[..., 5::6] + max_ratio = np.abs(np.log(wh_ratio_clip)) + dw = dw.clamp(min=-max_ratio, max=max_ratio) + dh = dh.clamp(min=-max_ratio, max=max_ratio) + # Compute center of each roi + px = ((rois[..., None, None, 0] + rois[..., None, None, 2]) * 0.5) + py = ((rois[..., None, None, 1] + rois[..., None, None, 3]) * 0.5) + # Compute width/height of each roi + pw = (rois[..., None, None, 2] - rois[..., None, None, 0]) + ph = (rois[..., None, None, 3] - rois[..., None, None, 1]) + # Use exp(network energy) to enlarge/shrink each roi + gw = pw * dw.exp() + gh = ph * dh.exp() + # Use network energy to shift the center of each roi + gx = px + pw * dx + gy = py + ph * dy + + x1 = gx - gw * 0.5 + y1 = gy - gh * 0.5 + x2 = gx + gw * 0.5 + y2 = gy + gh * 0.5 + + da = da.clamp(min=-0.5, max=0.5) + db = db.clamp(min=-0.5, max=0.5) + ga = gx + da * gw + _ga = gx - da * gw + gb = gy + db * gh + _gb = gy - db * gh + polys = torch.stack([ga, y1, x2, gb, _ga, y2, x1, _gb], dim=-1) + + center = torch.stack([gx, gy, gx, gy, gx, gy, gx, gy], dim=-1) + center_polys = polys - center + diag_len = torch.sqrt(center_polys[..., 0::2] * center_polys[..., 0::2] + + center_polys[..., 1::2] * center_polys[..., 1::2]) + max_diag_len, _ = torch.max(diag_len, dim=-1, keepdim=True) + diag_scale_factor = max_diag_len / diag_len + center_polys_shape = center_polys.shape + center_polys = center_polys.view(*center_polys_shape[:3], 4, + -1) * diag_scale_factor.view( + *center_polys_shape[:3], 4, 1) + center_polys = center_polys.view(center_polys_shape) + rectpolys = center_polys + center + obboxes = poly2obb(rectpolys, version).view(delta_shape[:-1] + (5, )) + + return obboxes diff --git a/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py b/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py index 01c3b72a4..f3fc975fb 100644 --- a/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py +++ b/mmdeploy/codebase/mmrotate/core/post_processing/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .bbox_nms import multiclass_nms_rotated +from .bbox_nms import fake_multiclass_nms_rotated, multiclass_nms_rotated -__all__ = ['multiclass_nms_rotated'] +__all__ = ['multiclass_nms_rotated', 'fake_multiclass_nms_rotated'] diff --git a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py index 5d9cf6efa..7c884937a 100644 --- a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py +++ b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py @@ -1,17 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmrotate.core import obb2xyxy from torch import Tensor import mmdeploy from mmdeploy.core import FUNCTION_REWRITER, mark -from mmdeploy.mmcv.ops import ONNXNMSRotatedOp, TRTBatchedRotatedNMSop +from mmdeploy.mmcv.ops import (ONNXNMSop, ONNXNMSRotatedOp, + TRTBatchedRotatedNMSop) -def select_nms_index(scores: torch.Tensor, - boxes: torch.Tensor, - nms_index: torch.Tensor, - batch_size: int, - keep_top_k: int = -1): +def select_rnms_index(scores: torch.Tensor, + boxes: torch.Tensor, + nms_index: torch.Tensor, + batch_size: int, + keep_top_k: int = -1): """Transform NMSRotated output. Args: @@ -85,6 +87,7 @@ def _multiclass_nms_rotated(boxes: Tensor, op. It only supports class-agnostic detection results. That is, the scores is of shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes, 5). + Args: boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5]. scores (Tensor): The detection scores of shape @@ -114,7 +117,7 @@ def _multiclass_nms_rotated(boxes: Tensor, selected_indices = ONNXNMSRotatedOp.apply(boxes, scores, iou_threshold, score_threshold) - dets, labels = select_nms_index( + dets, labels = select_rnms_index( scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) return dets, labels @@ -173,3 +176,47 @@ def multiclass_nms_rotated(*args, **kwargs): """Wrapper function for `_multiclass_nms`.""" return mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms.\ _multiclass_nms_rotated(*args, **kwargs) + + +@mark( + 'fake_multiclass_nms_rotated', + inputs=['boxes', 'scores'], + outputs=['dets', 'labels']) +def fake_multiclass_nms_rotated(boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.0, + pre_top_k: int = -1, + keep_top_k: int = -1, + version: str = 'le90'): + """Fake NMSRotated for multi-class bboxes which use horizontal bboxes for + NMS, but return the rotated bboxes result. + + This function helps exporting to onnx with batch and multiclass NMS op. It + only supports class-agnostic detection results. That is, the scores is of + shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes, + 5). + """ + max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class]) + iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) + score_threshold = torch.tensor([score_threshold], dtype=torch.float32) + batch_size = scores.shape[0] + + if pre_top_k > 0: + max_scores, _ = scores.max(-1) + _, topk_inds = max_scores.topk(pre_top_k) + batch_inds = torch.arange(batch_size).view(-1, 1).long() + boxes = boxes[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + + scores = scores.permute(0, 2, 1) + hboxes = obb2xyxy(boxes, version) + selected_indices = ONNXNMSop.apply(hboxes, scores, + max_output_boxes_per_class, + iou_threshold, score_threshold) + + dets, labels = select_rnms_index( + scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) + + return dets, labels diff --git a/mmdeploy/codebase/mmrotate/models/__init__.py b/mmdeploy/codebase/mmrotate/models/__init__.py index 6fe59fd52..32a7d21e7 100644 --- a/mmdeploy/codebase/mmrotate/models/__init__.py +++ b/mmdeploy/codebase/mmrotate/models/__init__.py @@ -1,9 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .oriented_standard_roi_head import ( + oriented_standard_roi_head__simple_test, + oriented_standard_roi_head__simple_test_bboxes) from .rotated_anchor_head import rotated_anchor_head__get_bbox +from .rotated_bbox_head import rotated_bbox_head__get_bboxes +from .rotated_rpn_head import rotated_rpn_head__get_bboxes from .single_stage_rotated_detector import \ single_stage_rotated_detector__simple_test __all__ = [ 'single_stage_rotated_detector__simple_test', - 'rotated_anchor_head__get_bbox' + 'rotated_anchor_head__get_bbox', 'rotated_rpn_head__get_bboxes', + 'oriented_standard_roi_head__simple_test', + 'oriented_standard_roi_head__simple_test_bboxes', + 'rotated_bbox_head__get_bboxes' ] diff --git a/mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py b/mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py new file mode 100644 index 000000000..f977e20d6 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/oriented_standard_roi_head.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.models.roi_heads.oriented_standard_roi_head' + '.OrientedStandardRoIHead.simple_test') +def oriented_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas, + **kwargs): + """Rewrite `simple_test` of `StandardRoIHead` for default backend. + + This function returns detection result as Tensor instead of numpy + array. + + Args: + ctx (ContextCaller): The context with additional information. + self: The instance of the original class. + x (tuple[Tensor]): Features from upstream network. Each + has shape (batch_size, c, h, w). + proposals (list(Tensor)): Proposals from rpn head. + Each has shape (num_proposals, 6), last dimension + 6 represent (x, y, w, h, theta, score). + img_metas (list[dict]): Meta information of images. + + Returns: + tuple[Tensor, Tensor]: (det_bboxes, det_labels), + `det_bboxes` of shape [N, num_det, 6] and `det_labels` + of shape [N, num_det]. + """ + assert self.with_bbox, 'Bbox head must be implemented.' + det_bboxes, det_labels = self.simple_test_bboxes( + x, img_metas, proposals, self.test_cfg, rescale=False) + + return det_bboxes, det_labels + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.models.roi_heads.oriented_standard_roi_head' + '.OrientedStandardRoIHead.simple_test_bboxes') +def oriented_standard_roi_head__simple_test_bboxes(ctx, + self, + x, + img_metas, + proposals, + rcnn_test_cfg, + rescale=False): + """Test only det bboxes without augmentation. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + img_metas (list[dict]): Image meta info. + proposals (List[Tensor]): Region proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Default: False. + + Returns: + tuple[list[Tensor], list[Tensor]]: The first list contains \ + the boxes of the corresponding image in a batch, each \ + tensor has the shape (num_boxes, 6) and last dimension \ + 6 represent (x, y, w, h, theta, score). Each Tensor \ + in the second list is the labels with shape (num_boxes, ). \ + The length of both lists should be equal to batch_size. + """ + + rois, labels = proposals + batch_index = torch.arange( + rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand( + rois.size(0), rois.size(1), 1) + rois = torch.cat([batch_index, rois[..., :5]], dim=-1) + batch_size = rois.shape[0] + num_proposals_per_img = rois.shape[1] + + # Eliminate the batch dimension + rois = rois.view(-1, 6) + bbox_results = self._bbox_forward(x, rois) + cls_score = bbox_results['cls_score'] + bbox_pred = bbox_results['bbox_pred'] + + # Recover the batch dimension + rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1)) + cls_score = cls_score.reshape(batch_size, num_proposals_per_img, + cls_score.size(-1)) + + bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, + bbox_pred.size(-1)) + det_bboxes, det_labels = self.bbox_head.get_bboxes( + rois, + cls_score, + bbox_pred, + img_metas[0]['img_shape'], + None, + rescale=rescale, + cfg=rcnn_test_cfg) + return det_bboxes, det_labels diff --git a/mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py b/mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py new file mode 100644 index 000000000..b7cfe91a2 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/rotated_bbox_head.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn.functional as F + +from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmrotate.core.post_processing import \ + multiclass_nms_rotated +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.models.roi_heads.bbox_heads.RotatedBBoxHead.get_bboxes') +def rotated_bbox_head__get_bboxes(ctx, + self, + rois, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=False, + cfg=None): + """Transform network output for a batch into bbox predictions. + + Args: + rois (torch.Tensor): Boxes to be transformed. Has shape + (num_boxes, 6). last dimension 5 arrange as + (batch_index, x, y, w, h, theta). + cls_score (torch.Tensor): Box scores, has shape + (num_boxes, num_classes + 1). + bbox_pred (Tensor, optional): Box energies / deltas. + has shape (num_boxes, num_classes * 6). + img_shape (Sequence[int], optional): Maximum bounds for boxes, + specifies (H, W, C) or (H, W). + scale_factor (ndarray): Scale factor of the + image arrange as (w_scale, h_scale, w_scale, h_scale). + rescale (bool): If True, return boxes in original image space. + Default: False. + cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None + + Returns: + tuple[Tensor, Tensor]: + First tensor is `det_bboxes`, has the shape + (num_boxes, 6) and last + dimension 6 represent (cx, cy, w, h, theta, score). + Second tensor is the labels with shape (num_boxes, ). + """ + assert rois.ndim == 3, 'Only support export two stage ' \ + 'model to ONNX ' \ + 'with batch dimension. ' + + if self.custom_cls_channels: + scores = self.loss_cls.get_activation(cls_score) + else: + scores = F.softmax( + cls_score, dim=-1) if cls_score is not None else None + + assert bbox_pred is not None + bboxes = self.bbox_coder.decode( + rois[..., 1:], bbox_pred, max_shape=img_shape) + + # ignore background class + scores = scores[..., :self.num_classes] + + post_params = get_post_processing_params(ctx.cfg) + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + + return multiclass_nms_rotated( + bboxes, + scores, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) diff --git a/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py b/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py new file mode 100644 index 000000000..389c67215 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.codebase.mmdet import (get_post_processing_params, + pad_with_value_if_necessary) +from mmdeploy.codebase.mmrotate.core.post_processing import \ + fake_multiclass_nms_rotated +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.models.dense_heads.RotatedRPNHead.get_bboxes') +def rotated_rpn_head__get_bboxes(ctx, + self, + cls_scores, + bbox_preds, + score_factors=None, + img_metas=None, + cfg=None, + rescale=False, + with_nms=True, + **kwargs): + """Rewrite `get_bboxes` of `RPNHead` for default backend. + + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + + Args: + ctx (ContextCaller): The context with additional information. + self (FoveaHead): The instance of the class FoveaHead. + cls_scores (list[Tensor]): Box scores for each scale level + with shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W). + score_factors (list[Tensor], Optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Default None. + img_metas (list[dict]): Meta information of the image, e.g., + image size, scaling factor, etc. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. Default: None. + rescale (bool): If True, return boxes in original image space. + Default False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + Returns: + If with_nms == True: + tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), + `dets` of shape [N, num_det, 5] and `labels` of shape + [N, num_det]. + Else: + tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores + """ + assert len(cls_scores) == len(bbox_preds) + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + num_levels = len(cls_scores) + + device = cls_scores[0].device + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_anchors = self.anchor_generator.grid_anchors( + featmap_sizes, device=device) + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors) + + cfg = self.test_cfg if cfg is None else cfg + batch_size = mlvl_cls_scores[0].shape[0] + pre_topk = cfg.get('nms_pre', -1) + + # loop over features, decode boxes + mlvl_valid_bboxes = [] + mlvl_scores = [] + mlvl_valid_anchors = [] + for level_id, cls_score, bbox_pred, anchors in zip( + range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + cls_score = cls_score.permute(0, 2, 3, 1) + if self.use_sigmoid_cls: + cls_score = cls_score.reshape(batch_size, -1) + scores = cls_score.sigmoid() + else: + cls_score = cls_score.reshape(batch_size, -1, 2) + # We set FG labels to [0, num_class-1] and BG label to + # num_class in RPN head since mmdet v2.5, which is unified to + # be consistent with other head since mmdet v2.0. In mmdet v2.0 + # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. + scores = cls_score.softmax(-1)[..., 0] + scores = scores.reshape(batch_size, -1, 1) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6) + + # use static anchor if input shape is static + if not is_dynamic_flag: + anchors = anchors.data + + # anchors = anchors.expand_as(bbox_pred) + anchors = anchors.expand(batch_size, -1, anchors.size(-1)) + + # topk in tensorrt does not support shape 0: + _, topk_inds = scores.squeeze(2).topk(pre_topk) + batch_inds = torch.arange( + batch_size, device=device).view(-1, 1).expand_as(topk_inds) + anchors = anchors[batch_inds, topk_inds, :] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + mlvl_valid_bboxes.append(bbox_pred) + mlvl_scores.append(scores) + mlvl_valid_anchors.append(anchors) + + batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1) + batch_mlvl_bboxes = self.bbox_coder.decode( + batch_mlvl_anchors, + batch_mlvl_bboxes, + max_shape=img_metas[0]['img_shape']) + # ignore background class + if not self.use_sigmoid_cls: + batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes] + if not with_nms: + return batch_mlvl_bboxes, batch_mlvl_scores + + post_params = get_post_processing_params(deploy_cfg) + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + # only one class in rpn + max_output_boxes_per_class = keep_top_k + return fake_multiclass_nms_rotated( + batch_mlvl_bboxes, + batch_mlvl_scores, + max_output_boxes_per_class, + iou_threshold=iou_threshold, + keep_top_k=keep_top_k, + version=self.version) diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index a3333f547..fae3a6698 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .atan2 import atan2__default from .chunk import chunk__ncnn from .getattribute import tensor__getattribute__ncnn from .group_norm import group_norm__ncnn @@ -13,5 +14,5 @@ __all__ = [ 'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn', 'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn', - 'triu' + 'triu', 'atan2__default' ] diff --git a/mmdeploy/pytorch/functions/atan2.py b/mmdeploy/pytorch/functions/atan2.py new file mode 100644 index 000000000..a09986a8f --- /dev/null +++ b/mmdeploy/pytorch/functions/atan2.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.atan2', backend='default') +def atan2__default( + ctx, + input1: torch.Tensor, + input2: torch.Tensor, +): + """Rewrite `atan2` for default backend.""" + return torch.atan(input1 / (input2 + 1e-6)) diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py index dad51cd1a..487cbb627 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py @@ -134,8 +134,10 @@ def test_multiclass_nms_rotated_with_keep_top_k(pre_top_k): @pytest.mark.parametrize('max_shape,proj_xy,edge_swap', [(None, False, False), (torch.tensor([100, 200]), True, True)]) -def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool, - max_shape: tuple, proj_xy: bool, edge_swap: bool): +def test_delta_xywha_rbbox_coder_delta2bbox(backend_type: Backend, + add_ctr_clamp: bool, + max_shape: tuple, proj_xy: bool, + edge_swap: bool): check_backend(backend_type) deploy_cfg = mmcv.Config( dict( @@ -181,3 +183,90 @@ def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool, model_output, rewrite_output, rtol=1e-03, atol=1e-05) else: assert rewrite_outputs is not None + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_delta_midpointoffset_rbbox_delta2bbox(backend_type: Backend): + check_backend(backend_type) + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict(output_names=None, input_shape=None), + backend_config=dict(type=backend_type.value, model_inputs=None), + codebase_config=dict(type='mmrotate', task='RotatedDetection'))) + + # wrap function to enable rewrite + def delta2bbox(*args, **kwargs): + import mmrotate + return mmrotate.core.bbox.coder.delta_midpointoffset_rbbox_coder\ + .delta2bbox(*args, **kwargs) + + rois = torch.rand(5, 4) + deltas = torch.rand(5, 6) + original_outputs = delta2bbox(rois, deltas, version='le90') + + # wrap function to nn.Module, enable torch.onnx.export + wrapped_func = WrapFunction(delta2bbox) + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_func, + model_inputs={ + 'rois': rois.unsqueeze(0), + 'deltas': deltas.unsqueeze(0) + }, + deploy_cfg=deploy_cfg) + + if is_backend_output: + model_output = original_outputs.squeeze().cpu().numpy() + rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy() + assert np.allclose( + model_output[:, :4], rewrite_output[:, :4], rtol=1e-03, atol=1e-05) + else: + assert rewrite_outputs is not None + + +@backend_checker(Backend.ONNXRUNTIME) +def test_fake_multiclass_nms_rotated(): + from mmdeploy.codebase.mmrotate.core import fake_multiclass_nms_rotated + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict(output_names=None, input_shape=None), + backend_config=dict( + type='onnxruntime', + common_config=dict( + fp16_mode=False, max_workspace_size=1 << 20), + model_inputs=[ + dict( + input_shapes=dict( + boxes=dict( + min_shape=[1, 5, 5], + opt_shape=[1, 5, 5], + max_shape=[1, 5, 5]), + scores=dict( + min_shape=[1, 5, 8], + opt_shape=[1, 5, 8], + max_shape=[1, 5, 8]))) + ]), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + pre_top_k=-1, + keep_top_k=10, + )))) + + boxes = torch.rand(1, 5, 5) + scores = torch.rand(1, 5, 8) + keep_top_k = 10 + wrapped_func = WrapFunction( + fake_multiclass_nms_rotated, keep_top_k=keep_top_k) + rewrite_outputs, _ = get_rewrite_outputs( + wrapped_func, + model_inputs={ + 'boxes': boxes, + 'scores': scores + }, + deploy_cfg=deploy_cfg) + + assert rewrite_outputs is not None, 'Got unexpected rewrite '\ + 'outputs: {}'.format(rewrite_outputs) diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py index 501411ce7..656a2b4e2 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py @@ -74,47 +74,46 @@ def _replace_r50_with_r18(model): return model -# @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) -# @pytest.mark.parametrize('model_cfg_path', [ -# 'tests/test_codebase/test_mmrotate/data/single_stage_model.json' -# ]) -# def test_forward_of_base_detector(model_cfg_path, backend): -# check_backend(backend) -# deploy_cfg = mmcv.Config( -# dict( -# backend_config=dict(type=backend.value), -# onnx_config=dict( -# output_names=['dets', 'labels'], input_shape=None), -# codebase_config=dict( -# type='mmrotate', -# task='RotatedDetection', -# post_processing=dict( -# score_threshold=0.05, -# iou_threshold=0.5, -# pre_top_k=-1, -# keep_top_k=100, -# )))) +@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) +@pytest.mark.parametrize( + 'model_cfg_path', + ['tests/test_codebase/test_mmrotate/data/single_stage_model.json']) +def test_forward_of_base_detector(model_cfg_path, backend): + check_backend(backend) + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend.value), + onnx_config=dict( + output_names=['dets', 'labels'], input_shape=None), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + pre_top_k=-1, + keep_top_k=100, + )))) -# model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path))) -# model_cfg.model = _replace_r50_with_r18(model_cfg.model) + model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path))) + model_cfg.model = _replace_r50_with_r18(model_cfg.model) -# from mmrotate.models import build_detector + from mmrotate.models import build_detector -# model_cfg.model.pretrained = None -# model_cfg.model.train_cfg = None -# model = build_detector( -# model_cfg.model, test_cfg= model_cfg.get('test_cfg')) -# model.cfg = model_cfg -# model.to('cpu') + model_cfg.model.pretrained = None + model_cfg.model.train_cfg = None + model = build_detector(model_cfg.model, test_cfg=model_cfg.get('test_cfg')) + model.cfg = model_cfg + model.to('cpu') -# img = torch.randn(1, 3, 64, 64) -# rewrite_inputs = {'img': img} -# rewrite_outputs, _ = get_rewrite_outputs( -# wrapped_model=model, -# model_inputs=rewrite_inputs, -# deploy_cfg=deploy_cfg) + img = torch.randn(1, 3, 64, 64) + rewrite_inputs = {'img': img} + rewrite_outputs, _ = get_rewrite_outputs( + wrapped_model=model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) -# assert rewrite_outputs is not None + assert rewrite_outputs is not None def get_deploy_cfg(backend_type: Backend, ir_type: str): @@ -155,8 +154,8 @@ def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str): # the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16), # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2). - # the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16), - # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2) + # the bboxes's size: (1, 45, 32, 32), (1, 45, 16, 16), + # (1, 45, 8, 8), (1, 45, 4, 4), (1, 45, 2, 2) seed_everything(1234) cls_score = [ torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1) @@ -201,3 +200,135 @@ def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str): atol=1e-05) else: assert rewrite_outputs is not None + + +def get_single_roi_extractor(): + """SingleRoIExtractor Config.""" + from mmrotate.models.roi_heads import RotatedSingleRoIExtractor + roi_layer = dict( + type='RoIAlignRotated', out_size=7, sample_num=2, clockwise=True) + out_channels = 1 + featmap_strides = [4, 8, 16, 32] + model = RotatedSingleRoIExtractor(roi_layer, out_channels, + featmap_strides).eval() + + return model + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_rotated_single_roi_extractor(backend_type: Backend): + check_backend(backend_type) + + single_roi_extractor = get_single_roi_extractor() + output_names = ['roi_feat'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + ))) + + seed_everything(1234) + out_channels = single_roi_extractor.out_channels + feats = [ + torch.rand((1, out_channels, 200, 336)), + torch.rand((1, out_channels, 100, 168)), + torch.rand((1, out_channels, 50, 84)), + torch.rand((1, out_channels, 25, 42)), + ] + seed_everything(5678) + rois = torch.tensor( + [[0.0000, 587.8285, 52.1405, 886.2484, 341.5644, 0.0000]]) + + model_inputs = { + 'feats': feats, + 'rois': rois, + } + model_outputs = get_model_outputs(single_roi_extractor, 'forward', + model_inputs) + + backend_outputs, _ = get_rewrite_outputs( + wrapped_model=single_roi_extractor, + model_inputs=model_inputs, + deploy_cfg=deploy_cfg) + if isinstance(backend_outputs, dict): + backend_outputs = backend_outputs.values() + for model_output, backend_output in zip(model_outputs[0], backend_outputs): + model_output = model_output.squeeze().cpu().numpy() + backend_output = backend_output.squeeze() + assert np.allclose( + model_output, backend_output, rtol=1e-03, atol=1e-05) + + +def get_oriented_rpn_head_model(): + """Oriented RPN Head Config.""" + test_cfg = mmcv.Config( + dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + from mmrotate.models.dense_heads import OrientedRPNHead + model = OrientedRPNHead( + in_channels=1, + version='le90', + bbox_coder=dict(type='MidpointOffsetCoder', angle_range='le90'), + test_cfg=test_cfg) + + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend): + check_backend(backend_type) + head = get_oriented_rpn_head_model() + head.cpu().eval() + s = 128 + img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + + output_names = ['dets', 'labels'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmrotate', + task='RotatedDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.1, + pre_top_k=2000, + keep_top_k=2000)))) + + # the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16), + # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2). + # the bboxes's size: (1, 54, 32, 32), (1, 54, 16, 16), + # (1, 54, 8, 8), (1, 54, 4, 4), (1, 54, 2, 2) + seed_everything(1234) + cls_score = [ + torch.rand(1, 9, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(5678) + bboxes = [torch.rand(1, 54, pow(2, i), pow(2, i)) for i in range(5, 0, -1)] + + # to get outputs of onnx model after rewrite + img_metas[0]['img_shape'] = torch.Tensor([s, s]) + wrapped_model = WrapModel( + head, 'get_bboxes', img_metas=img_metas, with_nms=True) + rewrite_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + assert rewrite_outputs is not None