[Enhancement] Support Slide Vertex TRT (#650)

* reorgnize mmrotate

* fix

* add hbb2obb

* add ut

* fix rotated nms

* update docs

* update benchmark

* update test

* remove ort regression test, remove comment
pull/754/head
q.yao 2022-07-13 16:09:09 +08:00 committed by GitHub
parent 14b2bfd524
commit dace58e844
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 703 additions and 32 deletions

View File

@ -6,4 +6,5 @@ codebase_config = dict(
score_threshold=0.05,
iou_threshold=0.1,
pre_top_k=3000,
keep_top_k=2000))
keep_top_k=2000,
max_output_boxes_per_class=2000))

View File

@ -295,7 +295,7 @@ __host__ __device__ __forceinline__ T single_box_iou_rotated(T const *const box1
const T area1 = box1.w * box1.h;
const T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
return 1.0f;
}
const T intersection = rotated_boxes_intersection<T>(box1, box2);

View File

@ -1641,6 +1641,18 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center"><a href="https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/gliding_vertex_r50_fpn_1x_dota_le90.py">GlidingVertex</a></td>
<td align="center">Rotated Detection</td>
<td align="center">DOTA-v1.0</td>
<td align="center">mAP</td>
<td align="center">0.732</td>
<td align="center">-</td>
<td align="center">0.733</td>
<td align="center">0.731</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
</tbody>
</table>
</div>

View File

@ -73,6 +73,8 @@ The table below lists the models that are guaranteed to be exportable to other b
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
### Note

View File

@ -12,6 +12,7 @@ Please refer to [official installation guide](https://mmrotate.readthedocs.io/en
| :--------------- | :--------------- | :----------: | :------: | :--: | :---: | :------: | :--------------------------------------------------------------------------------------------: |
| RotatedRetinaNet | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
| Oriented RCNN | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
| Gliding Vertex | RotatedDetection | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
### Example

View File

@ -1638,6 +1638,18 @@ GPU: ncnn, TensorRT, PPLNN
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center"><a href="https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/gliding_vertex_r50_fpn_1x_dota_le90.py">GlidingVertex</a></td>
<td align="center">Rotated Detection</td>
<td align="center">DOTA-v1.0</td>
<td align="center">mAP</td>
<td align="center">0.732</td>
<td align="center">-</td>
<td align="center">0.733</td>
<td align="center">0.731</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
</tbody>
</table>
</div>

View File

@ -72,6 +72,7 @@
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
## Note

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .delta_midpointoffset_rbbox_coder import * # noqa: F401,F403
from .delta_xywha_rbbox_coder import * # noqa: F401,F403
from .gliding_vertex_coder import * # noqa: F401,F403
from .transforms import * # noqa: F401,F403

View File

@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.core.bbox.coder.gliding_vertex_coder'
'.GVFixCoder.decode')
def gvfixcoder__decode(ctx, self, hbboxes, fix_deltas):
"""Rewriter for GVFixCoder decode, support more dimension input."""
from mmrotate.core.bbox.transforms import poly2obb
x1 = hbboxes[..., 0::4]
y1 = hbboxes[..., 1::4]
x2 = hbboxes[..., 2::4]
y2 = hbboxes[..., 3::4]
w = hbboxes[..., 2::4] - hbboxes[..., 0::4]
h = hbboxes[..., 3::4] - hbboxes[..., 1::4]
pred_t_x = x1 + w * fix_deltas[..., 0::4]
pred_r_y = y1 + h * fix_deltas[..., 1::4]
pred_d_x = x2 - w * fix_deltas[..., 2::4]
pred_l_y = y2 - h * fix_deltas[..., 3::4]
polys = torch.stack(
[pred_t_x, y1, x2, pred_r_y, pred_d_x, y2, x1, pred_l_y], dim=-1)
polys = polys.flatten(2)
rbboxes = poly2obb(polys, self.version)
return rbboxes

View File

@ -77,6 +77,7 @@ def select_rnms_index(scores: torch.Tensor,
def _multiclass_nms_rotated(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.1,
score_threshold: float = 0.05,
pre_top_k: int = -1,

View File

@ -75,6 +75,33 @@ class End2EndModel(BaseBackendModel):
output_names=output_names,
deploy_cfg=self.deploy_cfg)
@staticmethod
def __clear_outputs(
test_outputs: List[Union[torch.Tensor, np.ndarray]]
) -> List[Union[List[torch.Tensor], List[np.ndarray]]]:
"""Removes additional outputs and detections with zero and negative
score.
Args:
test_outputs (List[Union[torch.Tensor, np.ndarray]]):
outputs of forward_test.
Returns:
List[Union[List[torch.Tensor], List[np.ndarray]]]:
outputs with without zero score object.
"""
batch_size = len(test_outputs[0])
num_outputs = len(test_outputs)
outputs = [[None for _ in range(batch_size)]
for _ in range(num_outputs)]
for i in range(batch_size):
inds = test_outputs[0][i, :, -1] > 0.0
for output_id in range(num_outputs):
outputs[output_id][i] = test_outputs[output_id][i, inds, ...]
return outputs
def forward(self, img: Sequence[torch.Tensor],
img_metas: Sequence[Sequence[dict]], *args, **kwargs) -> list:
"""Run forward inference.
@ -91,6 +118,7 @@ class End2EndModel(BaseBackendModel):
input_img = img[0].contiguous()
img_metas = img_metas[0]
outputs = self.forward_test(input_img, img_metas, *args, **kwargs)
outputs = End2EndModel.__clear_outputs(outputs)
batch_dets, batch_labels = outputs[:2]
batch_size = input_img.shape[0]
rescale = kwargs.get('rescale', False)

View File

@ -1,19 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .oriented_standard_roi_head import (
oriented_standard_roi_head__simple_test,
oriented_standard_roi_head__simple_test_bboxes)
from .roi_extractors import rotated_single_roi_extractor__forward__tensorrt
from .rotated_anchor_head import rotated_anchor_head__get_bbox
from .rotated_bbox_head import rotated_bbox_head__get_bboxes
from .rotated_rpn_head import rotated_rpn_head__get_bboxes
from .single_stage_rotated_detector import \
single_stage_rotated_detector__simple_test
__all__ = [
'single_stage_rotated_detector__simple_test',
'rotated_anchor_head__get_bbox', 'rotated_rpn_head__get_bboxes',
'oriented_standard_roi_head__simple_test',
'oriented_standard_roi_head__simple_test_bboxes',
'rotated_bbox_head__get_bboxes',
'rotated_single_roi_extractor__forward__tensorrt'
]
from .dense_heads import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403
from .single_stage_rotated_detector import * # noqa: F401,F403

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .oriented_rpn_head import oriented_rpn_head__get_bboxes
from .rotated_anchor_head import rotated_anchor_head__get_bbox
from .rotated_rpn_head import rotated_rpn_head__get_bboxes
__all__ = [
'oriented_rpn_head__get_bboxes', 'rotated_anchor_head__get_bbox',
'rotated_rpn_head__get_bboxes'
]

View File

@ -0,0 +1,141 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.codebase.mmdet import (get_post_processing_params,
pad_with_value_if_necessary)
from mmdeploy.codebase.mmrotate.core.post_processing import \
fake_multiclass_nms_rotated
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.dense_heads.OrientedRPNHead.get_bboxes')
def oriented_rpn_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
score_factors=None,
img_metas=None,
cfg=None,
rescale=False,
with_nms=True,
**kwargs):
"""Rewrite `get_bboxes` of `RPNHead` for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
self (FoveaHead): The instance of the class FoveaHead.
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
score_factors (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Default None.
img_metas (list[dict]): Meta information of the image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
cfg = self.test_cfg if cfg is None else cfg
batch_size = mlvl_cls_scores[0].shape[0]
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_valid_bboxes = []
mlvl_scores = []
mlvl_valid_anchors = []
for level_id, cls_score, bbox_pred, anchors in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3, 1)
if self.use_sigmoid_cls:
cls_score = cls_score.reshape(batch_size, -1)
scores = cls_score.sigmoid()
else:
cls_score = cls_score.reshape(batch_size, -1, 2)
# We set FG labels to [0, num_class-1] and BG label to
# num_class in RPN head since mmdet v2.5, which is unified to
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = cls_score.softmax(-1)[..., 0]
scores = scores.reshape(batch_size, -1, 1)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6)
# use static anchor if input shape is static
if not is_dynamic_flag:
anchors = anchors.data
anchors = anchors.unsqueeze(0)
# topk in tensorrt does not support shape<k
# concate zero to enable topk,
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
anchors = pad_with_value_if_necessary(anchors, 1, pre_topk)
if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1)
prior_inds = topk_inds.new_zeros((1, 1))
anchors = anchors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_bboxes = self.bbox_coder.decode(
batch_mlvl_anchors,
batch_mlvl_bboxes,
max_shape=img_metas[0]['img_shape'])
# ignore background class
if not self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
post_params = get_post_processing_params(deploy_cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# only one class in rpn
max_output_boxes_per_class = keep_top_k
return fake_multiclass_nms_rotated(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
keep_top_k=keep_top_k,
version=self.version)

View File

@ -3,8 +3,7 @@ import torch
from mmdeploy.codebase.mmdet import (get_post_processing_params,
pad_with_value_if_necessary)
from mmdeploy.codebase.mmrotate.core.post_processing import \
fake_multiclass_nms_rotated
from mmdeploy.codebase.mmdet.core.post_processing import multiclass_nms
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape
@ -89,7 +88,7 @@ def rotated_rpn_head__get_bboxes(ctx,
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = cls_score.softmax(-1)[..., 0]
scores = scores.reshape(batch_size, -1, 1)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
# use static anchor if input shape is static
if not is_dynamic_flag:
@ -129,13 +128,16 @@ def rotated_rpn_head__get_bboxes(ctx,
post_params = get_post_processing_params(deploy_cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# only one class in rpn
max_output_boxes_per_class = keep_top_k
return fake_multiclass_nms_rotated(
return multiclass_nms(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
keep_top_k=keep_top_k,
version=self.version)
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)

View File

@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .gv_bbox_head import gv_bbox_head__get_bboxes
from .gv_ratio_roi_head import gv_ratio_roi_head__simple_test_bboxes
from .oriented_standard_roi_head import \
oriented_standard_roi_head__simple_test_bboxes
from .roi_extractors import rotated_single_roi_extractor__forward__tensorrt
from .rotated_bbox_head import rotated_bbox_head__get_bboxes
__all__ = [
'gv_bbox_head__get_bboxes', 'gv_ratio_roi_head__simple_test_bboxes',
'oriented_standard_roi_head__simple_test_bboxes',
'rotated_single_roi_extractor__forward__tensorrt',
'rotated_bbox_head__get_bboxes'
]

View File

@ -0,0 +1,90 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.codebase.mmrotate.core.post_processing import \
multiclass_nms_rotated
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.roi_heads.bbox_heads.GVBBoxHead.get_bboxes')
def gv_bbox_head__get_bboxes(ctx,
self,
rois,
cls_score,
bbox_pred,
fix_pred,
ratio_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None):
"""Transform network output for a batch into bbox predictions.
Args:
rois (torch.Tensor): Boxes to be transformed. Has shape
(num_boxes, 6). last dimension 5 arrange as
(batch_index, x, y, w, h, theta).
cls_score (torch.Tensor): Box scores, has shape
(num_boxes, num_classes + 1).
bbox_pred (Tensor, optional): Box energies / deltas.
has shape (num_boxes, num_classes * 6).
img_shape (Sequence[int], optional): Maximum bounds for boxes,
specifies (H, W, C) or (H, W).
scale_factor (ndarray): Scale factor of the
image arrange as (w_scale, h_scale, w_scale, h_scale).
rescale (bool): If True, return boxes in original image space.
Default: False.
cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
Returns:
tuple[Tensor, Tensor]:
First tensor is `det_bboxes`, has the shape
(num_boxes, 6) and last
dimension 6 represent (cx, cy, w, h, theta, score).
Second tensor is the labels with shape (num_boxes, ).
"""
assert rois.ndim == 3, 'Only support export two stage ' \
'model to ONNX ' \
'with batch dimension. '
if self.custom_cls_channels:
scores = self.loss_cls.get_activation(cls_score)
else:
scores = F.softmax(
cls_score, dim=-1) if cls_score is not None else None
assert bbox_pred is not None
bboxes = self.bbox_coder.decode(
rois[..., 1:], bbox_pred, max_shape=img_shape)
rbboxes = self.fix_coder.decode(bboxes, fix_pred)
bboxes = bboxes.view(*ratio_pred.size(), 4)
rbboxes = rbboxes.view(*ratio_pred.size(), 5)
from mmrotate.core import hbb2obb
rbboxes = rbboxes.where(
ratio_pred.unsqueeze(-1) < self.ratio_thr,
hbb2obb(bboxes, self.version))
rbboxes = rbboxes.squeeze(2)
# ignore background class
scores = scores[..., :self.num_classes]
post_params = get_post_processing_params(ctx.cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms_rotated(
rbboxes,
scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)

View File

@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.roi_heads.gv_ratio_roi_head'
'.GVRatioRoIHead.simple_test_bboxes')
def gv_ratio_roi_head__simple_test_bboxes(ctx,
self,
x,
img_metas,
proposals,
rcnn_test_cfg,
rescale=False):
"""Test only det bboxes without augmentation.
Args:
x (tuple[Tensor]): Feature maps of all scale level.
img_metas (list[dict]): Image meta info.
proposals (List[Tensor]): Region proposals.
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
rescale (bool): If True, return boxes in original image space.
Default: False.
Returns:
tuple[list[Tensor], list[Tensor]]: The first list contains \
the boxes of the corresponding image in a batch, each \
tensor has the shape (num_boxes, 6) and last dimension \
6 represent (x, y, w, h, theta, score). Each Tensor \
in the second list is the labels with shape (num_boxes, ). \
The length of both lists should be equal to batch_size.
"""
rois, labels = proposals
batch_index = torch.arange(
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
rois.size(0), rois.size(1), 1)
rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
batch_size = rois.shape[0]
num_proposals_per_img = rois.shape[1]
# Eliminate the batch dimension
rois = rois.view(-1, 5)
bbox_results = self._bbox_forward(x, rois)
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
fix_pred = bbox_results['fix_pred']
ratio_pred = bbox_results['ratio_pred']
# Recover the batch dimension
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
cls_score.size(-1))
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
bbox_pred.size(-1))
fix_pred = fix_pred.reshape(batch_size, num_proposals_per_img,
fix_pred.size(-1))
ratio_pred = ratio_pred.reshape(batch_size, num_proposals_per_img,
ratio_pred.size(-1))
det_bboxes, det_labels = self.bbox_head.get_bboxes(
rois,
cls_score,
bbox_pred,
fix_pred,
ratio_pred,
img_metas[0]['img_shape'],
None,
rescale=rescale,
cfg=self.test_cfg)
return det_bboxes, det_labels

View File

@ -5,10 +5,10 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.roi_heads.oriented_standard_roi_head'
'.OrientedStandardRoIHead.simple_test')
def oriented_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas,
**kwargs):
'mmrotate.models.roi_heads.rotate_standard_roi_head'
'.RotatedStandardRoIHead.simple_test')
def rotate_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas,
**kwargs):
"""Rewrite `simple_test` of `StandardRoIHead` for default backend.
This function returns detection result as Tensor instead of numpy

View File

@ -48,3 +48,20 @@ models:
- *pipeline_ort_detection_dynamic_fp32
- *pipeline_trt_detection_dynamic_fp32
- *pipeline_trt_detection_dynamic_fp16
- name: oriented_rcnn
metafile: configs/oriented_rcnn/metafile.yml
model_configs:
- configs/oriented_rcnn/oriented_rcnn_r50_fpn_fp16_1x_dota_le90.py
pipelines:
- *pipeline_ort_detection_dynamic_fp32
- *pipeline_trt_detection_dynamic_fp32
- *pipeline_trt_detection_dynamic_fp16
- name: gliding_vertex
metafile: configs/gliding_vertex/metafile.yml
model_configs:
- configs/gliding_vertex/gliding_vertex_r50_fpn_1x_dota_le90.py
pipelines:
- *pipeline_trt_detection_dynamic_fp32
- *pipeline_trt_detection_dynamic_fp16

View File

@ -6,8 +6,9 @@ import torch
from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import (WrapFunction, backend_checker, check_backend,
get_onnx_model, get_rewrite_outputs)
from mmdeploy.utils.test import (WrapFunction, WrapModel, backend_checker,
check_backend, get_onnx_model,
get_rewrite_outputs)
try:
import_codebase(Codebase.MMROTATE)
@ -309,3 +310,32 @@ def test_poly2obb_le90(backend_type: Backend):
run_with_backend=False)
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_gvfixcoder__decode(backend_type: Backend):
check_backend(backend_type)
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(output_names=['output'], input_shape=None),
backend_config=dict(type=backend_type.value),
codebase_config=dict(type='mmrotate', task='RotatedDetection')))
from mmrotate.core.bbox import GVFixCoder
coder = GVFixCoder(angle_range='le90')
hbboxes = torch.rand(1, 10, 4)
fix_deltas = torch.rand(1, 10, 4)
wrapped_model = WrapModel(coder, 'decode')
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model,
model_inputs={
'hbboxes': hbboxes,
'fix_deltas': fix_deltas
},
deploy_cfg=deploy_cfg,
run_with_backend=False)
assert rewrite_outputs is not None

View File

@ -332,3 +332,223 @@ def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend):
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None
def get_rotated_rpn_head_model():
"""Oriented RPN Head Config."""
test_cfg = mmcv.Config(
dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(iou_thr=0.1),
max_per_img=2000))
from mmrotate.models.dense_heads import RotatedRPNHead
model = RotatedRPNHead(
version='le90',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0.0, 0.0, 0.0, 0.0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
test_cfg=test_cfg)
model.requires_grad_(False)
return model
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_get_bboxes_of_rotated_rpn_head(backend_type: Backend):
check_backend(backend_type)
head = get_rotated_rpn_head_model()
head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['dets', 'labels']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmrotate',
task='RotatedDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.1,
pre_top_k=2000,
keep_top_k=2000))))
# the cls_score's size: (1, 3, 32, 32), (1, 3, 16, 16),
# (1, 3, 8, 8), (1, 3, 4, 4), (1, 3, 2, 2).
# the bboxes's size: (1, 18, 32, 32), (1, 18, 16, 16),
# (1, 18, 8, 8), (1, 18, 4, 4), (1, 18, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 18, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_rotate_standard_roi_head__simple_test(backend_type: Backend):
check_backend(backend_type)
from mmrotate.models.roi_heads import OrientedStandardRoIHead
output_names = ['dets', 'labels']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmrotate',
task='RotatedDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.1,
pre_top_k=2000,
keep_top_k=2000))))
angle_version = 'le90'
test_cfg = mmcv.Config(
dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(iou_thr=0.1),
max_per_img=2000))
head = OrientedStandardRoIHead(
bbox_roi_extractor=dict(
type='RotatedSingleRoIExtractor',
roi_layer=dict(
type='RoIAlignRotated',
out_size=7,
sample_num=2,
clockwise=True),
out_channels=3,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='RotatedShared2FCBBoxHead',
in_channels=3,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=15,
bbox_coder=dict(
type='DeltaXYWHAOBBoxCoder',
angle_range=angle_version,
norm_factor=None,
edge_swap=True,
proj_xy=True,
target_means=(.0, .0, .0, .0, .0),
target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)),
reg_class_agnostic=True),
test_cfg=test_cfg)
head.cpu().eval()
seed_everything(1234)
x = [torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(4, 0, -1)]
proposals = [torch.rand(1, 100, 6), torch.randint(0, 10, (1, 100))]
img_metas = [{'img_shape': torch.tensor([224, 224])}]
wrapped_model = WrapModel(
head, 'simple_test', proposals=proposals, img_metas=img_metas)
rewrite_inputs = {'x': x}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_gv_ratio_roi_head__simple_test(backend_type: Backend):
check_backend(backend_type)
from mmrotate.models.roi_heads import GVRatioRoIHead
output_names = ['dets', 'labels']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmrotate',
task='RotatedDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.1,
pre_top_k=2000,
keep_top_k=2000,
max_output_boxes_per_class=1000))))
angle_version = 'le90'
test_cfg = mmcv.Config(
dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(iou_thr=0.1),
max_per_img=2000))
head = GVRatioRoIHead(
version=angle_version,
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=3,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='GVBBoxHead',
version=angle_version,
num_shared_fcs=2,
in_channels=3,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=15,
ratio_thr=0.8,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=(.0, .0, .0, .0),
target_stds=(0.1, 0.1, 0.2, 0.2)),
fix_coder=dict(type='GVFixCoder', angle_range=angle_version),
ratio_coder=dict(type='GVRatioCoder', angle_range=angle_version),
reg_class_agnostic=True),
test_cfg=test_cfg)
head.cpu().eval()
seed_everything(1234)
x = [torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(4, 0, -1)]
bboxes = torch.rand(1, 100, 2)
bboxes = torch.cat(
[bboxes, bboxes + torch.rand(1, 100, 2) + torch.rand(1, 100, 1)],
dim=-1)
proposals = [bboxes, torch.randint(0, 10, (1, 100))]
img_metas = [{'img_shape': torch.tensor([224, 224])}]
wrapped_model = WrapModel(
head, 'simple_test', proposals=proposals, img_metas=img_metas)
rewrite_inputs = {'x': x}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
assert rewrite_outputs is not None