From 3e8237d8bb361f5d84086a93b1798c9d0bf8e6fc Mon Sep 17 00:00:00 2001 From: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Date: Fri, 17 Dec 2021 10:46:54 +0800 Subject: [PATCH] [Feature] Support end2end mmdet2.19 retina mobilessd (#286) * support end2end mmdet2.19 retina mobilessd * fix yapf * add end2end fsaf * fix lint * fix comments * fix lint * add static configs * fix docformatter * move ssdhead * add rewrite for l2norm * fix ncnn ssd * fix isort * rename config * add ssd_head_ut * fix string * align ssd * remove unused bbox rewriter Co-authored-by: grimoire Co-authored-by: maningsheng --- .../single-stage_ncnn_static-300x300.py | 4 + .../single-stage_ncnn_static-800x1344.py | 4 + csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp | 9 + .../mmdet/core/bbox/delta_xywh_bbox_coder.py | 121 ----------- .../mmdet/core/bbox/tblr_bbox_coder.py | 72 ------ .../mmdet/core/ops/detection_output.py | 7 +- mmdeploy/codebase/mmdet/core/ops/prior_box.py | 6 + mmdeploy/codebase/mmdet/models/__init__.py | 1 + .../mmdet/models/dense_heads/__init__.py | 15 +- .../models/dense_heads/base_dense_head.py | 205 ++++++++++++++++++ .../mmdet/models/dense_heads/ssd_head.py | 124 +++++++++++ mmdeploy/codebase/mmdet/models/necks.py | 10 + .../test_mmdet/test_mmdet_models.py | 203 ++++++++--------- 13 files changed, 468 insertions(+), 313 deletions(-) create mode 100644 configs/mmdet/detection/single-stage_ncnn_static-300x300.py create mode 100644 configs/mmdet/detection/single-stage_ncnn_static-800x1344.py create mode 100644 mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py create mode 100644 mmdeploy/codebase/mmdet/models/necks.py diff --git a/configs/mmdet/detection/single-stage_ncnn_static-300x300.py b/configs/mmdet/detection/single-stage_ncnn_static-300x300.py new file mode 100644 index 000000000..35f7ef720 --- /dev/null +++ b/configs/mmdet/detection/single-stage_ncnn_static-300x300.py @@ -0,0 +1,4 @@ +_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py'] + +codebase_config = dict(model_type='ncnn_end2end') +onnx_config = dict(output_names=['detection_output'], input_shape=[300, 300]) diff --git a/configs/mmdet/detection/single-stage_ncnn_static-800x1344.py b/configs/mmdet/detection/single-stage_ncnn_static-800x1344.py new file mode 100644 index 000000000..110336fa8 --- /dev/null +++ b/configs/mmdet/detection/single-stage_ncnn_static-800x1344.py @@ -0,0 +1,4 @@ +_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py'] + +codebase_config = dict(model_type='ncnn_end2end') +onnx_config = dict(output_names=['detection_output'], input_shape=[1344, 800]) diff --git a/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index 9bc2763ec..6333270d6 100644 --- a/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -3765,11 +3765,16 @@ int main(int argc, char** argv) { int nms_top_k = get_node_attr_i(node, "nms_top_k"); int keep_top_k = get_node_attr_i(node, "keep_top_k"); int num_class = get_node_attr_i(node, "num_class"); + std::vector vars = get_node_attr_af(node, "vars"); fprintf(pp, " 0=%d", num_class); fprintf(pp, " 1=%f", nms_threshold); fprintf(pp, " 2=%d", nms_top_k); fprintf(pp, " 3=%d", keep_top_k); fprintf(pp, " 4=%f", score_threshold); + fprintf(pp, " 5=%f", vars[0]); + fprintf(pp, " 6=%f", vars[1]); + fprintf(pp, " 7=%f", vars[2]); + fprintf(pp, " 8=%f", vars[3]); } else if (op == "Div") { int op_type = 3; fprintf(pp, " 0=%d", op_type); @@ -4660,10 +4665,14 @@ int main(int argc, char** argv) { } int image_width = get_node_attr_i(node, "image_width"); int image_height = get_node_attr_i(node, "image_height"); + float step_width = get_node_attr_f(node, "step_width"); + float step_height = get_node_attr_f(node, "step_height"); float offset = get_node_attr_f(node, "offset"); int step_mmdetection = get_node_attr_i(node, "step_mmdetection"); fprintf(pp, " 9=%d", image_width); fprintf(pp, " 10=%d", image_height); + fprintf(pp, " 11=%f", step_width); + fprintf(pp, " 12=%f", step_height); fprintf(pp, " 13=%f", offset); fprintf(pp, " 14=%d", step_mmdetection); } else if (op == "PixelShuffle") { diff --git a/mmdeploy/codebase/mmdet/core/bbox/delta_xywh_bbox_coder.py b/mmdeploy/codebase/mmdet/core/bbox/delta_xywh_bbox_coder.py index c7aba3fac..02f2b4ffb 100644 --- a/mmdeploy/codebase/mmdet/core/bbox/delta_xywh_bbox_coder.py +++ b/mmdeploy/codebase/mmdet/core/bbox/delta_xywh_bbox_coder.py @@ -141,124 +141,3 @@ def delta2bbox(ctx, bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) return bboxes - - -@FUNCTION_REWRITER.register_rewriter( - func_name='mmdet.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox', # noqa - backend='ncnn') -def delta2bbox__ncnn(ctx, - rois, - deltas, - means=(0., 0., 0., 0.), - stds=(1., 1., 1., 1.), - max_shape=None, - wh_ratio_clip=16 / 1000, - clip_border=True, - add_ctr_clamp=False, - ctr_clamp=32): - """Rewrite `delta2bbox` for ncnn backend. - - Batch dimension is not supported by ncnn, but supported by pytorch. - NCNN regards the lowest two dimensions as continuous address with byte - alignment, so the lowest two dimensions are not absolutely independent. - Reshape operator with -1 arguments should operates ncnn::Mat with - dimension >= 3. - - Args: - ctx (ContextCaller): The context with additional information. - rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4) - deltas (Tensor): Encoded offsets with respect to each roi. - Has shape (B, N, num_classes * 4) or (B, N, 4) or - (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H - when rois is a grid of anchors.Offset encoding follows [1]_. - means (Sequence[float]): Denormalizing means for delta coordinates - stds (Sequence[float]): Denormalizing standard deviation for delta - coordinates - max_shape (Sequence[int] or torch.Tensor or Sequence[ - Sequence[int]],optional): Maximum bounds for boxes, specifies - (H, W, C) or (H, W). If rois shape is (B, N, 4), then - the max_shape should be a Sequence[Sequence[int]] - and the length of max_shape should also be B. - wh_ratio_clip (float): Maximum aspect ratio for boxes. - clip_border (bool, optional): Whether clip the objects outside the - border of the image. Defaults to True. - add_ctr_clamp (bool): Whether to add center clamp, when added, the - predicted box is clamped is its center is too far away from - the original anchor's center. Only used by YOLOF. Default False. - ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. - Default 32. - - Return: - bboxes (Tensor): Boxes with shape (B, N, num_classes * 4) or (B, N, 4) - or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y, - br_x, br_y. - """ - means = deltas.new_tensor(means).view(1, 1, - -1).repeat(1, deltas.size(-2), - deltas.size(-1) // 4).data - stds = deltas.new_tensor(stds).view(1, 1, - -1).repeat(1, deltas.size(-2), - deltas.size(-1) // 4).data - denorm_deltas = deltas * stds + means - if denorm_deltas.shape[-1] == 4: - dx = denorm_deltas[..., 0:1] - dy = denorm_deltas[..., 1:2] - dw = denorm_deltas[..., 2:3] - dh = denorm_deltas[..., 3:4] - else: - dx = denorm_deltas[..., 0::4] - dy = denorm_deltas[..., 1::4] - dw = denorm_deltas[..., 2::4] - dh = denorm_deltas[..., 3::4] - - x1, y1 = rois[..., 0:1], rois[..., 1:2] - x2, y2 = rois[..., 2:3], rois[..., 3:4] - - # Compute center of each roi - px = (x1 + x2) * 0.5 - py = (y1 + y2) * 0.5 - # Compute width/height of each roi - pw = x2 - x1 - ph = y2 - y1 - - # do not use expand unless necessary - # since expand is a custom ops - if px.shape[-1] != 4: - px = px.expand_as(dx) - if py.shape[-1] != 4: - py = py.expand_as(dy) - if pw.shape[-1] != 4: - pw = pw.expand_as(dw) - if px.shape[-1] != 4: - ph = ph.expand_as(dh) - - dx_width = pw * dx - dy_height = ph * dy - - max_ratio = np.abs(np.log(wh_ratio_clip)) - if add_ctr_clamp: - dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp) - dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp) - dw = torch.clamp(dw, max=max_ratio) - dh = torch.clamp(dh, max=max_ratio) - else: - dw = dw.clamp(min=-max_ratio, max=max_ratio) - dh = dh.clamp(min=-max_ratio, max=max_ratio) - # Use exp(network energy) to enlarge/shrink each roi - gw = pw * dw.exp() - gh = ph * dh.exp() - # Use network energy to shift the center of each roi - gx = px + dx_width - gy = py + dy_height - # Convert center-xy/width/height to top-left, bottom-right - x1 = gx - gw * 0.5 - y1 = gy - gh * 0.5 - x2 = gx + gw * 0.5 - y2 = gy + gh * 0.5 - - if clip_border and max_shape is not None: - from mmdeploy.codebase.mmdet.deploy import clip_bboxes - x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape) - - bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) - return bboxes diff --git a/mmdeploy/codebase/mmdet/core/bbox/tblr_bbox_coder.py b/mmdeploy/codebase/mmdet/core/bbox/tblr_bbox_coder.py index 7e8d8c17e..51ce64acf 100644 --- a/mmdeploy/codebase/mmdet/core/bbox/tblr_bbox_coder.py +++ b/mmdeploy/codebase/mmdet/core/bbox/tblr_bbox_coder.py @@ -71,75 +71,3 @@ def tblr2bboxes(ctx, bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size()) return bboxes - - -@FUNCTION_REWRITER.register_rewriter( - func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes', - backend='ncnn') -def tblr2bboxes__ncnn(ctx, - priors, - tblr, - normalizer=4.0, - normalize_by_wh=True, - max_shape=None, - clip_border=True): - """Rewrite `tblr2bboxes` for ncnn backend. - - Batch dimension is not supported by ncnn, but supported by pytorch. - The negative value of axis in torch.cat is rewritten as corresponding - positive value to avoid axis shift. - - Args: - ctx (ContextCaller): The context with additional information. - priors (Tensor): Prior boxes in point form (x0, y0, x1, y1) - Shape: (N,4) or (B, N, 4). - tblr (Tensor): Coords of network output in tblr form - Shape: (N, 4) or (B, N, 4). - normalizer (Sequence[float] | float): Normalization parameter of - encoded boxes. By list, it represents the normalization factors at - tblr dims. By float, it is the unified normalization factor at all - dims. Default: 4.0 - normalize_by_wh (bool): Whether the tblr coordinates have been - normalized by the side length (wh) of prior bboxes. - max_shape (Sequence[int] or torch.Tensor or Sequence[ - Sequence[int]],optional): Maximum bounds for boxes, specifies - (H, W, C) or (H, W). If priors shape is (B, N, 4), then - the max_shape should be a Sequence[Sequence[int]] - and the length of max_shape should also be B. - clip_border (bool, optional): Whether clip the objects outside the - border of the image. Defaults to True. - - Return: - bboxes (Tensor): Boxes with shape (N, 4) or (B, N, 4) - """ - assert priors.size(0) == tblr.size(0) - if priors.ndim == 3: - assert priors.size(1) == tblr.size(1) - - loc_decode = tblr * normalizer - prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2 - if normalize_by_wh: - w = priors[..., 2:3] - priors[..., 0:1] - h = priors[..., 3:4] - priors[..., 1:2] - _h = h.unsqueeze(0).unsqueeze(-1) - _loc_h = loc_decode[..., 0:2].unsqueeze(0).unsqueeze(-1) - _w = w.unsqueeze(0).unsqueeze(-1) - _loc_w = loc_decode[..., 2:4].unsqueeze(0).unsqueeze(-1) - th = (_h * _loc_h).reshape(1, -1, 2) - tw = (_w * _loc_w).reshape(1, -1, 2) - loc_decode = torch.cat([th, tw], dim=2) - top = loc_decode[..., 0:1] - bottom = loc_decode[..., 1:2] - left = loc_decode[..., 2:3] - right = loc_decode[..., 3:4] - xmin = prior_centers[..., 0].unsqueeze(-1) - left - xmax = prior_centers[..., 0].unsqueeze(-1) + right - ymin = prior_centers[..., 1].unsqueeze(-1) - top - ymax = prior_centers[..., 1].unsqueeze(-1) + bottom - - if clip_border and max_shape is not None: - from mmdeploy.codebase.mmdet.deploy import clip_bboxes - xmin, ymin, xmax, ymax = clip_bboxes(xmin, ymin, xmax, ymax, max_shape) - bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size()) - - return bboxes diff --git a/mmdeploy/codebase/mmdet/core/ops/detection_output.py b/mmdeploy/codebase/mmdet/core/ops/detection_output.py index 7d9a558fa..48d9f8441 100644 --- a/mmdeploy/codebase/mmdet/core/ops/detection_output.py +++ b/mmdeploy/codebase/mmdet/core/ops/detection_output.py @@ -37,7 +37,8 @@ class NcnnDetectionOutputOp(torch.autograd.Function): nms_threshold=0.45, nms_top_k=100, keep_top_k=100, - num_class=81): + num_class=81, + target_stds=[0.1, 0.1, 0.2, 0.2]): """Symbolic function of dummy onnx DetectionOutput op for ncnn.""" return g.op( 'mmdeploy::DetectionOutput', @@ -49,6 +50,7 @@ class NcnnDetectionOutputOp(torch.autograd.Function): nms_top_k_i=nms_top_k, keep_top_k_i=keep_top_k, num_class_i=num_class, + vars_f=target_stds, outputs=1) @staticmethod @@ -60,7 +62,8 @@ class NcnnDetectionOutputOp(torch.autograd.Function): nms_threshold=0.45, nms_top_k=100, keep_top_k=100, - num_class=81): + num_class=81, + target_stds=[0.1, 0.1, 0.2, 0.2]): """Forward function of dummy onnx DetectionOutput op for ncnn.""" return torch.rand(1, 100, 6) diff --git a/mmdeploy/codebase/mmdet/core/ops/prior_box.py b/mmdeploy/codebase/mmdet/core/ops/prior_box.py index e9685f14b..24efb02de 100644 --- a/mmdeploy/codebase/mmdet/core/ops/prior_box.py +++ b/mmdeploy/codebase/mmdet/core/ops/prior_box.py @@ -40,6 +40,8 @@ class NcnnPriorBoxOp(torch.autograd.Function): aspect_ratios=[2, 3], image_height=300, image_width=300, + step_height=300, + step_width=300, max_sizes=[300], min_sizes=[285], offset=0.5, @@ -51,6 +53,8 @@ class NcnnPriorBoxOp(torch.autograd.Function): aspect_ratios_f=aspect_ratios, image_height_i=image_height, image_width_i=image_width, + step_height_f=step_height, + step_width_f=step_width, max_sizes_f=max_sizes, min_sizes_f=min_sizes, offset_f=offset, @@ -63,6 +67,8 @@ class NcnnPriorBoxOp(torch.autograd.Function): aspect_ratios=[2, 3], image_height=300, image_width=300, + step_height=300, + step_width=300, max_sizes=[300], min_sizes=[285], offset=0.5, diff --git a/mmdeploy/codebase/mmdet/models/__init__.py b/mmdeploy/codebase/mmdet/models/__init__.py index 6a7c31956..aa95dce80 100644 --- a/mmdeploy/codebase/mmdet/models/__init__.py +++ b/mmdeploy/codebase/mmdet/models/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dense_heads import * # noqa: F401,F403 from .detectors import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 from .roi_heads import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index 7d1a4f717..5c9638620 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .anchor_head import anchor_head__get_bboxes__ncnn -from .base_dense_head import base_dense_head__get_bbox +from .base_dense_head import (base_dense_head__get_bbox, + base_dense_head__get_bboxes__ncnn) from .fcos_head import fcos_head__get_bboxes__ncnn from .fovea_head import fovea_head__get_bboxes from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn +from .ssd_head import ssd_head__get_bboxes__ncnn from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn from .yolox_head import yolox_head__get_bboxes __all__ = [ - 'anchor_head__get_bboxes__ncnn', 'fcos_head__get_bboxes__ncnn', - 'rpn_head__get_bboxes', 'rpn_head__get_bboxes__ncnn', - 'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn', - 'yolox_head__get_bboxes', 'base_dense_head__get_bbox', - 'fovea_head__get_bboxes' + 'fcos_head__get_bboxes__ncnn', 'rpn_head__get_bboxes', + 'rpn_head__get_bboxes__ncnn', 'yolov3_head__get_bboxes', + 'yolov3_head__get_bboxes__ncnn', 'yolox_head__get_bboxes', + 'base_dense_head__get_bbox', 'fovea_head__get_bboxes', + 'base_dense_head__get_bboxes__ncnn', 'ssd_head__get_bboxes__ncnn' ] diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py index eef8f0de0..6f608e5c9 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py @@ -1,7 +1,10 @@ import torch +from mmdet.core.bbox.coder.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder +from mmdet.core.bbox.coder.tblr_bbox_coder import TBLRBBoxCoder from mmdeploy.codebase.mmdet import (get_post_processing_params, multiclass_nms, pad_with_value) +from mmdeploy.codebase.mmdet.core.ops import ncnn_detection_output_forward from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import Backend, get_backend, is_dynamic_shape @@ -188,3 +191,205 @@ def base_dense_head__get_bbox(ctx, score_threshold=score_threshold, pre_top_k=pre_top_k, keep_top_k=keep_top_k) + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead' + '.get_bboxes', + backend='ncnn') +def base_dense_head__get_bboxes__ncnn(ctx, + self, + cls_scores, + bbox_preds, + score_factors=None, + img_metas=None, + cfg=None, + rescale=False, + with_nms=True, + **kwargs): + """Rewrite `get_bboxes` of AnchorHead for NCNN backend. + + Shape node and batch inference is not supported by ncnn. This function + transform dynamic shape to constant shape and remove batch inference. + + Args: + ctx (ContextCaller): The context with additional information. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], Optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Default None. + img_metas (list[dict], Optional): Image meta info. Default None. + cfg (mmcv.Config, Optional): Test / postprocessing configuration, + if None, test_cfg would be used. Default None. + rescale (bool): If True, return boxes in original image space. + Default False. + with_nms (bool): If True, do nms before return boxes. + Default True. + + + Returns: + output__ncnn (Tensor): outputs, shape is [N, num_det, 6]. + """ + assert len(cls_scores) == len(bbox_preds) + deploy_cfg = ctx.cfg + assert not is_dynamic_shape(deploy_cfg), 'base_dense_head for ncnn\ + only supports static shape.' + + if score_factors is None: + # e.g. Retina, FreeAnchor, Foveabox, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, AutoAssign, etc. + with_score_factors = True + assert len(cls_scores) == len(score_factors) + batch_size = cls_scores[0].shape[0] + assert batch_size == 1, f'ncnn deployment requires batch size 1, \ + got {batch_size}.' + + num_levels = len(cls_scores) + if with_score_factors: + score_factor_list = score_factors + else: + score_factor_list = [None for _ in range(num_levels)] + + if isinstance(self.bbox_coder, DeltaXYWHBBoxCoder): + vars = torch.tensor(self.bbox_coder.stds) + elif isinstance(self.bbox_coder, TBLRBBoxCoder): + normalizer = self.bbox_coder.normalizer + if isinstance(normalizer, float): + vars = torch.tensor([normalizer, normalizer, 1, 1], + dtype=torch.float32) + else: + assert len(normalizer) == 4, f'normalizer of tblr must be 4,\ + got {len(normalizer)}' + + assert (normalizer[0] == normalizer[1] and normalizer[2] + == normalizer[3]), 'normalizer between top \ + and bottom, left and right must be the same value, or \ + we can not transform it to delta_xywh format.' + + vars = torch.tensor([normalizer[0], normalizer[2], 1, 1], + dtype=torch.float32) + else: + vars = None + if isinstance(img_metas[0]['img_shape'][0], int): + assert isinstance(img_metas[0]['img_shape'][1], int) + img_height = img_metas[0]['img_shape'][0] + img_width = img_metas[0]['img_shape'][1] + else: + img_height = img_metas[0]['img_shape'][0].item() + img_width = img_metas[0]['img_shape'][1].item() + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, device=cls_scores[0].device) + batch_mlvl_priors = [] + for i in range(num_levels): + _priors = mlvl_priors[i].reshape(1, -1, mlvl_priors[i].shape[-1]) + x1 = _priors[:, :, 0:1] / img_width + y1 = _priors[:, :, 1:2] / img_height + x2 = _priors[:, :, 2:3] / img_width + y2 = _priors[:, :, 3:4] / img_height + priors = torch.cat([x1, y1, x2, y2], dim=2).data + batch_mlvl_priors.append(priors) + + cfg = self.test_cfg if cfg is None else cfg + + batch_mlvl_bboxes = [] + batch_mlvl_scores = [] + batch_mlvl_score_factors = [] + + for level_idx, (cls_score, bbox_pred, score_factor, priors) in \ + enumerate(zip(cls_scores, bbox_preds, + score_factor_list, batch_mlvl_priors)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + # NCNN needs 3 dimensions to reshape when including -1 parameter in + # width or height dimension. + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) + if with_score_factors: + score_factor = score_factor.permute(0, 2, 3, 1).\ + reshape(batch_size, -1, 1).sigmoid() + cls_score = cls_score.permute(0, 2, 3, 1).\ + reshape(batch_size, -1, self.cls_out_channels) + # NCNN DetectionOutput op needs num_class + 1 classes. So if sigmoid + # score, we should padding background class according to mmdetection + # num_class definition. + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + dummy_background_score = torch.zeros( + batch_size, cls_score.shape[1], 1, device=cls_score.device) + scores = torch.cat([scores, dummy_background_score], dim=2) + else: + scores = cls_score.softmax(-1) + batch_mlvl_bboxes.append(bbox_pred) + batch_mlvl_scores.append(scores) + batch_mlvl_score_factors.append(score_factor) + + batch_mlvl_priors = torch.cat(batch_mlvl_priors, dim=1) + batch_mlvl_scores = torch.cat(batch_mlvl_scores, dim=1) + batch_mlvl_bboxes = torch.cat(batch_mlvl_bboxes, dim=1) + batch_mlvl_scores = torch.cat([ + batch_mlvl_scores[:, :, self.num_classes:], + batch_mlvl_scores[:, :, 0:self.num_classes] + ], + dim=2) + if isinstance(self.bbox_coder, TBLRBBoxCoder): + batch_mlvl_bboxes = _tblr_pred_to_delta_xywh_pred( + batch_mlvl_bboxes, vars[0:2]) + # flatten for ncnn DetectionOutput op inputs. + batch_mlvl_vars = vars.expand_as(batch_mlvl_priors) + batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, 1, -1) + batch_mlvl_scores = batch_mlvl_scores.reshape(batch_size, 1, -1) + batch_mlvl_priors = batch_mlvl_priors.reshape(batch_size, 1, -1) + batch_mlvl_vars = batch_mlvl_vars.reshape(batch_size, 1, -1) + batch_mlvl_priors = torch.cat([batch_mlvl_priors, batch_mlvl_vars], dim=1)\ + .data + + post_params = get_post_processing_params(ctx.cfg) + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + + output__ncnn = ncnn_detection_output_forward( + batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_priors, + score_threshold, iou_threshold, pre_top_k, keep_top_k, + self.num_classes + 1, + vars.cpu().detach().numpy()) + + return output__ncnn + + +def _tblr_pred_to_delta_xywh_pred(bbox_pred: torch.Tensor, + normalizer: torch.Tensor) -> torch.Tensor: + """Transform tblr format bbox prediction to delta_xywh format for ncnn. + + An internal function for transforming tblr format bbox prediction to + delta_xywh format. NCNN DetectionOutput layer needs delta_xywh format + bbox_pred as input. + + Args: + bbox_pred (Tensor): The bbox prediction of tblr format, has shape + (N, num_det, 4). + normalizer (Tensor): The normalizer scale of bbox horizon and + vertical coordinates, has shape (2,). + + Returns: + Tensor: The delta_xywh format bbox predictions. + """ + top = bbox_pred[:, :, 0:1] + bottom = bbox_pred[:, :, 1:2] + left = bbox_pred[:, :, 2:3] + right = bbox_pred[:, :, 3:4] + h = (top + bottom) * normalizer[0] + w = (left + right) * normalizer[1] + + _dwh = torch.cat([w, h], dim=2) + assert torch.all(_dwh >= 0), 'wh must be positive before log.' + dwh = torch.log(_dwh) + + return torch.cat([(right - left) / 2, (bottom - top) / 2, dwh], dim=2) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py new file mode 100644 index 000000000..3b732398b --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.core.ops import (ncnn_detection_output_forward, + ncnn_prior_box_forward) +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.dense_heads.SSDHead.get_bboxes', backend='ncnn') +def ssd_head__get_bboxes__ncnn(ctx, + self, + cls_scores, + bbox_preds, + img_metas, + with_nms=True, + cfg=None, + **kwargs): + """Rewrite `get_bboxes` of SSDHead for NCNN backend. + + This rewriter using ncnn PriorBox and DetectionOutput layer to + support dynamic deployment, and has higher speed. + + Args: + ctx (ContextCaller): The context with additional information. + cls_scores (list[Tensor]): Box scores for each level in the + feature pyramid, has shape + (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each + level in the feature pyramid, has shape + (N, num_anchors * 4, H, W). + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + with_nms (bool): If True, do nms before return boxes. + Default: True. + cfg (mmcv.Config | None): Test / postprocessing configuration, + if None, test_cfg would be used. + Default: None. + + Returns: + Tensor: outputs, shape is [N, num_det, 6]. + """ + assert len(cls_scores) == len(bbox_preds) + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + num_levels = len(cls_scores) + aspect_ratio = [ + ratio[ratio > 1].detach().cpu().numpy() + for ratio in self.anchor_generator.ratios + ] + strides = self.anchor_generator.strides + min_sizes = self.anchor_generator.base_sizes + if is_dynamic_flag: + max_sizes = min_sizes[1:] + img_metas[0]['img_shape'][0:1].tolist() + img_height = img_metas[0]['img_shape'][0].item() + img_width = img_metas[0]['img_shape'][1].item() + else: + max_sizes = min_sizes[1:] + img_metas[0]['img_shape'][0:1] + img_height = img_metas[0]['img_shape'][0] + img_width = img_metas[0]['img_shape'][1] + + # if no reshape, concat will be error in ncnn. + mlvl_anchors = [ + ncnn_prior_box_forward(cls_scores[i], aspect_ratio[i], img_height, + img_width, strides[i][0], strides[i][1], + max_sizes[i:i + 1], + min_sizes[i:i + 1]).reshape(1, 2, -1) + for i in range(num_levels) + ] + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + + cfg = self.test_cfg if cfg is None else cfg + assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors) + batch_size = 1 + + mlvl_valid_bboxes = [] + mlvl_scores = [] + for level_id, cls_score, bbox_pred in zip( + range(num_levels), mlvl_cls_scores, mlvl_bbox_preds): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(batch_size, -1, + self.cls_out_channels) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) + + mlvl_valid_bboxes.append(bbox_pred) + mlvl_scores.append(cls_score) + + # NCNN DetectionOutput layer uses background class at 0 position, but + # in mmdetection, background class is at self.num_classes position. + # We should adapt for ncnn. + batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + if self.use_sigmoid_cls: + batch_mlvl_scores = batch_mlvl_scores.sigmoid() + else: + batch_mlvl_scores = batch_mlvl_scores.softmax(-1) + batch_mlvl_anchors = torch.cat(mlvl_anchors, dim=2) + batch_mlvl_scores = torch.cat([ + batch_mlvl_scores[:, :, self.num_classes:], + batch_mlvl_scores[:, :, 0:self.num_classes] + ], + dim=2) + batch_mlvl_valid_bboxes = batch_mlvl_valid_bboxes.reshape( + batch_size, 1, -1) + batch_mlvl_scores = batch_mlvl_scores.reshape(batch_size, 1, -1) + batch_mlvl_anchors = batch_mlvl_anchors.reshape(batch_size, 2, -1) + + post_params = get_post_processing_params(deploy_cfg) + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + + output__ncnn = ncnn_detection_output_forward( + batch_mlvl_valid_bboxes, batch_mlvl_scores, batch_mlvl_anchors, + score_threshold, iou_threshold, pre_top_k, keep_top_k, + self.num_classes + 1) + + return output__ncnn diff --git a/mmdeploy/codebase/mmdet/models/necks.py b/mmdeploy/codebase/mmdet/models/necks.py new file mode 100644 index 000000000..0e3ecbcff --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/necks.py @@ -0,0 +1,10 @@ +import torch + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.necks.ssd_neck.L2Norm.forward') +def l2norm__forward__default(ctx, self, x): + return torch.nn.functional.normalize( + x, dim=1) * self.weight[None, :, None, None] diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 34690c147..28cc4e4c8 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -139,114 +139,6 @@ def get_single_roi_extractor(): return model -@pytest.mark.parametrize('backend_type', [Backend.NCNN]) -@pytest.mark.parametrize('is_ssd', [True, False]) -def test_anchor_head_get_bboxes(backend_type: Backend, is_ssd: bool): - """Test get_bboxes rewrite of anchor head.""" - check_backend(backend_type) - if is_ssd: - anchor_head = get_ssd_head_model() - else: - anchor_head = get_anchor_head_model() - anchor_head.cpu().eval() - s = 128 - img_metas = [{ - 'scale_factor': np.ones(4), - 'pad_shape': (s, s, 3), - 'img_shape': (s, s, 3) - }] - if is_ssd: - output_names = ['output'] - else: - output_names = ['dets', 'labels'] - deploy_cfg = mmcv.Config( - dict( - backend_config=dict(type=backend_type.value), - onnx_config=dict(output_names=output_names, input_shape=None), - codebase_config=dict( - type='mmdet', - task='ObjectDetection', - post_processing=dict( - score_threshold=0.05, - iou_threshold=0.5, - max_output_boxes_per_class=200, - pre_top_k=5000, - keep_top_k=100, - background_label_id=-1, - )))) - - if not is_ssd: - # For the general anchor_head: - # the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16), - # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2). - # the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16), - # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2) - seed_everything(1234) - cls_score = [ - torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1) - ] - seed_everything(5678) - bboxes = [ - torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1) - ] - else: - # For the ssd_head: - # the cls_score's size: (1, 30, 20, 20), (1, 30, 10, 10), - # (1, 30, 5, 5), (1, 30, 3, 3), (1, 30, 2, 2), (1, 30, 1, 1) - # the bboxes's size: (1, 24, 20, 20), (1, 24, 10, 10), - # (1, 24, 5, 5), (1, 24, 3, 3), (1, 24, 2, 2), (1, 24, 1, 1) - feat_shape = [20, 10, 5, 3, 2, 1] - num_prior = 6 - seed_everything(1234) - cls_score = [ - torch.rand(1, 30, feat_shape[i], feat_shape[i]) - for i in range(num_prior) - ] - seed_everything(5678) - bboxes = [ - torch.rand(1, 24, feat_shape[i], feat_shape[i]) - for i in range(num_prior) - ] - - # to get outputs of pytorch model - model_inputs = { - 'cls_scores': cls_score, - 'bbox_preds': bboxes, - 'img_metas': img_metas - } - model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs) - - # to get outputs of onnx model after rewrite - img_metas[0]['img_shape'] = torch.tensor([s, s], dtype=torch.int32) - wrapped_model = WrapModel( - anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True) - rewrite_inputs = { - 'cls_scores': cls_score, - 'bbox_preds': bboxes, - } - rewrite_outputs, is_backend_output = get_rewrite_outputs( - wrapped_model=wrapped_model, - model_inputs=rewrite_inputs, - deploy_cfg=deploy_cfg) - - if is_backend_output: - if isinstance(rewrite_outputs, dict): - rewrite_outputs = convert_to_list(rewrite_outputs, output_names) - for model_output, rewrite_output in zip(model_outputs[0], - rewrite_outputs): - model_output = model_output.squeeze().cpu().numpy() - rewrite_output = rewrite_output.squeeze().cpu().numpy() - # hard code to make two tensors with the same shape - # rewrite and original codes applied different nms strategy - assert np.allclose( - model_output[:rewrite_output.shape[0]], - rewrite_output, - rtol=1e-03, - atol=1e-05) - else: - assert rewrite_outputs is not None - - @pytest.mark.parametrize('backend_type', [Backend.NCNN]) def test_get_bboxes_of_fcos_head(backend_type: Backend): check_backend(backend_type) @@ -1126,8 +1018,8 @@ def test_get_bboxes_of_vfnet_head(backend_type: Backend): @pytest.mark.parametrize('backend_type', - [Backend.ONNXRUNTIME, Backend.OPENVINO]) -def test_get_bboxes_of_base_dense_head(backend_type: Backend): + [Backend.ONNXRUNTIME, Backend.NCNN, Backend.OPENVINO]) +def test_base_dense_head_get_bboxes(backend_type: Backend): """Test get_bboxes rewrite of base dense head.""" check_backend(backend_type) anchor_head = get_anchor_head_model() @@ -1139,7 +1031,10 @@ def test_get_bboxes_of_base_dense_head(backend_type: Backend): 'img_shape': (s, s, 3) }] - output_names = ['dets', 'labels'] + if backend_type != Backend.NCNN: + output_names = ['dets', 'labels'] + else: + output_names = ['output'] deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend_type.value), @@ -1204,3 +1099,89 @@ def test_get_bboxes_of_base_dense_head(backend_type: Backend): atol=1e-05) else: assert rewrite_outputs is not None + + +@pytest.mark.parametrize('backend_type', [Backend.NCNN]) +def test_ssd_head_get_bboxes(backend_type: Backend): + """Test get_bboxes rewrite of anchor head.""" + check_backend(backend_type) + ssd_head = get_ssd_head_model() + ssd_head.cpu().eval() + s = 128 + img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + output_names = ['output'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )))) + + # For the ssd_head: + # the cls_score's size: (1, 30, 20, 20), (1, 30, 10, 10), + # (1, 30, 5, 5), (1, 30, 3, 3), (1, 30, 2, 2), (1, 30, 1, 1) + # the bboxes's size: (1, 24, 20, 20), (1, 24, 10, 10), + # (1, 24, 5, 5), (1, 24, 3, 3), (1, 24, 2, 2), (1, 24, 1, 1) + feat_shape = [20, 10, 5, 3, 2, 1] + num_prior = 6 + seed_everything(1234) + cls_score = [ + torch.rand(1, 30, feat_shape[i], feat_shape[i]) + for i in range(num_prior) + ] + seed_everything(5678) + bboxes = [ + torch.rand(1, 24, feat_shape[i], feat_shape[i]) + for i in range(num_prior) + ] + + # to get outputs of pytorch model + model_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + 'img_metas': img_metas + } + model_outputs = get_model_outputs(ssd_head, 'get_bboxes', model_inputs) + + # to get outputs of onnx model after rewrite + img_metas[0]['img_shape'] = torch.tensor([s, s], dtype=torch.int32) + wrapped_model = WrapModel( + ssd_head, 'get_bboxes', img_metas=img_metas, with_nms=True) + rewrite_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + if isinstance(rewrite_outputs, dict): + rewrite_outputs = convert_to_list(rewrite_outputs, output_names) + for model_output, rewrite_output in zip(model_outputs[0], + rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze().cpu().numpy() + # hard code to make two tensors with the same shape + # rewrite and original codes applied different nms strategy + assert np.allclose( + model_output[:rewrite_output.shape[0]], + rewrite_output, + rtol=1e-03, + atol=1e-05) + else: + assert rewrite_outputs is not None