[Enhancement] Support Slide Vertex TRT (#650)
* reorgnize mmrotate * fix * add hbb2obb * add ut * fix rotated nms * update docs * update benchmark * update test * remove ort regression test, remove commentpull/754/head
parent
14b2bfd524
commit
dace58e844
|
@ -6,4 +6,5 @@ codebase_config = dict(
|
|||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=3000,
|
||||
keep_top_k=2000))
|
||||
keep_top_k=2000,
|
||||
max_output_boxes_per_class=2000))
|
||||
|
|
|
@ -295,7 +295,7 @@ __host__ __device__ __forceinline__ T single_box_iou_rotated(T const *const box1
|
|||
const T area1 = box1.w * box1.h;
|
||||
const T area2 = box2.w * box2.h;
|
||||
if (area1 < 1e-14 || area2 < 1e-14) {
|
||||
return 0.f;
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
const T intersection = rotated_boxes_intersection<T>(box1, box2);
|
||||
|
|
|
@ -1641,6 +1641,18 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../
|
|||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/gliding_vertex_r50_fpn_1x_dota_le90.py">GlidingVertex</a></td>
|
||||
<td align="center">Rotated Detection</td>
|
||||
<td align="center">DOTA-v1.0</td>
|
||||
<td align="center">mAP</td>
|
||||
<td align="center">0.732</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">0.733</td>
|
||||
<td align="center">0.731</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
|
|
@ -73,6 +73,8 @@ The table below lists the models that are guaranteed to be exportable to other b
|
|||
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
|
||||
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
|
||||
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
|
||||
| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
|
||||
|
||||
### Note
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ Please refer to [official installation guide](https://mmrotate.readthedocs.io/en
|
|||
| :--------------- | :--------------- | :----------: | :------: | :--: | :---: | :------: | :--------------------------------------------------------------------------------------------: |
|
||||
| RotatedRetinaNet | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
| Oriented RCNN | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
|
||||
| Gliding Vertex | RotatedDetection | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
|
||||
|
||||
### Example
|
||||
|
||||
|
|
|
@ -1638,6 +1638,18 @@ GPU: ncnn, TensorRT, PPLNN
|
|||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/gliding_vertex_r50_fpn_1x_dota_le90.py">GlidingVertex</a></td>
|
||||
<td align="center">Rotated Detection</td>
|
||||
<td align="center">DOTA-v1.0</td>
|
||||
<td align="center">mAP</td>
|
||||
<td align="center">0.732</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">0.733</td>
|
||||
<td align="center">0.731</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
|
|
@ -72,6 +72,7 @@
|
|||
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
|
||||
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
|
||||
| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
|
||||
|
||||
## Note
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .delta_midpointoffset_rbbox_coder import * # noqa: F401,F403
|
||||
from .delta_xywha_rbbox_coder import * # noqa: F401,F403
|
||||
from .gliding_vertex_coder import * # noqa: F401,F403
|
||||
from .transforms import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.core.bbox.coder.gliding_vertex_coder'
|
||||
'.GVFixCoder.decode')
|
||||
def gvfixcoder__decode(ctx, self, hbboxes, fix_deltas):
|
||||
"""Rewriter for GVFixCoder decode, support more dimension input."""
|
||||
|
||||
from mmrotate.core.bbox.transforms import poly2obb
|
||||
x1 = hbboxes[..., 0::4]
|
||||
y1 = hbboxes[..., 1::4]
|
||||
x2 = hbboxes[..., 2::4]
|
||||
y2 = hbboxes[..., 3::4]
|
||||
w = hbboxes[..., 2::4] - hbboxes[..., 0::4]
|
||||
h = hbboxes[..., 3::4] - hbboxes[..., 1::4]
|
||||
|
||||
pred_t_x = x1 + w * fix_deltas[..., 0::4]
|
||||
pred_r_y = y1 + h * fix_deltas[..., 1::4]
|
||||
pred_d_x = x2 - w * fix_deltas[..., 2::4]
|
||||
pred_l_y = y2 - h * fix_deltas[..., 3::4]
|
||||
|
||||
polys = torch.stack(
|
||||
[pred_t_x, y1, x2, pred_r_y, pred_d_x, y2, x1, pred_l_y], dim=-1)
|
||||
polys = polys.flatten(2)
|
||||
rbboxes = poly2obb(polys, self.version)
|
||||
|
||||
return rbboxes
|
|
@ -77,6 +77,7 @@ def select_rnms_index(scores: torch.Tensor,
|
|||
|
||||
def _multiclass_nms_rotated(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.1,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
|
|
|
@ -75,6 +75,33 @@ class End2EndModel(BaseBackendModel):
|
|||
output_names=output_names,
|
||||
deploy_cfg=self.deploy_cfg)
|
||||
|
||||
@staticmethod
|
||||
def __clear_outputs(
|
||||
test_outputs: List[Union[torch.Tensor, np.ndarray]]
|
||||
) -> List[Union[List[torch.Tensor], List[np.ndarray]]]:
|
||||
"""Removes additional outputs and detections with zero and negative
|
||||
score.
|
||||
|
||||
Args:
|
||||
test_outputs (List[Union[torch.Tensor, np.ndarray]]):
|
||||
outputs of forward_test.
|
||||
|
||||
Returns:
|
||||
List[Union[List[torch.Tensor], List[np.ndarray]]]:
|
||||
outputs with without zero score object.
|
||||
"""
|
||||
batch_size = len(test_outputs[0])
|
||||
|
||||
num_outputs = len(test_outputs)
|
||||
outputs = [[None for _ in range(batch_size)]
|
||||
for _ in range(num_outputs)]
|
||||
|
||||
for i in range(batch_size):
|
||||
inds = test_outputs[0][i, :, -1] > 0.0
|
||||
for output_id in range(num_outputs):
|
||||
outputs[output_id][i] = test_outputs[output_id][i, inds, ...]
|
||||
return outputs
|
||||
|
||||
def forward(self, img: Sequence[torch.Tensor],
|
||||
img_metas: Sequence[Sequence[dict]], *args, **kwargs) -> list:
|
||||
"""Run forward inference.
|
||||
|
@ -91,6 +118,7 @@ class End2EndModel(BaseBackendModel):
|
|||
input_img = img[0].contiguous()
|
||||
img_metas = img_metas[0]
|
||||
outputs = self.forward_test(input_img, img_metas, *args, **kwargs)
|
||||
outputs = End2EndModel.__clear_outputs(outputs)
|
||||
batch_dets, batch_labels = outputs[:2]
|
||||
batch_size = input_img.shape[0]
|
||||
rescale = kwargs.get('rescale', False)
|
||||
|
|
|
@ -1,19 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .oriented_standard_roi_head import (
|
||||
oriented_standard_roi_head__simple_test,
|
||||
oriented_standard_roi_head__simple_test_bboxes)
|
||||
from .roi_extractors import rotated_single_roi_extractor__forward__tensorrt
|
||||
from .rotated_anchor_head import rotated_anchor_head__get_bbox
|
||||
from .rotated_bbox_head import rotated_bbox_head__get_bboxes
|
||||
from .rotated_rpn_head import rotated_rpn_head__get_bboxes
|
||||
from .single_stage_rotated_detector import \
|
||||
single_stage_rotated_detector__simple_test
|
||||
|
||||
__all__ = [
|
||||
'single_stage_rotated_detector__simple_test',
|
||||
'rotated_anchor_head__get_bbox', 'rotated_rpn_head__get_bboxes',
|
||||
'oriented_standard_roi_head__simple_test',
|
||||
'oriented_standard_roi_head__simple_test_bboxes',
|
||||
'rotated_bbox_head__get_bboxes',
|
||||
'rotated_single_roi_extractor__forward__tensorrt'
|
||||
]
|
||||
from .dense_heads import * # noqa: F401,F403
|
||||
from .roi_heads import * # noqa: F401,F403
|
||||
from .single_stage_rotated_detector import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .oriented_rpn_head import oriented_rpn_head__get_bboxes
|
||||
from .rotated_anchor_head import rotated_anchor_head__get_bbox
|
||||
from .rotated_rpn_head import rotated_rpn_head__get_bboxes
|
||||
|
||||
__all__ = [
|
||||
'oriented_rpn_head__get_bboxes', 'rotated_anchor_head__get_bbox',
|
||||
'rotated_rpn_head__get_bboxes'
|
||||
]
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
||||
pad_with_value_if_necessary)
|
||||
from mmdeploy.codebase.mmrotate.core.post_processing import \
|
||||
fake_multiclass_nms_rotated
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.dense_heads.OrientedRPNHead.get_bboxes')
|
||||
def oriented_rpn_head__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
score_factors=None,
|
||||
img_metas=None,
|
||||
cfg=None,
|
||||
rescale=False,
|
||||
with_nms=True,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of `RPNHead` for default backend.
|
||||
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (FoveaHead): The instance of the class FoveaHead.
|
||||
cls_scores (list[Tensor]): Box scores for each scale level
|
||||
with shape (N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
||||
level with shape (N, num_anchors * 4, H, W).
|
||||
score_factors (list[Tensor], Optional): Score factor for
|
||||
all scale level, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 1, H, W). Default None.
|
||||
img_metas (list[dict]): Meta information of the image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
||||
device = cls_scores[0].device
|
||||
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
|
||||
mlvl_anchors = self.anchor_generator.grid_anchors(
|
||||
featmap_sizes, device=device)
|
||||
|
||||
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
batch_size = mlvl_cls_scores[0].shape[0]
|
||||
pre_topk = cfg.get('nms_pre', -1)
|
||||
|
||||
# loop over features, decode boxes
|
||||
mlvl_valid_bboxes = []
|
||||
mlvl_scores = []
|
||||
mlvl_valid_anchors = []
|
||||
for level_id, cls_score, bbox_pred, anchors in zip(
|
||||
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
cls_score = cls_score.permute(0, 2, 3, 1)
|
||||
if self.use_sigmoid_cls:
|
||||
cls_score = cls_score.reshape(batch_size, -1)
|
||||
scores = cls_score.sigmoid()
|
||||
else:
|
||||
cls_score = cls_score.reshape(batch_size, -1, 2)
|
||||
# We set FG labels to [0, num_class-1] and BG label to
|
||||
# num_class in RPN head since mmdet v2.5, which is unified to
|
||||
# be consistent with other head since mmdet v2.0. In mmdet v2.0
|
||||
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
|
||||
scores = cls_score.softmax(-1)[..., 0]
|
||||
scores = scores.reshape(batch_size, -1, 1)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6)
|
||||
|
||||
# use static anchor if input shape is static
|
||||
if not is_dynamic_flag:
|
||||
anchors = anchors.data
|
||||
|
||||
anchors = anchors.unsqueeze(0)
|
||||
|
||||
# topk in tensorrt does not support shape<k
|
||||
# concate zero to enable topk,
|
||||
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
|
||||
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
|
||||
anchors = pad_with_value_if_necessary(anchors, 1, pre_topk)
|
||||
|
||||
if pre_topk > 0:
|
||||
_, topk_inds = scores.squeeze(2).topk(pre_topk)
|
||||
batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1)
|
||||
prior_inds = topk_inds.new_zeros((1, 1))
|
||||
anchors = anchors[prior_inds, topk_inds, :]
|
||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_valid_anchors.append(anchors)
|
||||
|
||||
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
|
||||
batch_mlvl_bboxes = self.bbox_coder.decode(
|
||||
batch_mlvl_anchors,
|
||||
batch_mlvl_bboxes,
|
||||
max_shape=img_metas[0]['img_shape'])
|
||||
# ignore background class
|
||||
if not self.use_sigmoid_cls:
|
||||
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
|
||||
if not with_nms:
|
||||
return batch_mlvl_bboxes, batch_mlvl_scores
|
||||
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
# only one class in rpn
|
||||
max_output_boxes_per_class = keep_top_k
|
||||
return fake_multiclass_nms_rotated(
|
||||
batch_mlvl_bboxes,
|
||||
batch_mlvl_scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
keep_top_k=keep_top_k,
|
||||
version=self.version)
|
|
@ -3,8 +3,7 @@ import torch
|
|||
|
||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
||||
pad_with_value_if_necessary)
|
||||
from mmdeploy.codebase.mmrotate.core.post_processing import \
|
||||
fake_multiclass_nms_rotated
|
||||
from mmdeploy.codebase.mmdet.core.post_processing import multiclass_nms
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
@ -89,7 +88,7 @@ def rotated_rpn_head__get_bboxes(ctx,
|
|||
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
|
||||
scores = cls_score.softmax(-1)[..., 0]
|
||||
scores = scores.reshape(batch_size, -1, 1)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 6)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
|
||||
# use static anchor if input shape is static
|
||||
if not is_dynamic_flag:
|
||||
|
@ -129,13 +128,16 @@ def rotated_rpn_head__get_bboxes(ctx,
|
|||
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
# only one class in rpn
|
||||
max_output_boxes_per_class = keep_top_k
|
||||
return fake_multiclass_nms_rotated(
|
||||
return multiclass_nms(
|
||||
batch_mlvl_bboxes,
|
||||
batch_mlvl_scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
keep_top_k=keep_top_k,
|
||||
version=self.version)
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .gv_bbox_head import gv_bbox_head__get_bboxes
|
||||
from .gv_ratio_roi_head import gv_ratio_roi_head__simple_test_bboxes
|
||||
from .oriented_standard_roi_head import \
|
||||
oriented_standard_roi_head__simple_test_bboxes
|
||||
from .roi_extractors import rotated_single_roi_extractor__forward__tensorrt
|
||||
from .rotated_bbox_head import rotated_bbox_head__get_bboxes
|
||||
|
||||
__all__ = [
|
||||
'gv_bbox_head__get_bboxes', 'gv_ratio_roi_head__simple_test_bboxes',
|
||||
'oriented_standard_roi_head__simple_test_bboxes',
|
||||
'rotated_single_roi_extractor__forward__tensorrt',
|
||||
'rotated_bbox_head__get_bboxes'
|
||||
]
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.codebase.mmdet import get_post_processing_params
|
||||
from mmdeploy.codebase.mmrotate.core.post_processing import \
|
||||
multiclass_nms_rotated
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.roi_heads.bbox_heads.GVBBoxHead.get_bboxes')
|
||||
def gv_bbox_head__get_bboxes(ctx,
|
||||
self,
|
||||
rois,
|
||||
cls_score,
|
||||
bbox_pred,
|
||||
fix_pred,
|
||||
ratio_pred,
|
||||
img_shape,
|
||||
scale_factor,
|
||||
rescale=False,
|
||||
cfg=None):
|
||||
"""Transform network output for a batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
rois (torch.Tensor): Boxes to be transformed. Has shape
|
||||
(num_boxes, 6). last dimension 5 arrange as
|
||||
(batch_index, x, y, w, h, theta).
|
||||
cls_score (torch.Tensor): Box scores, has shape
|
||||
(num_boxes, num_classes + 1).
|
||||
bbox_pred (Tensor, optional): Box energies / deltas.
|
||||
has shape (num_boxes, num_classes * 6).
|
||||
img_shape (Sequence[int], optional): Maximum bounds for boxes,
|
||||
specifies (H, W, C) or (H, W).
|
||||
scale_factor (ndarray): Scale factor of the
|
||||
image arrange as (w_scale, h_scale, w_scale, h_scale).
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default: False.
|
||||
cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]:
|
||||
First tensor is `det_bboxes`, has the shape
|
||||
(num_boxes, 6) and last
|
||||
dimension 6 represent (cx, cy, w, h, theta, score).
|
||||
Second tensor is the labels with shape (num_boxes, ).
|
||||
"""
|
||||
assert rois.ndim == 3, 'Only support export two stage ' \
|
||||
'model to ONNX ' \
|
||||
'with batch dimension. '
|
||||
|
||||
if self.custom_cls_channels:
|
||||
scores = self.loss_cls.get_activation(cls_score)
|
||||
else:
|
||||
scores = F.softmax(
|
||||
cls_score, dim=-1) if cls_score is not None else None
|
||||
|
||||
assert bbox_pred is not None
|
||||
bboxes = self.bbox_coder.decode(
|
||||
rois[..., 1:], bbox_pred, max_shape=img_shape)
|
||||
|
||||
rbboxes = self.fix_coder.decode(bboxes, fix_pred)
|
||||
|
||||
bboxes = bboxes.view(*ratio_pred.size(), 4)
|
||||
rbboxes = rbboxes.view(*ratio_pred.size(), 5)
|
||||
|
||||
from mmrotate.core import hbb2obb
|
||||
rbboxes = rbboxes.where(
|
||||
ratio_pred.unsqueeze(-1) < self.ratio_thr,
|
||||
hbb2obb(bboxes, self.version))
|
||||
rbboxes = rbboxes.squeeze(2)
|
||||
|
||||
# ignore background class
|
||||
scores = scores[..., :self.num_classes]
|
||||
|
||||
post_params = get_post_processing_params(ctx.cfg)
|
||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
return multiclass_nms_rotated(
|
||||
rbboxes,
|
||||
scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.roi_heads.gv_ratio_roi_head'
|
||||
'.GVRatioRoIHead.simple_test_bboxes')
|
||||
def gv_ratio_roi_head__simple_test_bboxes(ctx,
|
||||
self,
|
||||
x,
|
||||
img_metas,
|
||||
proposals,
|
||||
rcnn_test_cfg,
|
||||
rescale=False):
|
||||
"""Test only det bboxes without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Feature maps of all scale level.
|
||||
img_metas (list[dict]): Image meta info.
|
||||
proposals (List[Tensor]): Region proposals.
|
||||
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
tuple[list[Tensor], list[Tensor]]: The first list contains \
|
||||
the boxes of the corresponding image in a batch, each \
|
||||
tensor has the shape (num_boxes, 6) and last dimension \
|
||||
6 represent (x, y, w, h, theta, score). Each Tensor \
|
||||
in the second list is the labels with shape (num_boxes, ). \
|
||||
The length of both lists should be equal to batch_size.
|
||||
"""
|
||||
|
||||
rois, labels = proposals
|
||||
batch_index = torch.arange(
|
||||
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
|
||||
rois.size(0), rois.size(1), 1)
|
||||
rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
|
||||
batch_size = rois.shape[0]
|
||||
num_proposals_per_img = rois.shape[1]
|
||||
|
||||
# Eliminate the batch dimension
|
||||
rois = rois.view(-1, 5)
|
||||
bbox_results = self._bbox_forward(x, rois)
|
||||
cls_score = bbox_results['cls_score']
|
||||
bbox_pred = bbox_results['bbox_pred']
|
||||
fix_pred = bbox_results['fix_pred']
|
||||
ratio_pred = bbox_results['ratio_pred']
|
||||
|
||||
# Recover the batch dimension
|
||||
rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
|
||||
cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
|
||||
cls_score.size(-1))
|
||||
|
||||
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
|
||||
bbox_pred.size(-1))
|
||||
fix_pred = fix_pred.reshape(batch_size, num_proposals_per_img,
|
||||
fix_pred.size(-1))
|
||||
ratio_pred = ratio_pred.reshape(batch_size, num_proposals_per_img,
|
||||
ratio_pred.size(-1))
|
||||
det_bboxes, det_labels = self.bbox_head.get_bboxes(
|
||||
rois,
|
||||
cls_score,
|
||||
bbox_pred,
|
||||
fix_pred,
|
||||
ratio_pred,
|
||||
img_metas[0]['img_shape'],
|
||||
None,
|
||||
rescale=rescale,
|
||||
cfg=self.test_cfg)
|
||||
return det_bboxes, det_labels
|
|
@ -5,10 +5,10 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.roi_heads.oriented_standard_roi_head'
|
||||
'.OrientedStandardRoIHead.simple_test')
|
||||
def oriented_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
'mmrotate.models.roi_heads.rotate_standard_roi_head'
|
||||
'.RotatedStandardRoIHead.simple_test')
|
||||
def rotate_standard_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` of `StandardRoIHead` for default backend.
|
||||
|
||||
This function returns detection result as Tensor instead of numpy
|
|
@ -48,3 +48,20 @@ models:
|
|||
- *pipeline_ort_detection_dynamic_fp32
|
||||
- *pipeline_trt_detection_dynamic_fp32
|
||||
- *pipeline_trt_detection_dynamic_fp16
|
||||
|
||||
- name: oriented_rcnn
|
||||
metafile: configs/oriented_rcnn/metafile.yml
|
||||
model_configs:
|
||||
- configs/oriented_rcnn/oriented_rcnn_r50_fpn_fp16_1x_dota_le90.py
|
||||
pipelines:
|
||||
- *pipeline_ort_detection_dynamic_fp32
|
||||
- *pipeline_trt_detection_dynamic_fp32
|
||||
- *pipeline_trt_detection_dynamic_fp16
|
||||
|
||||
- name: gliding_vertex
|
||||
metafile: configs/gliding_vertex/metafile.yml
|
||||
model_configs:
|
||||
- configs/gliding_vertex/gliding_vertex_r50_fpn_1x_dota_le90.py
|
||||
pipelines:
|
||||
- *pipeline_trt_detection_dynamic_fp32
|
||||
- *pipeline_trt_detection_dynamic_fp16
|
||||
|
|
|
@ -6,8 +6,9 @@ import torch
|
|||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.test import (WrapFunction, backend_checker, check_backend,
|
||||
get_onnx_model, get_rewrite_outputs)
|
||||
from mmdeploy.utils.test import (WrapFunction, WrapModel, backend_checker,
|
||||
check_backend, get_onnx_model,
|
||||
get_rewrite_outputs)
|
||||
|
||||
try:
|
||||
import_codebase(Codebase.MMROTATE)
|
||||
|
@ -309,3 +310,32 @@ def test_poly2obb_le90(backend_type: Backend):
|
|||
run_with_backend=False)
|
||||
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_gvfixcoder__decode(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(output_names=['output'], input_shape=None),
|
||||
backend_config=dict(type=backend_type.value),
|
||||
codebase_config=dict(type='mmrotate', task='RotatedDetection')))
|
||||
|
||||
from mmrotate.core.bbox import GVFixCoder
|
||||
coder = GVFixCoder(angle_range='le90')
|
||||
|
||||
hbboxes = torch.rand(1, 10, 4)
|
||||
fix_deltas = torch.rand(1, 10, 4)
|
||||
|
||||
wrapped_model = WrapModel(coder, 'decode')
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model,
|
||||
model_inputs={
|
||||
'hbboxes': hbboxes,
|
||||
'fix_deltas': fix_deltas
|
||||
},
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
|
||||
assert rewrite_outputs is not None
|
||||
|
|
|
@ -332,3 +332,223 @@ def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend):
|
|||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
def get_rotated_rpn_head_model():
|
||||
"""Oriented RPN Head Config."""
|
||||
test_cfg = mmcv.Config(
|
||||
dict(
|
||||
nms_pre=2000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(iou_thr=0.1),
|
||||
max_per_img=2000))
|
||||
from mmrotate.models.dense_heads import RotatedRPNHead
|
||||
model = RotatedRPNHead(
|
||||
version='le90',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
anchor_generator=dict(
|
||||
type='AnchorGenerator',
|
||||
scales=[8],
|
||||
ratios=[0.5, 1.0, 2.0],
|
||||
strides=[4, 8, 16, 32, 64]),
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=[0.0, 0.0, 0.0, 0.0],
|
||||
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
||||
test_cfg=test_cfg)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_get_bboxes_of_rotated_rpn_head(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
head = get_rotated_rpn_head_model()
|
||||
head.cpu().eval()
|
||||
s = 128
|
||||
img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
output_names = ['dets', 'labels']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=2000,
|
||||
keep_top_k=2000))))
|
||||
|
||||
# the cls_score's size: (1, 3, 32, 32), (1, 3, 16, 16),
|
||||
# (1, 3, 8, 8), (1, 3, 4, 4), (1, 3, 2, 2).
|
||||
# the bboxes's size: (1, 18, 32, 32), (1, 18, 16, 16),
|
||||
# (1, 18, 8, 8), (1, 18, 4, 4), (1, 18, 2, 2)
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [torch.rand(1, 18, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
wrapped_model = WrapModel(
|
||||
head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_rotate_standard_roi_head__simple_test(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
from mmrotate.models.roi_heads import OrientedStandardRoIHead
|
||||
output_names = ['dets', 'labels']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=2000,
|
||||
keep_top_k=2000))))
|
||||
angle_version = 'le90'
|
||||
test_cfg = mmcv.Config(
|
||||
dict(
|
||||
nms_pre=2000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(iou_thr=0.1),
|
||||
max_per_img=2000))
|
||||
head = OrientedStandardRoIHead(
|
||||
bbox_roi_extractor=dict(
|
||||
type='RotatedSingleRoIExtractor',
|
||||
roi_layer=dict(
|
||||
type='RoIAlignRotated',
|
||||
out_size=7,
|
||||
sample_num=2,
|
||||
clockwise=True),
|
||||
out_channels=3,
|
||||
featmap_strides=[4, 8, 16, 32]),
|
||||
bbox_head=dict(
|
||||
type='RotatedShared2FCBBoxHead',
|
||||
in_channels=3,
|
||||
fc_out_channels=1024,
|
||||
roi_feat_size=7,
|
||||
num_classes=15,
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHAOBBoxCoder',
|
||||
angle_range=angle_version,
|
||||
norm_factor=None,
|
||||
edge_swap=True,
|
||||
proj_xy=True,
|
||||
target_means=(.0, .0, .0, .0, .0),
|
||||
target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)),
|
||||
reg_class_agnostic=True),
|
||||
test_cfg=test_cfg)
|
||||
head.cpu().eval()
|
||||
|
||||
seed_everything(1234)
|
||||
x = [torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(4, 0, -1)]
|
||||
proposals = [torch.rand(1, 100, 6), torch.randint(0, 10, (1, 100))]
|
||||
img_metas = [{'img_shape': torch.tensor([224, 224])}]
|
||||
|
||||
wrapped_model = WrapModel(
|
||||
head, 'simple_test', proposals=proposals, img_metas=img_metas)
|
||||
rewrite_inputs = {'x': x}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_gv_ratio_roi_head__simple_test(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
from mmrotate.models.roi_heads import GVRatioRoIHead
|
||||
output_names = ['dets', 'labels']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmrotate',
|
||||
task='RotatedDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.1,
|
||||
pre_top_k=2000,
|
||||
keep_top_k=2000,
|
||||
max_output_boxes_per_class=1000))))
|
||||
angle_version = 'le90'
|
||||
test_cfg = mmcv.Config(
|
||||
dict(
|
||||
nms_pre=2000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(iou_thr=0.1),
|
||||
max_per_img=2000))
|
||||
head = GVRatioRoIHead(
|
||||
version=angle_version,
|
||||
bbox_roi_extractor=dict(
|
||||
type='SingleRoIExtractor',
|
||||
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
||||
out_channels=3,
|
||||
featmap_strides=[4, 8, 16, 32]),
|
||||
bbox_head=dict(
|
||||
type='GVBBoxHead',
|
||||
version=angle_version,
|
||||
num_shared_fcs=2,
|
||||
in_channels=3,
|
||||
fc_out_channels=1024,
|
||||
roi_feat_size=7,
|
||||
num_classes=15,
|
||||
ratio_thr=0.8,
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=(.0, .0, .0, .0),
|
||||
target_stds=(0.1, 0.1, 0.2, 0.2)),
|
||||
fix_coder=dict(type='GVFixCoder', angle_range=angle_version),
|
||||
ratio_coder=dict(type='GVRatioCoder', angle_range=angle_version),
|
||||
reg_class_agnostic=True),
|
||||
test_cfg=test_cfg)
|
||||
head.cpu().eval()
|
||||
|
||||
seed_everything(1234)
|
||||
x = [torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(4, 0, -1)]
|
||||
bboxes = torch.rand(1, 100, 2)
|
||||
bboxes = torch.cat(
|
||||
[bboxes, bboxes + torch.rand(1, 100, 2) + torch.rand(1, 100, 1)],
|
||||
dim=-1)
|
||||
proposals = [bboxes, torch.randint(0, 10, (1, 100))]
|
||||
img_metas = [{'img_shape': torch.tensor([224, 224])}]
|
||||
|
||||
wrapped_model = WrapModel(
|
||||
head, 'simple_test', proposals=proposals, img_metas=img_metas)
|
||||
rewrite_inputs = {'x': x}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
assert rewrite_outputs is not None
|
||||
|
|
Loading…
Reference in New Issue