mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Docstring]: Update docstring and names of rewrite in mmdet (#194)
* refactor mmdet docstring * fix yapf * fix lint * fix docformatter * fix docformatter --wrap-descriptions 79 * reply comments * fix runningleon comments; * fix misleading
This commit is contained in:
parent
fa626a58a0
commit
6f98c423a1
@ -563,7 +563,7 @@ class PartitionTwoStageDetector(DeployBaseDetector):
|
||||
from mmdet.models.builder import build_head, build_roi_extractor
|
||||
|
||||
from mmdeploy.mmdet.models.roi_heads.bbox_heads import \
|
||||
get_bboxes_of_bbox_head
|
||||
bbox_head__get_bboxes
|
||||
|
||||
# load cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
@ -580,7 +580,7 @@ class PartitionTwoStageDetector(DeployBaseDetector):
|
||||
|
||||
ctx = Context()
|
||||
ctx.cfg = self.deploy_cfg
|
||||
self.get_bboxes_of_bbox_head = partial(get_bboxes_of_bbox_head, ctx)
|
||||
self.bbox_head__get_bboxes = partial(bbox_head__get_bboxes, ctx)
|
||||
|
||||
def partition0_postprocess(self, x: Sequence[torch.Tensor],
|
||||
scores: torch.Tensor, bboxes: torch.Tensor):
|
||||
@ -656,10 +656,10 @@ class PartitionTwoStageDetector(DeployBaseDetector):
|
||||
bbox_pred.size(-1))
|
||||
|
||||
rcnn_test_cfg = self.model_cfg.model.test_cfg.rcnn
|
||||
return self.get_bboxes_of_bbox_head(self.bbox_head, rois, cls_score,
|
||||
bbox_pred,
|
||||
img_metas[0][0]['img_shape'],
|
||||
rcnn_test_cfg)
|
||||
return self.bbox_head__get_bboxes(self.bbox_head, rois, cls_score,
|
||||
bbox_pred,
|
||||
img_metas[0][0]['img_shape'],
|
||||
rcnn_test_cfg)
|
||||
|
||||
|
||||
class ONNXRuntimePTSDetector(PartitionTwoStageDetector):
|
||||
|
@ -17,7 +17,40 @@ def delta2bbox(ctx,
|
||||
clip_border=True,
|
||||
add_ctr_clamp=False,
|
||||
ctr_clamp=32):
|
||||
"""Rewrite for ONNX exporting of default backend."""
|
||||
"""Rewrite `delta2bbox` for default backend.
|
||||
|
||||
Since the need of clip op with dynamic min and max, this function uses
|
||||
clip_bboxes function to support dynamic shape.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
|
||||
deltas (Tensor): Encoded offsets with respect to each roi.
|
||||
Has shape (B, N, num_classes * 4) or (B, N, 4) or
|
||||
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
|
||||
when rois is a grid of anchors.Offset encoding follows [1]_.
|
||||
means (Sequence[float]): Denormalizing means for delta coordinates
|
||||
stds (Sequence[float]): Denormalizing standard deviation for delta
|
||||
coordinates
|
||||
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
||||
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
||||
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
|
||||
the max_shape should be a Sequence[Sequence[int]]
|
||||
and the length of max_shape should also be B.
|
||||
wh_ratio_clip (float): Maximum aspect ratio for boxes.
|
||||
clip_border (bool, optional): Whether clip the objects outside the
|
||||
border of the image. Defaults to True.
|
||||
add_ctr_clamp (bool): Whether to add center clamp, when added, the
|
||||
predicted box is clamped is its center is too far away from
|
||||
the original anchor's center. Only used by YOLOF. Default False.
|
||||
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
|
||||
Default 32.
|
||||
|
||||
Return:
|
||||
bboxes (Tensor): Boxes with shape (B, N, num_classes * 4) or (B, N, 4)
|
||||
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
|
||||
br_x, br_y.
|
||||
"""
|
||||
means = deltas.new_tensor(means).view(1,
|
||||
-1).repeat(1,
|
||||
deltas.size(-1) // 4)
|
||||
@ -82,7 +115,43 @@ def delta2bbox_ncnn(ctx,
|
||||
clip_border=True,
|
||||
add_ctr_clamp=False,
|
||||
ctr_clamp=32):
|
||||
"""Rewrite for ONNX exporting of NCNN backend."""
|
||||
"""Rewrite `delta2bbox` for ncnn backend.
|
||||
|
||||
Batch dimension is not supported by ncnn, but supported by pytorch.
|
||||
NCNN regards the lowest two dimensions as continuous address with byte
|
||||
alignment, so the lowest two dimensions are not absolutely independent.
|
||||
Reshape operator with -1 arguments should operates ncnn::Mat with
|
||||
dimension >= 3.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
|
||||
deltas (Tensor): Encoded offsets with respect to each roi.
|
||||
Has shape (B, N, num_classes * 4) or (B, N, 4) or
|
||||
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
|
||||
when rois is a grid of anchors.Offset encoding follows [1]_.
|
||||
means (Sequence[float]): Denormalizing means for delta coordinates
|
||||
stds (Sequence[float]): Denormalizing standard deviation for delta
|
||||
coordinates
|
||||
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
||||
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
||||
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
|
||||
the max_shape should be a Sequence[Sequence[int]]
|
||||
and the length of max_shape should also be B.
|
||||
wh_ratio_clip (float): Maximum aspect ratio for boxes.
|
||||
clip_border (bool, optional): Whether clip the objects outside the
|
||||
border of the image. Defaults to True.
|
||||
add_ctr_clamp (bool): Whether to add center clamp, when added, the
|
||||
predicted box is clamped is its center is too far away from
|
||||
the original anchor's center. Only used by YOLOF. Default False.
|
||||
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
|
||||
Default 32.
|
||||
|
||||
Return:
|
||||
bboxes (Tensor): Boxes with shape (B, N, num_classes * 4) or (B, N, 4)
|
||||
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
|
||||
br_x, br_y.
|
||||
"""
|
||||
means = deltas.new_tensor(means).view(1, 1,
|
||||
-1).repeat(1, deltas.size(-2),
|
||||
deltas.size(-1) // 4).data
|
||||
|
@ -13,7 +13,34 @@ def tblr2bboxes(ctx,
|
||||
normalize_by_wh=True,
|
||||
max_shape=None,
|
||||
clip_border=True):
|
||||
"""Rewrite for ONNX exporting of default backend."""
|
||||
"""Rewrite `tblr2bboxes` for default backend.
|
||||
|
||||
Since the need of clip op with dynamic min and max, this function uses
|
||||
clip_bboxes function to support dynamic shape.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
|
||||
Shape: (N,4) or (B, N, 4).
|
||||
tblr (Tensor): Coords of network output in tblr form
|
||||
Shape: (N, 4) or (B, N, 4).
|
||||
normalizer (Sequence[float] | float): Normalization parameter of
|
||||
encoded boxes. By list, it represents the normalization factors at
|
||||
tblr dims. By float, it is the unified normalization factor at all
|
||||
dims. Default: 4.0
|
||||
normalize_by_wh (bool): Whether the tblr coordinates have been
|
||||
normalized by the side length (wh) of prior bboxes.
|
||||
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
||||
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
||||
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
|
||||
the max_shape should be a Sequence[Sequence[int]]
|
||||
and the length of max_shape should also be B.
|
||||
clip_border (bool, optional): Whether clip the objects outside the
|
||||
border of the image. Defaults to True.
|
||||
|
||||
Return:
|
||||
bboxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
|
||||
"""
|
||||
if not isinstance(normalizer, float):
|
||||
normalizer = torch.tensor(normalizer, device=priors.device)
|
||||
assert len(normalizer) == 4, 'Normalizer must have length = 4'
|
||||
@ -55,7 +82,35 @@ def tblr2bboxes_ncnn(ctx,
|
||||
normalize_by_wh=True,
|
||||
max_shape=None,
|
||||
clip_border=True):
|
||||
"""Rewrite for ONNX exporting of NCNN backend."""
|
||||
"""Rewrite `tblr2bboxes` for ncnn backend.
|
||||
|
||||
Batch dimension is not supported by ncnn, but supported by pytorch.
|
||||
The negative value of axis in torch.cat is rewritten as corresponding
|
||||
positive value to avoid axis shift.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
|
||||
Shape: (N,4) or (B, N, 4).
|
||||
tblr (Tensor): Coords of network output in tblr form
|
||||
Shape: (N, 4) or (B, N, 4).
|
||||
normalizer (Sequence[float] | float): Normalization parameter of
|
||||
encoded boxes. By list, it represents the normalization factors at
|
||||
tblr dims. By float, it is the unified normalization factor at all
|
||||
dims. Default: 4.0
|
||||
normalize_by_wh (bool): Whether the tblr coordinates have been
|
||||
normalized by the side length (wh) of prior bboxes.
|
||||
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
||||
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
||||
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
|
||||
the max_shape should be a Sequence[Sequence[int]]
|
||||
and the length of max_shape should also be B.
|
||||
clip_border (bool, optional): Whether clip the objects outside the
|
||||
border of the image. Defaults to True.
|
||||
|
||||
Return:
|
||||
bboxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
|
||||
"""
|
||||
assert priors.size(0) == tblr.size(0)
|
||||
if priors.ndim == 3:
|
||||
assert priors.size(1) == tblr.size(1)
|
||||
|
@ -1,14 +1,14 @@
|
||||
from .anchor_head import get_bboxes_of_anchor_head
|
||||
from .atss_head import get_bboxes_of_atss_head
|
||||
from .fcos_head import get_bboxes_of_fcos_head
|
||||
from .fovea_head import get_bboxes_of_fovea_head
|
||||
from .rpn_head import get_bboxes_of_rpn_head
|
||||
from .anchor_head import anchor_head__get_bboxes
|
||||
from .atss_head import atss_head__get_bboxes
|
||||
from .fcos_head import fcos_head__get_bboxes
|
||||
from .fovea_head import fovea_head__get_bboxes
|
||||
from .rpn_head import rpn_head__get_bboxes
|
||||
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
|
||||
from .yolox_head import get_bboxes_of_yolox_head
|
||||
from .yolox_head import yolox_head__get_bboxes
|
||||
|
||||
__all__ = [
|
||||
'get_bboxes_of_anchor_head', 'get_bboxes_of_fcos_head',
|
||||
'get_bboxes_of_rpn_head', 'get_bboxes_of_fovea_head',
|
||||
'get_bboxes_of_atss_head', 'yolov3_head__get_bboxes',
|
||||
'yolov3_head__get_bboxes__ncnn', 'get_bboxes_of_yolox_head'
|
||||
'anchor_head__get_bboxes', 'atss_head__get_bboxes',
|
||||
'fcos_head__get_bboxes', 'fovea_head__get_bboxes', 'rpn_head__get_bboxes',
|
||||
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn',
|
||||
'yolox_head__get_bboxes'
|
||||
]
|
||||
|
@ -9,15 +9,45 @@ from mmdeploy.utils import (Backend, get_backend, get_mmdet_params,
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.AnchorHead.get_bboxes')
|
||||
def get_bboxes_of_anchor_head(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend."""
|
||||
def anchor_head__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of AnchorHead for default backend.
|
||||
|
||||
Rewrite this function to support deployment of default backend and
|
||||
dynamic shape export.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
cls_scores (list[Tensor]): Box scores for each level in the
|
||||
feature pyramid, has shape
|
||||
(N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each
|
||||
level in the feature pyramid, has shape
|
||||
(N, num_anchors * 4, H, W).
|
||||
img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
Default: None.
|
||||
|
||||
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
@ -117,15 +147,45 @@ def get_bboxes_of_anchor_head(ctx,
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.AnchorHead.get_bboxes', backend='ncnn')
|
||||
def get_bboxes_of_anchor_head_ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for NCNN backend."""
|
||||
def anchor_head__get_bboxes__ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of AnchorHead for NCNN backend.
|
||||
|
||||
Shape node and batch inference is not supported by ncnn. This function
|
||||
transform dynamic shape to constant shape and remove batch inference.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
cls_scores (list[Tensor]): Box scores for each level in the
|
||||
feature pyramid, has shape
|
||||
(N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each
|
||||
level in the feature pyramid, has shape
|
||||
(N, num_anchors * 4, H, W).
|
||||
img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
Default: None.
|
||||
|
||||
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
|
@ -6,18 +6,19 @@ from mmdeploy.utils import get_mmdet_params
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter('mmdet.models.ATSSHead.get_bboxes')
|
||||
def get_bboxes_of_atss_head(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
centernesses,
|
||||
img_metas,
|
||||
cfg=None,
|
||||
rescale=False,
|
||||
with_nms=True):
|
||||
def atss_head__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
centernesses,
|
||||
img_metas,
|
||||
cfg=None,
|
||||
rescale=False,
|
||||
with_nms=True):
|
||||
"""Rewrite `get_bboxes` from ATSSHead for default backend.
|
||||
|
||||
Transform network output for a batch into bbox predictions.
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
@ -39,12 +40,9 @@ def get_bboxes_of_atss_head(ctx,
|
||||
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: The first item is an (n, 5) tensor,
|
||||
where 5 represent (tl_x, tl_y, br_x, br_y, score) and
|
||||
the score between 0 and 1.
|
||||
The shape of the second tensor in the tuple is (n,), and
|
||||
each element represents the class label of the corresponding
|
||||
box.
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
|
@ -9,16 +9,48 @@ from mmdeploy.utils import (Backend, get_backend, get_mmdet_params,
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.FCOSHead.get_bboxes')
|
||||
def get_bboxes_of_fcos_head(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
centernesses,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend."""
|
||||
def fcos_head__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
centernesses,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of FCOSHead for default backend.
|
||||
|
||||
Rewrite this function to support deployment of default backend
|
||||
and dynamic shape export. Transform network output for a batch into
|
||||
bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (ATSSHead): The instance of the class ATSSHead.
|
||||
cls_scores (list[Tensor]): Box scores for each scale level
|
||||
with shape (N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
||||
level with shape (N, num_anchors * 4, H, W).
|
||||
centernesses (list[Tensor]): Centerness for each scale level with
|
||||
shape (N, num_anchors * 1, H, W).
|
||||
img_metas (dict): Meta information of the image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
|
||||
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
"""
|
||||
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
@ -111,16 +143,49 @@ def get_bboxes_of_fcos_head(ctx,
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.FCOSHead.get_bboxes', backend='ncnn')
|
||||
def get_bboxes_of_fcos_head_ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
centernesses,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for NCNN backend."""
|
||||
def fcos_head__get_bboxes__ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
centernesses,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of FCOSHead for ncnn backend.
|
||||
|
||||
1. Shape node and batch inference is not supported by ncnn. This function
|
||||
transform dynamic shape to constant shape and remove batch inference.
|
||||
2. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by
|
||||
ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for
|
||||
correct `BinaryOps` calculation by ncnn.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (ATSSHead): The instance of the class ATSSHead.
|
||||
cls_scores (list[Tensor]): Box scores for each scale level
|
||||
with shape (N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
||||
level with shape (N, num_anchors * 4, H, W).
|
||||
centernesses (list[Tensor]): Centerness for each scale level with
|
||||
shape (N, num_anchors * 1, H, W).
|
||||
img_metas (dict): Meta information of the image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
|
||||
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
|
@ -6,16 +6,17 @@ from mmdeploy.utils import get_mmdet_params
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter('mmdet.models.FoveaHead.get_bboxes')
|
||||
def get_bboxes_of_fovea_head(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
cfg=None,
|
||||
rescale=None):
|
||||
def fovea_head__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
cfg=None,
|
||||
rescale=None):
|
||||
"""Rewrite `get_bboxes` from FoveaHead for default backend.
|
||||
|
||||
Transform network output for a batch into bbox predictions.
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
@ -32,12 +33,9 @@ def get_bboxes_of_fovea_head(ctx,
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: The first item is an (n, 5) tensor,
|
||||
where 5 represent (tl_x, tl_y, br_x, br_y, score) and
|
||||
the score between 0 and 1.
|
||||
The shape of the second tensor in the tuple is (n,), and
|
||||
each element represents the class label of the corresponding
|
||||
box.
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
|
@ -8,15 +8,41 @@ from mmdeploy.utils import (Backend, get_backend, get_mmdet_params,
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter('mmdet.models.RPNHead.get_bboxes')
|
||||
def get_bboxes_of_rpn_head(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend."""
|
||||
def rpn_head__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of RPNHead for default backend.
|
||||
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (FoveaHead): The instance of the class FoveaHead.
|
||||
cls_scores (list[Tensor]): Box scores for each scale level
|
||||
with shape (N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
||||
level with shape (N, num_anchors * 4, H, W).
|
||||
img_metas (dict): Meta information of the image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
@ -113,15 +139,44 @@ def get_bboxes_of_rpn_head(ctx,
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.RPNHead.get_bboxes', backend='ncnn')
|
||||
def get_bboxes_of_rpn_head_ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for NCNN backend."""
|
||||
def rpn_head__get_bboxes__ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of RPNHead for ncnn backend.
|
||||
|
||||
Shape node and batch inference is not supported by ncnn. This function
|
||||
transform dynamic shape to constant shape and remove batch inference.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
cls_scores (list[Tensor]): Box scores for each level in the
|
||||
feature pyramid, has shape
|
||||
(N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each
|
||||
level in the feature pyramid, has shape
|
||||
(N, num_anchors * 4, H, W).
|
||||
img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
Default: None.
|
||||
|
||||
|
||||
Returns:
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
|
@ -15,25 +15,27 @@ def yolov3_head__get_bboxes(ctx,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend.
|
||||
"""Rewrite `get_bboxes` of YOLOV3Head for default backend.
|
||||
|
||||
Transform network output for a batch into bbox predictions.
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx: Context that contains original meta information.
|
||||
self: Represent the instance of the original class.
|
||||
pred_maps (list[Tensor]): Raw predictions for a batch of images.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
|
||||
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
|
||||
size and the score between 0 and 1. The shape of the second
|
||||
tensor in the tuple is (N, num_box), and each element
|
||||
represents the class label of the corresponding box.
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
is_dynamic_flag = is_dynamic_shape(ctx.cfg)
|
||||
num_levels = len(pred_maps)
|
||||
@ -161,25 +163,34 @@ def yolov3_head__get_bboxes__ncnn(ctx,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for ncnn backend.
|
||||
"""Rewrite `get_bboxes` of YOLOV3Head for ncnn backend.
|
||||
|
||||
1. Shape node and batch inference is not supported by ncnn. This function
|
||||
transform dynamic shape to constant shape and remove batch inference.
|
||||
2. Batch dimension is not supported by ncnn, but supported by pytorch.
|
||||
The negative value of axis in torch.cat is rewritten as corresponding
|
||||
positive value to avoid axis shift.
|
||||
3. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by
|
||||
ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for
|
||||
correct `BinaryOps` calculation by ncnn.
|
||||
|
||||
Transform network output for a batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx: Context that contains original meta information.
|
||||
self: Represent the instance of the original class.
|
||||
pred_maps (list[Tensor]): Raw predictions for a batch of images.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default: None.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
|
||||
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
|
||||
size and the score between 0 and 1. The shape of the second
|
||||
tensor in the tuple is (N, num_box), and each element
|
||||
represents the class label of the corresponding box.
|
||||
If with_nms == True:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
Else:
|
||||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
num_levels = len(pred_maps)
|
||||
pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]
|
||||
|
@ -7,18 +7,19 @@ from mmdeploy.utils import get_mmdet_params
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.YOLOXHead.get_bboxes')
|
||||
def get_bboxes_of_yolox_head(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
objectnesses,
|
||||
img_metas=None,
|
||||
cfg=None,
|
||||
rescale=False,
|
||||
with_nms=True):
|
||||
"""Rewrite `get_bboxes` for default backend.
|
||||
def yolox_head__get_bboxes(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
objectnesses,
|
||||
img_metas=None,
|
||||
cfg=None,
|
||||
rescale=False,
|
||||
with_nms=True):
|
||||
"""Rewrite `get_bboxes` of YOLOXHead for default backend.
|
||||
|
||||
Transform network outputs of a batch into bbox results.
|
||||
Rewrite this function to deploy model, transform network output for a
|
||||
batch into bbox predictions.
|
||||
|
||||
Args:
|
||||
ctx: Context that contains original meta information.
|
||||
@ -39,6 +40,7 @@ def get_bboxes_of_yolox_head(ctx,
|
||||
Default False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default True.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
|
||||
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
|
||||
|
@ -1,9 +1,9 @@
|
||||
from .base import forward_of_base_detector
|
||||
from .rpn import simple_test_of_rpn
|
||||
from .single_stage import simple_test_of_single_stage
|
||||
from .two_stage import extract_feat_of_two_stage
|
||||
from .base import base_detector__forward
|
||||
from .rpn import rpn__simple_test
|
||||
from .single_stage import single_stage__simple_test
|
||||
from .two_stage import two_stage__extract_feat
|
||||
|
||||
__all__ = [
|
||||
'simple_test_of_single_stage', 'extract_feat_of_two_stage',
|
||||
'forward_of_base_detector', 'simple_test_of_rpn'
|
||||
'single_stage__simple_test', 'two_stage__extract_feat',
|
||||
'base_detector__forward', 'rpn__simple_test'
|
||||
]
|
||||
|
@ -6,8 +6,13 @@ from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
@mark(
|
||||
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
||||
def _forward_of_base_detector_impl(ctx, self, img, img_metas=None, **kwargs):
|
||||
"""Rewrite and adding mark for `forward`."""
|
||||
def _base_detector__forward_impl(ctx, self, img, img_metas=None, **kwargs):
|
||||
"""Rewrite and adding mark for `forward`.
|
||||
|
||||
Encapsulate this function for rewriting `forward` of BaseDetector.
|
||||
1. Add mark for BaseDetector.
|
||||
2. Support both dynamic and static export to onnx.
|
||||
"""
|
||||
assert isinstance(img_metas, dict)
|
||||
assert isinstance(img, torch.Tensor)
|
||||
|
||||
@ -23,8 +28,31 @@ def _forward_of_base_detector_impl(ctx, self, img, img_metas=None, **kwargs):
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.BaseDetector.forward')
|
||||
def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
|
||||
"""Rewrite `forward` of BaseDetector for default backend.
|
||||
|
||||
Rewrite this function to:
|
||||
1. Create img_metas for exporting model to onnx.
|
||||
2. Call `simple_test` directly to skip `aug_test`.
|
||||
3. Remove `return_loss` because deployment has no need for training
|
||||
functions.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the class BaseDetector.
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (Optional[list[dict]]): A list of image info dict where each
|
||||
dict has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys, see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
|
||||
Returns:
|
||||
list[list[np.ndarray]]: BBox results of each image and classes.
|
||||
The outer list corresponds to each image. The inner list
|
||||
corresponds to each class.
|
||||
"""
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
|
||||
@ -36,5 +64,5 @@ def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs):
|
||||
|
||||
if 'return_loss' in kwargs:
|
||||
kwargs.pop('return_loss')
|
||||
return _forward_of_base_detector_impl(
|
||||
return _base_detector__forward_impl(
|
||||
ctx, self, img, img_metas=img_metas, **kwargs)
|
||||
|
@ -2,7 +2,25 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='mmdet.models.RPN.simple_test')
|
||||
def simple_test_of_rpn(ctx, self, img, img_metas, **kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
def rpn__simple_test(ctx, self, img, img_metas, **kwargs):
|
||||
"""Rewrite `simple_test` for default backend.
|
||||
|
||||
Support configured dynamic/static shape for model input and return
|
||||
detection result as Tensor instead of numpy array.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
img (Tensor | List[Tensor]): Input image tensor(s).
|
||||
img_meta (dict): Dict containing image's meta information
|
||||
such as `img_shape`.
|
||||
|
||||
Returns:
|
||||
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
|
||||
The first item is ``bboxes`` with shape (n, 5),
|
||||
where 5 represent (tl_x, tl_y, br_x, br_y, score).
|
||||
The shape of the second tensor in the tuple is ``labels``
|
||||
with shape (n,)
|
||||
"""
|
||||
x = self.extract_feat(img)
|
||||
return self.rpn_head.simple_test_rpn(x, img_metas)
|
||||
|
@ -3,7 +3,25 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.SingleStageDetector.simple_test')
|
||||
def simple_test_of_single_stage(ctx, self, img, img_metas, **kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
def single_stage__simple_test(ctx, self, img, img_metas, **kwargs):
|
||||
"""Rewrite `simple_test` for default backend.
|
||||
|
||||
Support configured dynamic/static shape for model input and return
|
||||
detection result as Tensor instead of numpy array.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
img (Tensor | List[Tensor]): Input image tensor(s).
|
||||
img_meta (dict): Dict containing image's meta information
|
||||
such as `img_shape`.
|
||||
|
||||
Returns:
|
||||
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
|
||||
The first item is ``bboxes`` with shape (n, 5),
|
||||
where 5 represent (tl_x, tl_y, br_x, br_y, score).
|
||||
The shape of the second tensor in the tuple is ``labels``
|
||||
with shape (n,)
|
||||
"""
|
||||
feat = self.extract_feat(img)
|
||||
return self.bbox_head.simple_test(feat, img_metas, **kwargs)
|
||||
|
@ -4,20 +4,53 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.TwoStageDetector.extract_feat')
|
||||
@mark('extract_feat', inputs='img', outputs='feat')
|
||||
def extract_feat_of_two_stage(ctx, self, img):
|
||||
"""Rewrite `extract_feat` for default backend."""
|
||||
def two_stage__extract_feat(ctx, self, img):
|
||||
"""Rewrite `extract_feat` for default backend.
|
||||
|
||||
This function uses the specific `extract_feat` function for the two
|
||||
stage detector after adding marks.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
img (Tensor | List[Tensor]): Input image tensor(s).
|
||||
|
||||
Returns:
|
||||
list[Tensor]: Each item with shape (N, C, H, W) corresponds one
|
||||
level of backbone and neck features.
|
||||
"""
|
||||
return ctx.origin_func(self, img)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.TwoStageDetector.simple_test')
|
||||
def simple_test_of_two_stage(ctx,
|
||||
self,
|
||||
img,
|
||||
img_metas,
|
||||
proposals=None,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
def two_stage__simple_test(ctx,
|
||||
self,
|
||||
img,
|
||||
img_metas,
|
||||
proposals=None,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend.
|
||||
|
||||
Support configured dynamic/static shape for model input and return
|
||||
detection result as Tensor instead of numpy array.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
img (Tensor | List[Tensor]): Input image tensor(s).
|
||||
img_meta (dict): Dict containing image's meta information
|
||||
such as `img_shape`.
|
||||
proposals (List[Tensor]): Region proposals.
|
||||
Default is None.
|
||||
|
||||
Returns:
|
||||
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
|
||||
The first item is ``bboxes`` with shape (n, 5),
|
||||
where 5 represent (tl_x, tl_y, br_x, br_y, score).
|
||||
The shape of the second tensor in the tuple is ``labels``
|
||||
with shape (n,)
|
||||
"""
|
||||
assert self.with_bbox, 'Bbox head must be implemented.'
|
||||
x = self.extract_feat(img)
|
||||
if proposals is None:
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .bbox_head import get_bboxes_of_bbox_head
|
||||
from .bbox_head import bbox_head__get_bboxes
|
||||
|
||||
__all__ = ['get_bboxes_of_bbox_head']
|
||||
__all__ = ['bbox_head__get_bboxes']
|
||||
|
@ -14,16 +14,53 @@ from mmdeploy.utils import get_mmdet_params
|
||||
'bbox_head_forward',
|
||||
inputs=['bbox_feats'],
|
||||
outputs=['cls_score', 'bbox_pred'])
|
||||
def forward_of_bbox_head(ctx, self, x):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
def bbox_head__forward(ctx, self, x):
|
||||
"""Rewrite `forward` for default backend.
|
||||
|
||||
This function uses the specific `forward` function for the BBoxHead
|
||||
or ConvFCBBoxHead after adding marks.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
x (Tensor): Input image tensor.
|
||||
|
||||
Returns:
|
||||
tuple(Tensor, Tensor): The (cls_score, bbox_pred). The cls_score
|
||||
has shape (N, num_det, num_cls) and the bbox_pred has shape
|
||||
(N, num_det, 4).
|
||||
"""
|
||||
return ctx.origin_func(self, x)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.BBoxHead.get_bboxes')
|
||||
def get_bboxes_of_bbox_head(ctx, self, rois, cls_score, bbox_pred, img_shape,
|
||||
cfg, **kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend."""
|
||||
def bbox_head__get_bboxes(ctx, self, rois, cls_score, bbox_pred, img_shape,
|
||||
cfg, **kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend.
|
||||
|
||||
Transform network output for a batch into bbox predictions. Support
|
||||
`reg_class_agnostic == False` case.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (ATSSHead): The instance of the class ATSSHead.
|
||||
rois (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
|
||||
last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
|
||||
cls_score (Tensor): Box scores, has shape
|
||||
(num_boxes, num_classes + 1).
|
||||
bbox_pred (Tensor, optional): Box energies / deltas.
|
||||
has shape (num_boxes, num_classes * 4).
|
||||
img_shape (Sequence[int], optional): Maximum bounds for boxes,
|
||||
specifies (H, W, C) or (H, W).
|
||||
cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
|
||||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
"""
|
||||
assert rois.ndim == 3, 'Only support export two stage ' \
|
||||
'model to ONNX ' \
|
||||
'with batch dimension. '
|
||||
|
@ -5,9 +5,37 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.CascadeRoIHead.simple_test')
|
||||
def simple_test_of_cascade_roi_head(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend.
|
||||
|
||||
1. This function eliminates the batch dimension to get forward bbox
|
||||
results, and recover batch dimension to calculate final result
|
||||
for deployment.
|
||||
2. This function returns detection result as Tensor instead of numpy
|
||||
array.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
x (tuple[Tensor]): Features from upstream network. Each
|
||||
has shape (batch_size, c, h, w).
|
||||
proposals (list(Tensor)): Proposals from rpn head.
|
||||
Each has shape (num_proposals, 5), last dimension
|
||||
5 represent (x1, y1, x2, y2, score).
|
||||
img_metas (list[dict]): Meta information of images.
|
||||
|
||||
Returns:
|
||||
If self.with_mask == True:
|
||||
tuple[Tensor, Tensor, Tensor]: (det_bboxes, det_labels,
|
||||
segm_results), `det_bboxes` of shape [N, num_det, 5],
|
||||
`det_labels` of shape [N, num_det], and `segm_results`
|
||||
of shape [N, num_det, roi_H, roi_W].
|
||||
Else:
|
||||
tuple[Tensor, Tensor]: (det_bboxes, det_labels),
|
||||
`det_bboxes` of shape [N, num_det, 5] and `det_labels`
|
||||
of shape [N, num_det].
|
||||
"""
|
||||
assert self.with_bbox, 'Bbox head must be implemented.'
|
||||
assert proposals.shape[0] == 1, 'Only support one input image ' \
|
||||
'while in exporting to ONNX'
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .fcn_mask_head import get_seg_masks_of_fcn_mask_head
|
||||
from .fcn_mask_head import fcn_mask_head__get_seg_masks
|
||||
|
||||
__all__ = ['get_seg_masks_of_fcn_mask_head']
|
||||
__all__ = ['fcn_mask_head__get_seg_masks']
|
||||
|
@ -7,11 +7,12 @@ from mmdeploy.utils import Backend, get_backend, get_mmdet_params
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.FCNMaskHead.get_seg_masks')
|
||||
def get_seg_masks_of_fcn_mask_head(ctx, self, mask_pred, det_bboxes,
|
||||
det_labels, rcnn_test_cfg, ori_shape,
|
||||
**kwargs):
|
||||
def fcn_mask_head__get_seg_masks(ctx, self, mask_pred, det_bboxes, det_labels,
|
||||
rcnn_test_cfg, ori_shape, **kwargs):
|
||||
"""Get segmentation masks from mask_pred and bboxes.
|
||||
|
||||
Rewrite the get_seg_masks for only fcn_mask_head inference.
|
||||
|
||||
Args:
|
||||
mask_pred (Tensor): shape (n, #class, h, w).
|
||||
det_bboxes (Tensor): shape (n, 4/5)
|
||||
|
@ -1,10 +1,8 @@
|
||||
from .single_level_roi_extractor import (
|
||||
forward_of_single_roi_extractor_dynamic,
|
||||
forward_of_single_roi_extractor_dynamic_openvino,
|
||||
forward_of_single_roi_extractor_static)
|
||||
single_roi_extractor__forward, single_roi_extractor__forward__openvino,
|
||||
single_roi_extractor__forward__tensorrt)
|
||||
|
||||
__all__ = [
|
||||
'forward_of_single_roi_extractor_dynamic',
|
||||
'forward_of_single_roi_extractor_static',
|
||||
'forward_of_single_roi_extractor_dynamic_openvino'
|
||||
'single_roi_extractor__forward', 'single_roi_extractor__forward__openvino',
|
||||
'single_roi_extractor__forward__tensorrt'
|
||||
]
|
||||
|
@ -61,12 +61,15 @@ class MultiLevelRoiAlign(Function):
|
||||
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
|
||||
backend='tensorrt')
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
def forward_of_single_roi_extractor_static(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
"""Rewrite `forward` for TensorRT backend."""
|
||||
def single_roi_extractor__forward__tensorrt(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
"""Rewrite `forward` for TensorRT backend.
|
||||
|
||||
This function uses MMCVMultiLevelRoiAlign op for TensorRT deployment.
|
||||
"""
|
||||
featmap_strides = self.featmap_strides
|
||||
finest_scale = self.finest_scale
|
||||
|
||||
@ -86,12 +89,16 @@ def forward_of_single_roi_extractor_static(ctx,
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward')
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
def forward_of_single_roi_extractor_dynamic(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
def single_roi_extractor__forward(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
"""Rewrite `forward` for default backend.
|
||||
|
||||
Add mark for roi_extractor forward. Remove unnecessary code of origin
|
||||
forward function.
|
||||
"""
|
||||
out_size = self.roi_layers[0].output_size
|
||||
num_levels = len(feats)
|
||||
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
|
||||
@ -159,13 +166,16 @@ class SingleRoIExtractorOpenVINO(Function):
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
|
||||
backend='openvino')
|
||||
def forward_of_single_roi_extractor_dynamic_openvino(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
def single_roi_extractor__forward__openvino(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
"""Replaces SingleRoIExtractor with SingleRoIExtractorOpenVINO when
|
||||
exporting to OpenVINO."""
|
||||
exporting to OpenVINO.
|
||||
|
||||
This function uses ExperimentalDetectronROIFeatureExtractor for OpenVINO.
|
||||
"""
|
||||
|
||||
# Adding original output to SingleRoIExtractorOpenVINO.
|
||||
state = torch._C._get_tracing_state()
|
||||
|
@ -3,9 +3,34 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.StandardRoIHead.simple_test')
|
||||
def simple_test_of_standard_roi_head(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
def standard_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend.
|
||||
|
||||
This function returns detection result as Tensor instead of numpy
|
||||
array.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
x (tuple[Tensor]): Features from upstream network. Each
|
||||
has shape (batch_size, c, h, w).
|
||||
proposals (list(Tensor)): Proposals from rpn head.
|
||||
Each has shape (num_proposals, 5), last dimension
|
||||
5 represent (x1, y1, x2, y2, score).
|
||||
img_metas (list[dict]): Meta information of images.
|
||||
|
||||
Returns:
|
||||
If self.with_mask == True:
|
||||
tuple[Tensor, Tensor, Tensor]: (det_bboxes, det_labels,
|
||||
segm_results), `det_bboxes` of shape [N, num_det, 5],
|
||||
`det_labels` of shape [N, num_det], and `segm_results`
|
||||
of shape [N, num_det, roi_H, roi_W].
|
||||
Else:
|
||||
tuple[Tensor, Tensor]: (det_bboxes, det_labels),
|
||||
`det_bboxes` of shape [N, num_det, 5] and `det_labels`
|
||||
of shape [N, num_det].
|
||||
"""
|
||||
assert self.with_bbox, 'Bbox head must be implemented.'
|
||||
det_bboxes, det_labels = self.simple_test_bboxes(
|
||||
x, img_metas, proposals, self.test_cfg, rescale=False)
|
||||
|
@ -6,9 +6,31 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.test_mixins.\
|
||||
BBoxTestMixin.simple_test_bboxes')
|
||||
def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
|
||||
rcnn_test_cfg, **kwargs):
|
||||
"""Rewrite `simple_test_bboxes` for default backend."""
|
||||
def bbox_test_mixin__simple_test_bboxes(ctx, self, x, img_metas, proposals,
|
||||
rcnn_test_cfg, **kwargs):
|
||||
"""Rewrite `simple_test_bboxes` for default backend.
|
||||
|
||||
1. This function eliminates the batch dimension to get forward bbox
|
||||
results, and recover batch dimension to calculate final result
|
||||
for deployment.
|
||||
2. This function returns detection result as Tensor instead of numpy
|
||||
array.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
x (tuple[Tensor]): Features from upstream network. Each
|
||||
has shape (batch_size, c, h, w).
|
||||
img_metas (list[dict]): Meta information of images.
|
||||
proposals (list(Tensor)): Proposals from rpn head.
|
||||
Each has shape (num_proposals, 5), last dimension
|
||||
5 represent (x1, y1, x2, y2, score).
|
||||
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (det_bboxes, det_labels), `det_bboxes` of
|
||||
shape [N, num_det, 5] and `det_labels` of shape [N, num_det].
|
||||
"""
|
||||
rois = proposals
|
||||
batch_index = torch.arange(
|
||||
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
|
||||
@ -38,9 +60,28 @@ def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.test_mixins.\
|
||||
MaskTestMixin.simple_test_mask')
|
||||
def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes,
|
||||
det_labels, **kwargs):
|
||||
"""Rewrite `simple_test_mask` for default backend."""
|
||||
def mask_test_mixin__simple_test_mask(ctx, self, x, img_metas, det_bboxes,
|
||||
det_labels, **kwargs):
|
||||
"""Rewrite `simple_test_mask` for default backend.
|
||||
|
||||
This function returns detection result as Tensor instead of numpy
|
||||
array.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
x (tuple[Tensor]): Features from upstream network. Each
|
||||
has shape (batch_size, c, h, w).
|
||||
img_metas (list[dict]): Meta information of images.
|
||||
det_bboxes (tuple[Tensor]): Detection bounding-boxes from features.
|
||||
Each has shape of (batch_size, num_det, 5).
|
||||
det_labels (tuple[Tensor]): Detection labels from features. Each
|
||||
has shape of (batch_size, num_det).
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: (segm_results), `segm_results` of shape
|
||||
[N, num_det, roi_H, roi_W].
|
||||
"""
|
||||
if det_bboxes.shape[1] == 0:
|
||||
bboxes_shape, labels_shape = list(det_bboxes.shape), list(
|
||||
det_labels.shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user