[Feature] Support two stage rotated detector in MMRotate (#473)
* upload * add fake_multiclass_nms_rotated * delete unused code * align with pytorch * Update delta_midpointoffset_rbbox_coder.py * rewrite atan2 * Update bbox_nms.pypull/399/head^2
parent
e3a8baac4c
commit
89204d16ce
|
@ -1954,9 +1954,9 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
|
|||
<td>model config file</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" rowspan="2">RotatedRetinaNet</td>
|
||||
<td align="center" rowspan="2">Rotated Detection</td>
|
||||
<td align="center" rowspan="2">DOTA-v1.0</td>
|
||||
<td align="center">RotatedRetinaNet</td>
|
||||
<td align="center">Rotated Detection</td>
|
||||
<td align="center">DOTA-v1.0</td>
|
||||
<td align="center">mAP</td>
|
||||
<td align="center">0.698</td>
|
||||
<td align="center">0.698</td>
|
||||
|
@ -1964,7 +1964,20 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
|
|||
<td align="center">0.697</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td rowspan="2">$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py</td>
|
||||
<td>$MMROTATE_DIR/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Oriented RCNN</td>
|
||||
<td align="center">Rotated Detection</td>
|
||||
<td align="center">DOTA-v1.0</td>
|
||||
<td align="center">mAP</td>
|
||||
<td align="center">0.756</td>
|
||||
<td align="center">0.756</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td>$MMROTATE_DIR/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_le90.py</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
|
|
@ -11,6 +11,7 @@ Please refer to [official installation guide](https://mmrotate.readthedocs.io/en
|
|||
| Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config |
|
||||
|:----------------------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:|
|
||||
| RotatedRetinaNet | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
| Oriented RCNN | RotatedDetection | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
|
||||
|
||||
### Example
|
||||
|
||||
|
|
|
@ -1557,9 +1557,9 @@ GPU: ncnn, TensorRT, PPLNN
|
|||
<td align="center">fp32</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" rowspan="2"><a href="https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py">RotatedRetinaNet</a></td>
|
||||
<td align="center" rowspan="2">Rotated Detection</td>
|
||||
<td align="center" rowspan="2">DOTA-v1.0</td>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmrotate/tree/main/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_le135.py">RotatedRetinaNet</a></td>
|
||||
<td align="center">Rotated Detection</td>
|
||||
<td align="center">DOTA-v1.0</td>
|
||||
<td align="center">mAP</td>
|
||||
<td align="center">0.698</td>
|
||||
<td align="center">0.698</td>
|
||||
|
@ -1568,6 +1568,18 @@ GPU: ncnn, TensorRT, PPLNN
|
|||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmrotate/tree/main/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_le90.py">Oriented RCNN</a></td>
|
||||
<td align="center">Rotated Detection</td>
|
||||
<td align="center">DOTA-v1.0</td>
|
||||
<td align="center">mAP</td>
|
||||
<td align="center">0.756</td>
|
||||
<td align="center">0.756</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .delta_midpointoffset_rbbox_coder import * # noqa: F401,F403
|
||||
from .delta_xywha_rbbox_coder import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmrotate.core import poly2obb
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmrotate.core.bbox.coder.delta_midpointoffset_rbbox_coder'
|
||||
'.delta2bbox',
|
||||
backend='default')
|
||||
def delta2bbox(ctx,
|
||||
rois,
|
||||
deltas,
|
||||
means=(0., 0., 0., 0., 0., 0.),
|
||||
stds=(1., 1., 1., 1., 1., 1.),
|
||||
wh_ratio_clip=16 / 1000,
|
||||
version='oc'):
|
||||
"""Apply deltas to shift/scale base boxes.
|
||||
|
||||
Typically the rois are anchor or proposed bounding boxes and the deltas
|
||||
are network outputs used to shift/scale those boxes. This is the inverse
|
||||
function of :func:`bbox2delta`.
|
||||
|
||||
|
||||
Args:
|
||||
rois (torch.Tensor): Boxes to be transformed. Has shape (N, 6).
|
||||
deltas (torch.Tensor): Encoded offsets relative to each roi.
|
||||
Has shape (N, num_classes * 6) or (N, 6). Note
|
||||
N = num_base_anchors * W * H, when rois is a grid of
|
||||
anchors.
|
||||
means (Sequence[float]): Denormalizing means for delta coordinates.
|
||||
Default (0., 0., 0., 0., 0., 0.).
|
||||
stds (Sequence[float]): Denormalizing standard deviation for delta
|
||||
coordinates. Default (1., 1., 1., 1., 1., 1.).
|
||||
wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
|
||||
16 / 1000.
|
||||
version (str, optional): Angle representations. Defaults to 'oc'.
|
||||
|
||||
Returns:
|
||||
Tensor: Boxes with shape (N, num_classes * 5) or (N, 5), where 5
|
||||
represent cx, cy, w, h, a.
|
||||
"""
|
||||
means = deltas.new_tensor(means).view(1, -1)
|
||||
stds = deltas.new_tensor(stds).view(1, -1)
|
||||
delta_shape = deltas.shape
|
||||
reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 6))
|
||||
denorm_deltas = reshaped_deltas * stds + means
|
||||
|
||||
# means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 6)
|
||||
# stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 6)
|
||||
# denorm_deltas = deltas * stds + means
|
||||
dx = denorm_deltas[..., 0::6]
|
||||
dy = denorm_deltas[..., 1::6]
|
||||
dw = denorm_deltas[..., 2::6]
|
||||
dh = denorm_deltas[..., 3::6]
|
||||
da = denorm_deltas[..., 4::6]
|
||||
db = denorm_deltas[..., 5::6]
|
||||
max_ratio = np.abs(np.log(wh_ratio_clip))
|
||||
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
||||
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
||||
# Compute center of each roi
|
||||
px = ((rois[..., None, None, 0] + rois[..., None, None, 2]) * 0.5)
|
||||
py = ((rois[..., None, None, 1] + rois[..., None, None, 3]) * 0.5)
|
||||
# Compute width/height of each roi
|
||||
pw = (rois[..., None, None, 2] - rois[..., None, None, 0])
|
||||
ph = (rois[..., None, None, 3] - rois[..., None, None, 1])
|
||||
# Use exp(network energy) to enlarge/shrink each roi
|
||||
gw = pw * dw.exp()
|
||||
gh = ph * dh.exp()
|
||||
# Use network energy to shift the center of each roi
|
||||
gx = px + pw * dx
|
||||
gy = py + ph * dy
|
||||
|
||||
x1 = gx - gw * 0.5
|
||||
y1 = gy - gh * 0.5
|
||||
x2 = gx + gw * 0.5
|
||||
y2 = gy + gh * 0.5
|
||||
|
||||
da = da.clamp(min=-0.5, max=0.5)
|
||||
db = db.clamp(min=-0.5, max=0.5)
|
||||
ga = gx + da * gw
|
||||
_ga = gx - da * gw
|
||||
gb = gy + db * gh
|
||||
_gb = gy - db * gh
|
||||
polys = torch.stack([ga, y1, x2, gb, _ga, y2, x1, _gb], dim=-1)
|
||||
|
||||
center = torch.stack([gx, gy, gx, gy, gx, gy, gx, gy], dim=-1)
|
||||
center_polys = polys - center
|
||||
diag_len = torch.sqrt(center_polys[..., 0::2] * center_polys[..., 0::2] +
|
||||
center_polys[..., 1::2] * center_polys[..., 1::2])
|
||||
max_diag_len, _ = torch.max(diag_len, dim=-1, keepdim=True)
|
||||
diag_scale_factor = max_diag_len / diag_len
|
||||
center_polys_shape = center_polys.shape
|
||||
center_polys = center_polys.view(*center_polys_shape[:3], 4,
|
||||
-1) * diag_scale_factor.view(
|
||||
*center_polys_shape[:3], 4, 1)
|
||||
center_polys = center_polys.view(center_polys_shape)
|
||||
rectpolys = center_polys + center
|
||||
obboxes = poly2obb(rectpolys, version).view(delta_shape[:-1] + (5, ))
|
||||
|
||||
return obboxes
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bbox_nms import multiclass_nms_rotated
|
||||
from .bbox_nms import fake_multiclass_nms_rotated, multiclass_nms_rotated
|
||||
|
||||
__all__ = ['multiclass_nms_rotated']
|
||||
__all__ = ['multiclass_nms_rotated', 'fake_multiclass_nms_rotated']
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmrotate.core import obb2xyxy
|
||||
from torch import Tensor
|
||||
|
||||
import mmdeploy
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.mmcv.ops import ONNXNMSRotatedOp, TRTBatchedRotatedNMSop
|
||||
from mmdeploy.mmcv.ops import (ONNXNMSop, ONNXNMSRotatedOp,
|
||||
TRTBatchedRotatedNMSop)
|
||||
|
||||
|
||||
def select_nms_index(scores: torch.Tensor,
|
||||
boxes: torch.Tensor,
|
||||
nms_index: torch.Tensor,
|
||||
batch_size: int,
|
||||
keep_top_k: int = -1):
|
||||
def select_rnms_index(scores: torch.Tensor,
|
||||
boxes: torch.Tensor,
|
||||
nms_index: torch.Tensor,
|
||||
batch_size: int,
|
||||
keep_top_k: int = -1):
|
||||
"""Transform NMSRotated output.
|
||||
|
||||
Args:
|
||||
|
@ -85,6 +87,7 @@ def _multiclass_nms_rotated(boxes: Tensor,
|
|||
op. It only supports class-agnostic detection results. That is, the scores
|
||||
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
|
||||
(N, num_boxes, 5).
|
||||
|
||||
Args:
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5].
|
||||
scores (Tensor): The detection scores of shape
|
||||
|
@ -114,7 +117,7 @@ def _multiclass_nms_rotated(boxes: Tensor,
|
|||
selected_indices = ONNXNMSRotatedOp.apply(boxes, scores, iou_threshold,
|
||||
score_threshold)
|
||||
|
||||
dets, labels = select_nms_index(
|
||||
dets, labels = select_rnms_index(
|
||||
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
|
||||
|
||||
return dets, labels
|
||||
|
@ -173,3 +176,47 @@ def multiclass_nms_rotated(*args, **kwargs):
|
|||
"""Wrapper function for `_multiclass_nms`."""
|
||||
return mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms.\
|
||||
_multiclass_nms_rotated(*args, **kwargs)
|
||||
|
||||
|
||||
@mark(
|
||||
'fake_multiclass_nms_rotated',
|
||||
inputs=['boxes', 'scores'],
|
||||
outputs=['dets', 'labels'])
|
||||
def fake_multiclass_nms_rotated(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
score_threshold: float = 0.0,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = -1,
|
||||
version: str = 'le90'):
|
||||
"""Fake NMSRotated for multi-class bboxes which use horizontal bboxes for
|
||||
NMS, but return the rotated bboxes result.
|
||||
|
||||
This function helps exporting to onnx with batch and multiclass NMS op. It
|
||||
only supports class-agnostic detection results. That is, the scores is of
|
||||
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
|
||||
5).
|
||||
"""
|
||||
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
|
||||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
||||
batch_size = scores.shape[0]
|
||||
|
||||
if pre_top_k > 0:
|
||||
max_scores, _ = scores.max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_top_k)
|
||||
batch_inds = torch.arange(batch_size).view(-1, 1).long()
|
||||
boxes = boxes[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
|
||||
scores = scores.permute(0, 2, 1)
|
||||
hboxes = obb2xyxy(boxes, version)
|
||||
selected_indices = ONNXNMSop.apply(hboxes, scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold)
|
||||
|
||||
dets, labels = select_rnms_index(
|
||||
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
|
||||
|
||||
return dets, labels
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .oriented_standard_roi_head import (
|
||||
oriented_standard_roi_head__simple_test,
|
||||
oriented_standard_roi_head__simple_test_bboxes)
|
||||
from .rotated_anchor_head import rotated_anchor_head__get_bbox
|
||||
from .rotated_bbox_head import rotated_bbox_head__get_bboxes
|
||||
from .rotated_rpn_head import rotated_rpn_head__get_bboxes
|
||||
from .single_stage_rotated_detector import \
|
||||
single_stage_rotated_detector__simple_test
|
||||
|
||||
__all__ = [
|
||||
'single_stage_rotated_detector__simple_test',
|
||||
'rotated_anchor_head__get_bbox'
|
||||
'rotated_anchor_head__get_bbox', 'rotated_rpn_head__get_bboxes',
|
||||
'oriented_standard_roi_head__simple_test',
|
||||
'oriented_standard_roi_head__simple_test_bboxes',
|
||||
'rotated_bbox_head__get_bboxes'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.roi_heads.oriented_standard_roi_head'
|
||||
'.OrientedStandardRoIHead.simple_test')
|
||||
def oriented_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` of `StandardRoIHead` for default backend.
|
||||
|
||||
This function returns detection result as Tensor instead of numpy
|
||||
array.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
x (tuple[Tensor]): Features from upstream network. Each
|
||||
has shape (batch_size, c, h, w).
|
||||
proposals (list(Tensor)): Proposals from rpn head.
|
||||
Each has shape (num_proposals, 6), last dimension
|
||||
6 represent (x, y, w, h, theta, score).
|
||||
img_metas (list[dict]): Meta information of images.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (det_bboxes, det_labels),
|
||||
`det_bboxes` of shape [N, num_det, 6] and `det_labels`
|
||||
of shape [N, num_det].
|
||||
"""
|
||||
assert self.with_bbox, 'Bbox head must be implemented.'
|
||||
det_bboxes, det_labels = self.simple_test_bboxes(
|
||||
x, img_metas, proposals, self.test_cfg, rescale=False)
|
||||
|
||||
return det_bboxes, det_labels
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.roi_heads.oriented_standard_roi_head'
|
||||
'.OrientedStandardRoIHead.simple_test_bboxes')
|
||||
def oriented_standard_roi_head__simple_test_bboxes(ctx,
|
||||
self,
|
||||
x,
|
||||
img_metas,
|
||||
proposals,
|
||||
rcnn_test_cfg,
|
||||
rescale=False):
|
||||
"""Test only det bboxes without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Feature maps of all scale level.
|
||||
img_metas (list[dict]): Image meta info.
|
||||
proposals (List[Tensor]): Region proposals.
|
||||
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
tuple[list[Tensor], list[Tensor]]: The first list contains \
|
||||
the boxes of the corresponding image in a batch, each \
|
||||
tensor has the shape (num_boxes, 6) and last dimension \
|
||||
6 represent (x, y, w, h, theta, score). Each Tensor \
|
||||
in the second list is the labels with shape (num_boxes, ). \
|
||||
The length of both lists should be equal to batch_size.
|
||||
"""
|
||||
|
||||
rois, labels = proposals
|
||||
batch_index = torch.arange(
|
||||
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
|
||||
rois.size(0), rois.size(1), 1)
|
||||
rois = torch.cat([batch_index, rois[..., :5]], dim=-1)
|
||||
batch_size = rois.shape[0]
|
||||
num_proposals_per_img = rois.shape[1]
|
||||
|
||||
# Eliminate the batch dimension
|
||||
rois = rois.view(-1, 6)
|
||||
bbox_results = self._bbox_forward(x, rois)
|
||||
cls_score = bbox_results['cls_score']
|
||||
bbox_pred = bbox_results['bbox_pred']
|
||||
|
||||
# Recover the batch dimension
|
||||
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
|
||||
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
|
||||
cls_score.size(-1))
|
||||
|
||||
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
|
||||
bbox_pred.size(-1))
|
||||
det_bboxes, det_labels = self.bbox_head.get_bboxes(
|
||||
rois,
|
||||
cls_score,
|
||||
bbox_pred,
|
||||
img_metas[0]['img_shape'],
|
||||
None,
|
||||
rescale=rescale,
|
||||
cfg=rcnn_test_cfg)
|
||||
return det_bboxes, det_labels
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.codebase.mmdet import get_post_processing_params
|
||||
from mmdeploy.codebase.mmrotate.core.post_processing import \
|
||||
multiclass_nms_rotated
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.roi_heads.bbox_heads.RotatedBBoxHead.get_bboxes')
|
||||
def rotated_bbox_head__get_bboxes(ctx,
|
||||
self,
|
||||
rois,
|
||||
cls_score,
|
||||
bbox_pred,
|
||||
img_shape,
|
||||
scale_factor,
|
||||
rescale=False,
|
||||
cfg=None):
|
||||
"""Transform network output for a batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
rois (torch.Tensor): Boxes to be transformed. Has shape
|
||||
(num_boxes, 6). last dimension 5 arrange as
|
||||
(batch_index, x, y, w, h, theta).
|
||||
cls_score (torch.Tensor): Box scores, has shape
|
||||
(num_boxes, num_classes + 1).
|
||||
bbox_pred (Tensor, optional): Box energies / deltas.
|
||||
has shape (num_boxes, num_classes * 6).
|
||||
img_shape (Sequence[int], optional): Maximum bounds for boxes,
|
||||
specifies (H, W, C) or (H, W).
|
||||
scale_factor (ndarray): Scale factor of the
|
||||
image arrange as (w_scale, h_scale, w_scale, h_scale).
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default: False.
|
||||
cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]:
|
||||
First tensor is `det_bboxes`, has the shape
|
||||
(num_boxes, 6) and last
|
||||
dimension 6 represent (cx, cy, w, h, theta, score).
|
||||
Second tensor is the labels with shape (num_boxes, ).
|
||||
"""
|
||||
assert rois.ndim == 3, 'Only support export two stage ' \
|
||||
'model to ONNX ' \
|
||||
'with batch dimension. '
|
||||
|
||||
if self.custom_cls_channels:
|
||||
scores = self.loss_cls.get_activation(cls_score)
|
||||
else:
|
||||
scores = F.softmax(
|
||||
cls_score, dim=-1) if cls_score is not None else None
|
||||
|
||||
assert bbox_pred is not None
|
||||
bboxes = self.bbox_coder.decode(
|
||||
rois[..., 1:], bbox_pred, max_shape=img_shape)
|
||||
|
||||
# ignore background class
|
||||
scores = scores[..., :self.num_classes]
|
||||
|
||||
post_params = get_post_processing_params(ctx.cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
return multiclass_nms_rotated(
|
||||
bboxes,
|
||||
scores,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
||||
pad_with_value_if_necessary)
|
||||
from mmdeploy.codebase.mmrotate.core.post_processing import \
|
||||
fake_multiclass_nms_rotated
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.dense_heads.RotatedRPNHead.get_bboxes')
|
||||
def rotated_rpn_head__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
score_factors=None,
|
||||
img_metas=None,
|
||||
cfg=None,
|
||||
rescale=False,
|
||||
with_nms=True,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of `RPNHead` for default backend.
|
||||
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (FoveaHead): The instance of the class FoveaHead.
|
||||
cls_scores (list[Tensor]): Box scores for each scale level
|
||||
with shape (N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
||||
level with shape (N, num_anchors * 4, H, W).
|
||||
score_factors (list[Tensor], Optional): Score factor for
|
||||
all scale level, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 1, H, W). Default None.
|
||||
img_metas (list[dict]): Meta information of the image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
||||
device = cls_scores[0].device
|
||||
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
|
||||
mlvl_anchors = self.anchor_generator.grid_anchors(
|
||||
featmap_sizes, device=device)
|
||||
|
||||
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
batch_size = mlvl_cls_scores[0].shape[0]
|
||||
pre_topk = cfg.get('nms_pre', -1)
|
||||
|
||||
# loop over features, decode boxes
|
||||
mlvl_valid_bboxes = []
|
||||
mlvl_scores = []
|
||||
mlvl_valid_anchors = []
|
||||
for level_id, cls_score, bbox_pred, anchors in zip(
|
||||
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
cls_score = cls_score.permute(0, 2, 3, 1)
|
||||
if self.use_sigmoid_cls:
|
||||
cls_score = cls_score.reshape(batch_size, -1)
|
||||
scores = cls_score.sigmoid()
|
||||
else:
|
||||
cls_score = cls_score.reshape(batch_size, -1, 2)
|
||||
# We set FG labels to [0, num_class-1] and BG label to
|
||||
# num_class in RPN head since mmdet v2.5, which is unified to
|
||||
# be consistent with other head since mmdet v2.0. In mmdet v2.0
|
||||
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
|
||||
scores = cls_score.softmax(-1)[..., 0]
|
||||
scores = scores.reshape(batch_size, -1, 1)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6)
|
||||
|
||||
# use static anchor if input shape is static
|
||||
if not is_dynamic_flag:
|
||||
anchors = anchors.data
|
||||
|
||||
# anchors = anchors.expand_as(bbox_pred)
|
||||
anchors = anchors.expand(batch_size, -1, anchors.size(-1))
|
||||
|
||||
# topk in tensorrt does not support shape<k
|
||||
# concate zero to enable topk,
|
||||
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
|
||||
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
|
||||
anchors = pad_with_value_if_necessary(anchors, 1, pre_topk)
|
||||
|
||||
if pre_topk > 0:
|
||||
_, topk_inds = scores.squeeze(2).topk(pre_topk)
|
||||
batch_inds = torch.arange(
|
||||
batch_size, device=device).view(-1, 1).expand_as(topk_inds)
|
||||
anchors = anchors[batch_inds, topk_inds, :]
|
||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_valid_anchors.append(anchors)
|
||||
|
||||
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
|
||||
batch_mlvl_bboxes = self.bbox_coder.decode(
|
||||
batch_mlvl_anchors,
|
||||
batch_mlvl_bboxes,
|
||||
max_shape=img_metas[0]['img_shape'])
|
||||
# ignore background class
|
||||
if not self.use_sigmoid_cls:
|
||||
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
|
||||
if not with_nms:
|
||||
return batch_mlvl_bboxes, batch_mlvl_scores
|
||||
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
# only one class in rpn
|
||||
max_output_boxes_per_class = keep_top_k
|
||||
return fake_multiclass_nms_rotated(
|
||||
batch_mlvl_bboxes,
|
||||
batch_mlvl_scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
keep_top_k=keep_top_k,
|
||||
version=self.version)
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .atan2 import atan2__default
|
||||
from .chunk import chunk__ncnn
|
||||
from .getattribute import tensor__getattribute__ncnn
|
||||
from .group_norm import group_norm__ncnn
|
||||
|
@ -13,5 +14,5 @@ __all__ = [
|
|||
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
|
||||
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
|
||||
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
|
||||
'triu'
|
||||
'triu', 'atan2__default'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.atan2', backend='default')
|
||||
def atan2__default(
|
||||
ctx,
|
||||
input1: torch.Tensor,
|
||||
input2: torch.Tensor,
|
||||
):
|
||||
"""Rewrite `atan2` for default backend."""
|
||||
return torch.atan(input1 / (input2 + 1e-6))
|
|
@ -134,8 +134,10 @@ def test_multiclass_nms_rotated_with_keep_top_k(pre_top_k):
|
|||
@pytest.mark.parametrize('max_shape,proj_xy,edge_swap',
|
||||
[(None, False, False),
|
||||
(torch.tensor([100, 200]), True, True)])
|
||||
def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool,
|
||||
max_shape: tuple, proj_xy: bool, edge_swap: bool):
|
||||
def test_delta_xywha_rbbox_coder_delta2bbox(backend_type: Backend,
|
||||
add_ctr_clamp: bool,
|
||||
max_shape: tuple, proj_xy: bool,
|
||||
edge_swap: bool):
|
||||
check_backend(backend_type)
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
|
@ -181,3 +183,90 @@ def test_delta2bbox(backend_type: Backend, add_ctr_clamp: bool,
|
|||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_delta_midpointoffset_rbbox_delta2bbox(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(output_names=None, input_shape=None),
|
||||
backend_config=dict(type=backend_type.value, model_inputs=None),
|
||||
codebase_config=dict(type='mmrotate', task='RotatedDetection')))
|
||||
|
||||
# wrap function to enable rewrite
|
||||
def delta2bbox(*args, **kwargs):
|
||||
import mmrotate
|
||||
return mmrotate.core.bbox.coder.delta_midpointoffset_rbbox_coder\
|
||||
.delta2bbox(*args, **kwargs)
|
||||
|
||||
rois = torch.rand(5, 4)
|
||||
deltas = torch.rand(5, 6)
|
||||
original_outputs = delta2bbox(rois, deltas, version='le90')
|
||||
|
||||
# wrap function to nn.Module, enable torch.onnx.export
|
||||
wrapped_func = WrapFunction(delta2bbox)
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={
|
||||
'rois': rois.unsqueeze(0),
|
||||
'deltas': deltas.unsqueeze(0)
|
||||
},
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
model_output = original_outputs.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_outputs[0].squeeze().cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output[:, :4], rewrite_output[:, :4], rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
def test_fake_multiclass_nms_rotated():
|
||||
from mmdeploy.codebase.mmrotate.core import fake_multiclass_nms_rotated
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(output_names=None, input_shape=None),
|
||||
backend_config=dict(
|
||||
type='onnxruntime',
|
||||
common_config=dict(
|
||||
fp16_mode=False, max_workspace_size=1 << 20),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
boxes=dict(
|
||||
min_shape=[1, 5, 5],
|
||||
opt_shape=[1, 5, 5],
|
||||
max_shape=[1, 5, 5]),
|
||||
scores=dict(
|
||||
min_shape=[1, 5, 8],
|
||||
opt_shape=[1, 5, 8],
|
||||
max_shape=[1, 5, 8])))
|
||||
]),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=10,
|
||||
))))
|
||||
|
||||
boxes = torch.rand(1, 5, 5)
|
||||
scores = torch.rand(1, 5, 8)
|
||||
keep_top_k = 10
|
||||
wrapped_func = WrapFunction(
|
||||
fake_multiclass_nms_rotated, keep_top_k=keep_top_k)
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_func,
|
||||
model_inputs={
|
||||
'boxes': boxes,
|
||||
'scores': scores
|
||||
},
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
assert rewrite_outputs is not None, 'Got unexpected rewrite '\
|
||||
'outputs: {}'.format(rewrite_outputs)
|
||||
|
|
|
@ -74,47 +74,46 @@ def _replace_r50_with_r18(model):
|
|||
return model
|
||||
|
||||
|
||||
# @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
# @pytest.mark.parametrize('model_cfg_path', [
|
||||
# 'tests/test_codebase/test_mmrotate/data/single_stage_model.json'
|
||||
# ])
|
||||
# def test_forward_of_base_detector(model_cfg_path, backend):
|
||||
# check_backend(backend)
|
||||
# deploy_cfg = mmcv.Config(
|
||||
# dict(
|
||||
# backend_config=dict(type=backend.value),
|
||||
# onnx_config=dict(
|
||||
# output_names=['dets', 'labels'], input_shape=None),
|
||||
# codebase_config=dict(
|
||||
# type='mmrotate',
|
||||
# task='RotatedDetection',
|
||||
# post_processing=dict(
|
||||
# score_threshold=0.05,
|
||||
# iou_threshold=0.5,
|
||||
# pre_top_k=-1,
|
||||
# keep_top_k=100,
|
||||
# ))))
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.parametrize(
|
||||
'model_cfg_path',
|
||||
['tests/test_codebase/test_mmrotate/data/single_stage_model.json'])
|
||||
def test_forward_of_base_detector(model_cfg_path, backend):
|
||||
check_backend(backend)
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
output_names=['dets', 'labels'], input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=100,
|
||||
))))
|
||||
|
||||
# model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
|
||||
# model_cfg.model = _replace_r50_with_r18(model_cfg.model)
|
||||
model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
|
||||
model_cfg.model = _replace_r50_with_r18(model_cfg.model)
|
||||
|
||||
# from mmrotate.models import build_detector
|
||||
from mmrotate.models import build_detector
|
||||
|
||||
# model_cfg.model.pretrained = None
|
||||
# model_cfg.model.train_cfg = None
|
||||
# model = build_detector(
|
||||
# model_cfg.model, test_cfg= model_cfg.get('test_cfg'))
|
||||
# model.cfg = model_cfg
|
||||
# model.to('cpu')
|
||||
model_cfg.model.pretrained = None
|
||||
model_cfg.model.train_cfg = None
|
||||
model = build_detector(model_cfg.model, test_cfg=model_cfg.get('test_cfg'))
|
||||
model.cfg = model_cfg
|
||||
model.to('cpu')
|
||||
|
||||
# img = torch.randn(1, 3, 64, 64)
|
||||
# rewrite_inputs = {'img': img}
|
||||
# rewrite_outputs, _ = get_rewrite_outputs(
|
||||
# wrapped_model=model,
|
||||
# model_inputs=rewrite_inputs,
|
||||
# deploy_cfg=deploy_cfg)
|
||||
img = torch.randn(1, 3, 64, 64)
|
||||
rewrite_inputs = {'img': img}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
# assert rewrite_outputs is not None
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
def get_deploy_cfg(backend_type: Backend, ir_type: str):
|
||||
|
@ -155,8 +154,8 @@ def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str):
|
|||
|
||||
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
||||
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
||||
# the bboxes's size: (1, 45, 32, 32), (1, 45, 16, 16),
|
||||
# (1, 45, 8, 8), (1, 45, 4, 4), (1, 45, 2, 2)
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
|
@ -201,3 +200,135 @@ def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str):
|
|||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
def get_single_roi_extractor():
|
||||
"""SingleRoIExtractor Config."""
|
||||
from mmrotate.models.roi_heads import RotatedSingleRoIExtractor
|
||||
roi_layer = dict(
|
||||
type='RoIAlignRotated', out_size=7, sample_num=2, clockwise=True)
|
||||
out_channels = 1
|
||||
featmap_strides = [4, 8, 16, 32]
|
||||
model = RotatedSingleRoIExtractor(roi_layer, out_channels,
|
||||
featmap_strides).eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_rotated_single_roi_extractor(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
|
||||
single_roi_extractor = get_single_roi_extractor()
|
||||
output_names = ['roi_feat']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
)))
|
||||
|
||||
seed_everything(1234)
|
||||
out_channels = single_roi_extractor.out_channels
|
||||
feats = [
|
||||
torch.rand((1, out_channels, 200, 336)),
|
||||
torch.rand((1, out_channels, 100, 168)),
|
||||
torch.rand((1, out_channels, 50, 84)),
|
||||
torch.rand((1, out_channels, 25, 42)),
|
||||
]
|
||||
seed_everything(5678)
|
||||
rois = torch.tensor(
|
||||
[[0.0000, 587.8285, 52.1405, 886.2484, 341.5644, 0.0000]])
|
||||
|
||||
model_inputs = {
|
||||
'feats': feats,
|
||||
'rois': rois,
|
||||
}
|
||||
model_outputs = get_model_outputs(single_roi_extractor, 'forward',
|
||||
model_inputs)
|
||||
|
||||
backend_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=single_roi_extractor,
|
||||
model_inputs=model_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if isinstance(backend_outputs, dict):
|
||||
backend_outputs = backend_outputs.values()
|
||||
for model_output, backend_output in zip(model_outputs[0], backend_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
backend_output = backend_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, backend_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def get_oriented_rpn_head_model():
|
||||
"""Oriented RPN Head Config."""
|
||||
test_cfg = mmcv.Config(
|
||||
dict(
|
||||
nms_pre=2000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(iou_thr=0.1),
|
||||
max_per_img=2000))
|
||||
from mmrotate.models.dense_heads import OrientedRPNHead
|
||||
model = OrientedRPNHead(
|
||||
in_channels=1,
|
||||
version='le90',
|
||||
bbox_coder=dict(type='MidpointOffsetCoder', angle_range='le90'),
|
||||
test_cfg=test_cfg)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
head = get_oriented_rpn_head_model()
|
||||
head.cpu().eval()
|
||||
s = 128
|
||||
img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
output_names = ['dets', 'labels']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=2000,
|
||||
keep_top_k=2000))))
|
||||
|
||||
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
||||
# the bboxes's size: (1, 54, 32, 32), (1, 54, 16, 16),
|
||||
# (1, 54, 8, 8), (1, 54, 4, 4), (1, 54, 2, 2)
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, 9, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [torch.rand(1, 54, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
wrapped_model = WrapModel(
|
||||
head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
assert rewrite_outputs is not None
|
||||
|
|
Loading…
Reference in New Issue