[Docstring]: Update docstring and names of rewrite in mmdet (#194)

* refactor mmdet docstring

* fix yapf

* fix lint

* fix docformatter

* fix docformatter --wrap-descriptions 79

* reply comments

* fix runningleon comments;

* fix misleading
This commit is contained in:
hanrui1sensetime 2021-11-15 10:41:27 +08:00 committed by GitHub
parent fa626a58a0
commit 6f98c423a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 757 additions and 207 deletions

View File

@ -563,7 +563,7 @@ class PartitionTwoStageDetector(DeployBaseDetector):
from mmdet.models.builder import build_head, build_roi_extractor
from mmdeploy.mmdet.models.roi_heads.bbox_heads import \
get_bboxes_of_bbox_head
bbox_head__get_bboxes
# load cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
@ -580,7 +580,7 @@ class PartitionTwoStageDetector(DeployBaseDetector):
ctx = Context()
ctx.cfg = self.deploy_cfg
self.get_bboxes_of_bbox_head = partial(get_bboxes_of_bbox_head, ctx)
self.bbox_head__get_bboxes = partial(bbox_head__get_bboxes, ctx)
def partition0_postprocess(self, x: Sequence[torch.Tensor],
scores: torch.Tensor, bboxes: torch.Tensor):
@ -656,10 +656,10 @@ class PartitionTwoStageDetector(DeployBaseDetector):
bbox_pred.size(-1))
rcnn_test_cfg = self.model_cfg.model.test_cfg.rcnn
return self.get_bboxes_of_bbox_head(self.bbox_head, rois, cls_score,
bbox_pred,
img_metas[0][0]['img_shape'],
rcnn_test_cfg)
return self.bbox_head__get_bboxes(self.bbox_head, rois, cls_score,
bbox_pred,
img_metas[0][0]['img_shape'],
rcnn_test_cfg)
class ONNXRuntimePTSDetector(PartitionTwoStageDetector):

View File

@ -17,7 +17,40 @@ def delta2bbox(ctx,
clip_border=True,
add_ctr_clamp=False,
ctr_clamp=32):
"""Rewrite for ONNX exporting of default backend."""
"""Rewrite `delta2bbox` for default backend.
Since the need of clip op with dynamic min and max, this function uses
clip_bboxes function to support dynamic shape.
Args:
ctx (ContextCaller): The context with additional information.
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
deltas (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
when rois is a grid of anchors.Offset encoding follows [1]_.
means (Sequence[float]): Denormalizing means for delta coordinates
stds (Sequence[float]): Denormalizing standard deviation for delta
coordinates
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
wh_ratio_clip (float): Maximum aspect ratio for boxes.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
add_ctr_clamp (bool): Whether to add center clamp, when added, the
predicted box is clamped is its center is too far away from
the original anchor's center. Only used by YOLOF. Default False.
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
Default 32.
Return:
bboxes (Tensor): Boxes with shape (B, N, num_classes * 4) or (B, N, 4)
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
br_x, br_y.
"""
means = deltas.new_tensor(means).view(1,
-1).repeat(1,
deltas.size(-1) // 4)
@ -82,7 +115,43 @@ def delta2bbox_ncnn(ctx,
clip_border=True,
add_ctr_clamp=False,
ctr_clamp=32):
"""Rewrite for ONNX exporting of NCNN backend."""
"""Rewrite `delta2bbox` for ncnn backend.
Batch dimension is not supported by ncnn, but supported by pytorch.
NCNN regards the lowest two dimensions as continuous address with byte
alignment, so the lowest two dimensions are not absolutely independent.
Reshape operator with -1 arguments should operates ncnn::Mat with
dimension >= 3.
Args:
ctx (ContextCaller): The context with additional information.
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
deltas (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
when rois is a grid of anchors.Offset encoding follows [1]_.
means (Sequence[float]): Denormalizing means for delta coordinates
stds (Sequence[float]): Denormalizing standard deviation for delta
coordinates
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
wh_ratio_clip (float): Maximum aspect ratio for boxes.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
add_ctr_clamp (bool): Whether to add center clamp, when added, the
predicted box is clamped is its center is too far away from
the original anchor's center. Only used by YOLOF. Default False.
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
Default 32.
Return:
bboxes (Tensor): Boxes with shape (B, N, num_classes * 4) or (B, N, 4)
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
br_x, br_y.
"""
means = deltas.new_tensor(means).view(1, 1,
-1).repeat(1, deltas.size(-2),
deltas.size(-1) // 4).data

View File

@ -13,7 +13,34 @@ def tblr2bboxes(ctx,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
"""Rewrite for ONNX exporting of default backend."""
"""Rewrite `tblr2bboxes` for default backend.
Since the need of clip op with dynamic min and max, this function uses
clip_bboxes function to support dynamic shape.
Args:
ctx (ContextCaller): The context with additional information.
priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
Shape: (N,4) or (B, N, 4).
tblr (Tensor): Coords of network output in tblr form
Shape: (N, 4) or (B, N, 4).
normalizer (Sequence[float] | float): Normalization parameter of
encoded boxes. By list, it represents the normalization factors at
tblr dims. By float, it is the unified normalization factor at all
dims. Default: 4.0
normalize_by_wh (bool): Whether the tblr coordinates have been
normalized by the side length (wh) of prior bboxes.
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
Return:
bboxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
"""
if not isinstance(normalizer, float):
normalizer = torch.tensor(normalizer, device=priors.device)
assert len(normalizer) == 4, 'Normalizer must have length = 4'
@ -55,7 +82,35 @@ def tblr2bboxes_ncnn(ctx,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
"""Rewrite for ONNX exporting of NCNN backend."""
"""Rewrite `tblr2bboxes` for ncnn backend.
Batch dimension is not supported by ncnn, but supported by pytorch.
The negative value of axis in torch.cat is rewritten as corresponding
positive value to avoid axis shift.
Args:
ctx (ContextCaller): The context with additional information.
priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
Shape: (N,4) or (B, N, 4).
tblr (Tensor): Coords of network output in tblr form
Shape: (N, 4) or (B, N, 4).
normalizer (Sequence[float] | float): Normalization parameter of
encoded boxes. By list, it represents the normalization factors at
tblr dims. By float, it is the unified normalization factor at all
dims. Default: 4.0
normalize_by_wh (bool): Whether the tblr coordinates have been
normalized by the side length (wh) of prior bboxes.
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
Return:
bboxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
"""
assert priors.size(0) == tblr.size(0)
if priors.ndim == 3:
assert priors.size(1) == tblr.size(1)

View File

@ -1,14 +1,14 @@
from .anchor_head import get_bboxes_of_anchor_head
from .atss_head import get_bboxes_of_atss_head
from .fcos_head import get_bboxes_of_fcos_head
from .fovea_head import get_bboxes_of_fovea_head
from .rpn_head import get_bboxes_of_rpn_head
from .anchor_head import anchor_head__get_bboxes
from .atss_head import atss_head__get_bboxes
from .fcos_head import fcos_head__get_bboxes
from .fovea_head import fovea_head__get_bboxes
from .rpn_head import rpn_head__get_bboxes
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
from .yolox_head import get_bboxes_of_yolox_head
from .yolox_head import yolox_head__get_bboxes
__all__ = [
'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head',
'get_bboxes_of_rpn_head', 'get_bboxes_of_fovea_head',
'get_bboxes_of_atss_head', 'yolov3_head__get_bboxes',
'yolov3_head__get_bboxes__ncnn', 'get_bboxes_of_yolox_head'
'anchor_head__get_bboxes', 'atss_head__get_bboxes',
'fcos_head__get_bboxes', 'fovea_head__get_bboxes', 'rpn_head__get_bboxes',
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn',
'yolox_head__get_bboxes'
]

View File

@ -9,15 +9,45 @@ from mmdeploy.utils import (Backend, get_backend, get_mmdet_params,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.AnchorHead.get_bboxes')
def get_bboxes_of_anchor_head(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for default backend."""
def anchor_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of AnchorHead for default backend.
Rewrite this function to support deployment of default backend and
dynamic shape export.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Box scores for each level in the
feature pyramid, has shape
(N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each
level in the feature pyramid, has shape
(N, num_anchors * 4, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
Default: None.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
@ -117,15 +147,45 @@ def get_bboxes_of_anchor_head(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.AnchorHead.get_bboxes', backend='ncnn')
def get_bboxes_of_anchor_head_ncnn(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for NCNN backend."""
def anchor_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of AnchorHead for NCNN backend.
Shape node and batch inference is not supported by ncnn. This function
transform dynamic shape to constant shape and remove batch inference.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Box scores for each level in the
feature pyramid, has shape
(N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each
level in the feature pyramid, has shape
(N, num_anchors * 4, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
Default: None.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)

View File

@ -6,18 +6,19 @@ from mmdeploy.utils import get_mmdet_params
@FUNCTION_REWRITER.register_rewriter('mmdet.models.ATSSHead.get_bboxes')
def get_bboxes_of_atss_head(ctx,
self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
cfg=None,
rescale=False,
with_nms=True):
def atss_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
cfg=None,
rescale=False,
with_nms=True):
"""Rewrite `get_bboxes` from ATSSHead for default backend.
Transform network output for a batch into bbox predictions.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
@ -39,12 +40,9 @@ def get_bboxes_of_atss_head(ctx,
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: The first item is an (n, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score) and
the score between 0 and 1.
The shape of the second tensor in the tuple is (n,), and
each element represents the class label of the corresponding
box.
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness

View File

@ -9,16 +9,48 @@ from mmdeploy.utils import (Backend, get_backend, get_mmdet_params,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.FCOSHead.get_bboxes')
def get_bboxes_of_fcos_head(ctx,
self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for default backend."""
def fcos_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of FCOSHead for default backend.
Rewrite this function to support deployment of default backend
and dynamic shape export. Transform network output for a batch into
bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
self (ATSSHead): The instance of the class ATSSHead.
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
centernesses (list[Tensor]): Centerness for each scale level with
shape (N, num_anchors * 1, H, W).
img_metas (dict): Meta information of the image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
@ -111,16 +143,49 @@ def get_bboxes_of_fcos_head(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.FCOSHead.get_bboxes', backend='ncnn')
def get_bboxes_of_fcos_head_ncnn(ctx,
self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for NCNN backend."""
def fcos_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of FCOSHead for ncnn backend.
1. Shape node and batch inference is not supported by ncnn. This function
transform dynamic shape to constant shape and remove batch inference.
2. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by
ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for
correct `BinaryOps` calculation by ncnn.
Args:
ctx (ContextCaller): The context with additional information.
self (ATSSHead): The instance of the class ATSSHead.
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
centernesses (list[Tensor]): Centerness for each scale level with
shape (N, num_anchors * 1, H, W).
img_metas (dict): Meta information of the image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)

View File

@ -6,16 +6,17 @@ from mmdeploy.utils import get_mmdet_params
@FUNCTION_REWRITER.register_rewriter('mmdet.models.FoveaHead.get_bboxes')
def get_bboxes_of_fovea_head(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
cfg=None,
rescale=None):
def fovea_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
cfg=None,
rescale=None):
"""Rewrite `get_bboxes` from FoveaHead for default backend.
Transform network output for a batch into bbox predictions.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
@ -32,12 +33,9 @@ def get_bboxes_of_fovea_head(ctx,
Default: False.
Returns:
tuple[Tensor, Tensor]: The first item is an (n, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score) and
the score between 0 and 1.
The shape of the second tensor in the tuple is (n,), and
each element represents the class label of the corresponding
box.
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
"""
assert len(cls_scores) == len(bbox_preds)
cfg = self.test_cfg if cfg is None else cfg

View File

@ -8,15 +8,41 @@ from mmdeploy.utils import (Backend, get_backend, get_mmdet_params,
@FUNCTION_REWRITER.register_rewriter('mmdet.models.RPNHead.get_bboxes')
def get_bboxes_of_rpn_head(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for default backend."""
def rpn_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of RPNHead for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
self (FoveaHead): The instance of the class FoveaHead.
cls_scores (list[Tensor]): Box scores for each scale level
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
img_metas (dict): Meta information of the image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
@ -113,15 +139,44 @@ def get_bboxes_of_rpn_head(ctx,
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.RPNHead.get_bboxes', backend='ncnn')
def get_bboxes_of_rpn_head_ncnn(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for NCNN backend."""
def rpn_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of RPNHead for ncnn backend.
Shape node and batch inference is not supported by ncnn. This function
transform dynamic shape to constant shape and remove batch inference.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Box scores for each level in the
feature pyramid, has shape
(N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each
level in the feature pyramid, has shape
(N, num_anchors * 4, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
Default: None.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)

View File

@ -15,25 +15,27 @@ def yolov3_head__get_bboxes(ctx,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for default backend.
"""Rewrite `get_bboxes` of YOLOV3Head for default backend.
Transform network output for a batch into bbox predictions.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx: Context that contains original meta information.
self: Represent the instance of the original class.
pred_maps (list[Tensor]): Raw predictions for a batch of images.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
size and the score between 0 and 1. The shape of the second
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
is_dynamic_flag = is_dynamic_shape(ctx.cfg)
num_levels = len(pred_maps)
@ -161,25 +163,34 @@ def yolov3_head__get_bboxes__ncnn(ctx,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for ncnn backend.
"""Rewrite `get_bboxes` of YOLOV3Head for ncnn backend.
1. Shape node and batch inference is not supported by ncnn. This function
transform dynamic shape to constant shape and remove batch inference.
2. Batch dimension is not supported by ncnn, but supported by pytorch.
The negative value of axis in torch.cat is rewritten as corresponding
positive value to avoid axis shift.
3. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by
ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for
correct `BinaryOps` calculation by ncnn.
Transform network output for a batch into bbox predictions.
Args:
ctx: Context that contains original meta information.
self: Represent the instance of the original class.
pred_maps (list[Tensor]): Raw predictions for a batch of images.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
size and the score between 0 and 1. The shape of the second
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
num_levels = len(pred_maps)
pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]

View File

@ -7,18 +7,19 @@ from mmdeploy.utils import get_mmdet_params
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.YOLOXHead.get_bboxes')
def get_bboxes_of_yolox_head(ctx,
self,
cls_scores,
bbox_preds,
objectnesses,
img_metas=None,
cfg=None,
rescale=False,
with_nms=True):
"""Rewrite `get_bboxes` for default backend.
def yolox_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
objectnesses,
img_metas=None,
cfg=None,
rescale=False,
with_nms=True):
"""Rewrite `get_bboxes` of YOLOXHead for default backend.
Transform network outputs of a batch into bbox results.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx: Context that contains original meta information.
@ -39,6 +40,7 @@ def get_bboxes_of_yolox_head(ctx,
Default False.
with_nms (bool): If True, do nms before return boxes.
Default True.
Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch

View File

@ -1,9 +1,9 @@
from .base import forward_of_base_detector
from .rpn import simple_test_of_rpn
from .single_stage import simple_test_of_single_stage
from .two_stage import extract_feat_of_two_stage
from .base import base_detector__forward
from .rpn import rpn__simple_test
from .single_stage import single_stage__simple_test
from .two_stage import two_stage__extract_feat
__all__ = [
'simple_test_of_single_stage', 'extract_feat_of_two_stage',
'forward_of_base_detector', 'simple_test_of_rpn'
'single_stage__simple_test', 'two_stage__extract_feat',
'base_detector__forward', 'rpn__simple_test'
]

View File

@ -6,8 +6,13 @@ from mmdeploy.utils import is_dynamic_shape
@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def _forward_of_base_detector_impl(ctx, self, img, img_metas=None, **kwargs):
"""Rewrite and adding mark for `forward`."""
def _base_detector__forward_impl(ctx, self, img, img_metas=None, **kwargs):
"""Rewrite and adding mark for `forward`.
Encapsulate this function for rewriting `forward` of BaseDetector.
1. Add mark for BaseDetector.
2. Support both dynamic and static export to onnx.
"""
assert isinstance(img_metas, dict)
assert isinstance(img, torch.Tensor)
@ -23,8 +28,31 @@ def _forward_of_base_detector_impl(ctx, self, img, img_metas=None, **kwargs):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.BaseDetector.forward')
def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs):
"""Rewrite `forward` for default backend."""
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
"""Rewrite `forward` of BaseDetector for default backend.
Rewrite this function to:
1. Create img_metas for exporting model to onnx.
2. Call `simple_test` directly to skip `aug_test`.
3. Remove `return_loss` because deployment has no need for training
functions.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the class BaseDetector.
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (Optional[list[dict]]): A list of image info dict where each
dict has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
if img_metas is None:
img_metas = {}
@ -36,5 +64,5 @@ def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs):
if 'return_loss' in kwargs:
kwargs.pop('return_loss')
return _forward_of_base_detector_impl(
return _base_detector__forward_impl(
ctx, self, img, img_metas=img_metas, **kwargs)

View File

@ -2,7 +2,25 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(func_name='mmdet.models.RPN.simple_test')
def simple_test_of_rpn(ctx, self, img, img_metas, **kwargs):
"""Rewrite `simple_test` for default backend."""
def rpn__simple_test(ctx, self, img, img_metas, **kwargs):
"""Rewrite `simple_test` for default backend.
Support configured dynamic/static shape for model input and return
detection result as Tensor instead of numpy array.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
img (Tensor | List[Tensor]): Input image tensor(s).
img_meta (dict): Dict containing image's meta information
such as `img_shape`.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is ``bboxes`` with shape (n, 5),
where 5 represent (tl_x, tl_y, br_x, br_y, score).
The shape of the second tensor in the tuple is ``labels``
with shape (n,)
"""
x = self.extract_feat(img)
return self.rpn_head.simple_test_rpn(x, img_metas)

View File

@ -3,7 +3,25 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.SingleStageDetector.simple_test')
def simple_test_of_single_stage(ctx, self, img, img_metas, **kwargs):
"""Rewrite `simple_test` for default backend."""
def single_stage__simple_test(ctx, self, img, img_metas, **kwargs):
"""Rewrite `simple_test` for default backend.
Support configured dynamic/static shape for model input and return
detection result as Tensor instead of numpy array.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
img (Tensor | List[Tensor]): Input image tensor(s).
img_meta (dict): Dict containing image's meta information
such as `img_shape`.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is ``bboxes`` with shape (n, 5),
where 5 represent (tl_x, tl_y, br_x, br_y, score).
The shape of the second tensor in the tuple is ``labels``
with shape (n,)
"""
feat = self.extract_feat(img)
return self.bbox_head.simple_test(feat, img_metas, **kwargs)

View File

@ -4,20 +4,53 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.TwoStageDetector.extract_feat')
@mark('extract_feat', inputs='img', outputs='feat')
def extract_feat_of_two_stage(ctx, self, img):
"""Rewrite `extract_feat` for default backend."""
def two_stage__extract_feat(ctx, self, img):
"""Rewrite `extract_feat` for default backend.
This function uses the specific `extract_feat` function for the two
stage detector after adding marks.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
img (Tensor | List[Tensor]): Input image tensor(s).
Returns:
list[Tensor]: Each item with shape (N, C, H, W) corresponds one
level of backbone and neck features.
"""
return ctx.origin_func(self, img)
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.TwoStageDetector.simple_test')
def simple_test_of_two_stage(ctx,
self,
img,
img_metas,
proposals=None,
**kwargs):
"""Rewrite `simple_test` for default backend."""
def two_stage__simple_test(ctx,
self,
img,
img_metas,
proposals=None,
**kwargs):
"""Rewrite `simple_test` for default backend.
Support configured dynamic/static shape for model input and return
detection result as Tensor instead of numpy array.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
img (Tensor | List[Tensor]): Input image tensor(s).
img_meta (dict): Dict containing image's meta information
such as `img_shape`.
proposals (List[Tensor]): Region proposals.
Default is None.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is ``bboxes`` with shape (n, 5),
where 5 represent (tl_x, tl_y, br_x, br_y, score).
The shape of the second tensor in the tuple is ``labels``
with shape (n,)
"""
assert self.with_bbox, 'Bbox head must be implemented.'
x = self.extract_feat(img)
if proposals is None:

View File

@ -1,3 +1,3 @@
from .bbox_head import get_bboxes_of_bbox_head
from .bbox_head import bbox_head__get_bboxes
__all__ = ['get_bboxes_of_bbox_head']
__all__ = ['bbox_head__get_bboxes']

View File

@ -14,16 +14,53 @@ from mmdeploy.utils import get_mmdet_params
'bbox_head_forward',
inputs=['bbox_feats'],
outputs=['cls_score', 'bbox_pred'])
def forward_of_bbox_head(ctx, self, x):
"""Rewrite `forward` for default backend."""
def bbox_head__forward(ctx, self, x):
"""Rewrite `forward` for default backend.
This function uses the specific `forward` function for the BBoxHead
or ConvFCBBoxHead after adding marks.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
x (Tensor): Input image tensor.
Returns:
tuple(Tensor, Tensor): The (cls_score, bbox_pred). The cls_score
has shape (N, num_det, num_cls) and the bbox_pred has shape
(N, num_det, 4).
"""
return ctx.origin_func(self, x)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.BBoxHead.get_bboxes')
def get_bboxes_of_bbox_head(ctx, self, rois, cls_score, bbox_pred, img_shape,
cfg, **kwargs):
"""Rewrite `get_bboxes` for default backend."""
def bbox_head__get_bboxes(ctx, self, rois, cls_score, bbox_pred, img_shape,
cfg, **kwargs):
"""Rewrite `get_bboxes` for default backend.
Transform network output for a batch into bbox predictions. Support
`reg_class_agnostic == False` case.
Args:
ctx (ContextCaller): The context with additional information.
self (ATSSHead): The instance of the class ATSSHead.
rois (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
cls_score (Tensor): Box scores, has shape
(num_boxes, num_classes + 1).
bbox_pred (Tensor, optional): Box energies / deltas.
has shape (num_boxes, num_classes * 4).
img_shape (Sequence[int], optional): Maximum bounds for boxes,
specifies (H, W, C) or (H, W).
cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
Returns:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
"""
assert rois.ndim == 3, 'Only support export two stage ' \
'model to ONNX ' \
'with batch dimension. '

View File

@ -5,9 +5,37 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.CascadeRoIHead.simple_test')
def simple_test_of_cascade_roi_head(ctx, self, x, proposals, img_metas,
**kwargs):
"""Rewrite `simple_test` for default backend."""
def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
**kwargs):
"""Rewrite `simple_test` for default backend.
1. This function eliminates the batch dimension to get forward bbox
results, and recover batch dimension to calculate final result
for deployment.
2. This function returns detection result as Tensor instead of numpy
array.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
x (tuple[Tensor]): Features from upstream network. Each
has shape (batch_size, c, h, w).
proposals (list(Tensor)): Proposals from rpn head.
Each has shape (num_proposals, 5), last dimension
5 represent (x1, y1, x2, y2, score).
img_metas (list[dict]): Meta information of images.
Returns:
If self.with_mask == True:
tuple[Tensor, Tensor, Tensor]: (det_bboxes, det_labels,
segm_results), `det_bboxes` of shape [N, num_det, 5],
`det_labels` of shape [N, num_det], and `segm_results`
of shape [N, num_det, roi_H, roi_W].
Else:
tuple[Tensor, Tensor]: (det_bboxes, det_labels),
`det_bboxes` of shape [N, num_det, 5] and `det_labels`
of shape [N, num_det].
"""
assert self.with_bbox, 'Bbox head must be implemented.'
assert proposals.shape[0] == 1, 'Only support one input image ' \
'while in exporting to ONNX'

View File

@ -1,3 +1,3 @@
from .fcn_mask_head import get_seg_masks_of_fcn_mask_head
from .fcn_mask_head import fcn_mask_head__get_seg_masks
__all__ = ['get_seg_masks_of_fcn_mask_head']
__all__ = ['fcn_mask_head__get_seg_masks']

View File

@ -7,11 +7,12 @@ from mmdeploy.utils import Backend, get_backend, get_mmdet_params
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.FCNMaskHead.get_seg_masks')
def get_seg_masks_of_fcn_mask_head(ctx, self, mask_pred, det_bboxes,
det_labels, rcnn_test_cfg, ori_shape,
**kwargs):
def fcn_mask_head__get_seg_masks(ctx, self, mask_pred, det_bboxes, det_labels,
rcnn_test_cfg, ori_shape, **kwargs):
"""Get segmentation masks from mask_pred and bboxes.
Rewrite the get_seg_masks for only fcn_mask_head inference.
Args:
mask_pred (Tensor): shape (n, #class, h, w).
det_bboxes (Tensor): shape (n, 4/5)

View File

@ -1,10 +1,8 @@
from .single_level_roi_extractor import (
forward_of_single_roi_extractor_dynamic,
forward_of_single_roi_extractor_dynamic_openvino,
forward_of_single_roi_extractor_static)
single_roi_extractor__forward, single_roi_extractor__forward__openvino,
single_roi_extractor__forward__tensorrt)
__all__ = [
'forward_of_single_roi_extractor_dynamic',
'forward_of_single_roi_extractor_static',
'forward_of_single_roi_extractor_dynamic_openvino'
'single_roi_extractor__forward', 'single_roi_extractor__forward__openvino',
'single_roi_extractor__forward__tensorrt'
]

View File

@ -61,12 +61,15 @@ class MultiLevelRoiAlign(Function):
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
backend='tensorrt')
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
def forward_of_single_roi_extractor_static(ctx,
self,
feats,
rois,
roi_scale_factor=None):
"""Rewrite `forward` for TensorRT backend."""
def single_roi_extractor__forward__tensorrt(ctx,
self,
feats,
rois,
roi_scale_factor=None):
"""Rewrite `forward` for TensorRT backend.
This function uses MMCVMultiLevelRoiAlign op for TensorRT deployment.
"""
featmap_strides = self.featmap_strides
finest_scale = self.finest_scale
@ -86,12 +89,16 @@ def forward_of_single_roi_extractor_static(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward')
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
def forward_of_single_roi_extractor_dynamic(ctx,
self,
feats,
rois,
roi_scale_factor=None):
"""Rewrite `forward` for default backend."""
def single_roi_extractor__forward(ctx,
self,
feats,
rois,
roi_scale_factor=None):
"""Rewrite `forward` for default backend.
Add mark for roi_extractor forward. Remove unnecessary code of origin
forward function.
"""
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
@ -159,13 +166,16 @@ class SingleRoIExtractorOpenVINO(Function):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
backend='openvino')
def forward_of_single_roi_extractor_dynamic_openvino(ctx,
self,
feats,
rois,
roi_scale_factor=None):
def single_roi_extractor__forward__openvino(ctx,
self,
feats,
rois,
roi_scale_factor=None):
"""Replaces SingleRoIExtractor with SingleRoIExtractorOpenVINO when
exporting to OpenVINO."""
exporting to OpenVINO.
This function uses ExperimentalDetectronROIFeatureExtractor for OpenVINO.
"""
# Adding original output to SingleRoIExtractorOpenVINO.
state = torch._C._get_tracing_state()

View File

@ -3,9 +3,34 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.StandardRoIHead.simple_test')
def simple_test_of_standard_roi_head(ctx, self, x, proposals, img_metas,
**kwargs):
"""Rewrite `simple_test` for default backend."""
def standard_roi_head__simple_test(ctx, self, x, proposals, img_metas,
**kwargs):
"""Rewrite `simple_test` for default backend.
This function returns detection result as Tensor instead of numpy
array.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
x (tuple[Tensor]): Features from upstream network. Each
has shape (batch_size, c, h, w).
proposals (list(Tensor)): Proposals from rpn head.
Each has shape (num_proposals, 5), last dimension
5 represent (x1, y1, x2, y2, score).
img_metas (list[dict]): Meta information of images.
Returns:
If self.with_mask == True:
tuple[Tensor, Tensor, Tensor]: (det_bboxes, det_labels,
segm_results), `det_bboxes` of shape [N, num_det, 5],
`det_labels` of shape [N, num_det], and `segm_results`
of shape [N, num_det, roi_H, roi_W].
Else:
tuple[Tensor, Tensor]: (det_bboxes, det_labels),
`det_bboxes` of shape [N, num_det, 5] and `det_labels`
of shape [N, num_det].
"""
assert self.with_bbox, 'Bbox head must be implemented.'
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_metas, proposals, self.test_cfg, rescale=False)

View File

@ -6,9 +6,31 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.test_mixins.\
BBoxTestMixin.simple_test_bboxes')
def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
rcnn_test_cfg, **kwargs):
"""Rewrite `simple_test_bboxes` for default backend."""
def bbox_test_mixin__simple_test_bboxes(ctx, self, x, img_metas, proposals,
rcnn_test_cfg, **kwargs):
"""Rewrite `simple_test_bboxes` for default backend.
1. This function eliminates the batch dimension to get forward bbox
results, and recover batch dimension to calculate final result
for deployment.
2. This function returns detection result as Tensor instead of numpy
array.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
x (tuple[Tensor]): Features from upstream network. Each
has shape (batch_size, c, h, w).
img_metas (list[dict]): Meta information of images.
proposals (list(Tensor)): Proposals from rpn head.
Each has shape (num_proposals, 5), last dimension
5 represent (x1, y1, x2, y2, score).
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
Returns:
tuple[Tensor, Tensor]: (det_bboxes, det_labels), `det_bboxes` of
shape [N, num_det, 5] and `det_labels` of shape [N, num_det].
"""
rois = proposals
batch_index = torch.arange(
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
@ -38,9 +60,28 @@ def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.test_mixins.\
MaskTestMixin.simple_test_mask')
def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes,
det_labels, **kwargs):
"""Rewrite `simple_test_mask` for default backend."""
def mask_test_mixin__simple_test_mask(ctx, self, x, img_metas, det_bboxes,
det_labels, **kwargs):
"""Rewrite `simple_test_mask` for default backend.
This function returns detection result as Tensor instead of numpy
array.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
x (tuple[Tensor]): Features from upstream network. Each
has shape (batch_size, c, h, w).
img_metas (list[dict]): Meta information of images.
det_bboxes (tuple[Tensor]): Detection bounding-boxes from features.
Each has shape of (batch_size, num_det, 5).
det_labels (tuple[Tensor]): Detection labels from features. Each
has shape of (batch_size, num_det).
Returns:
tuple[Tensor]: (segm_results), `segm_results` of shape
[N, num_det, roi_H, roi_W].
"""
if det_bboxes.shape[1] == 0:
bboxes_shape, labels_shape = list(det_bboxes.shape), list(
det_labels.shape)