[Feature] Support end2end mmdet2.19 retina mobilessd (#286)

* support end2end mmdet2.19 retina mobilessd

* fix yapf

* add end2end fsaf

* fix lint

* fix comments

* fix lint

* add static configs

* fix docformatter

* move ssdhead

* add rewrite for l2norm

* fix ncnn ssd

* fix isort

* rename config

* add ssd_head_ut

* fix string

* align ssd

* remove unused bbox rewriter

Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: maningsheng <mnsheng@yeah.net>
This commit is contained in:
hanrui1sensetime 2021-12-17 10:46:54 +08:00 committed by GitHub
parent 31e8aed862
commit 3e8237d8bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 468 additions and 313 deletions

View File

@ -0,0 +1,4 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
codebase_config = dict(model_type='ncnn_end2end')
onnx_config = dict(output_names=['detection_output'], input_shape=[300, 300])

View File

@ -0,0 +1,4 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
codebase_config = dict(model_type='ncnn_end2end')
onnx_config = dict(output_names=['detection_output'], input_shape=[1344, 800])

View File

@ -3765,11 +3765,16 @@ int main(int argc, char** argv) {
int nms_top_k = get_node_attr_i(node, "nms_top_k");
int keep_top_k = get_node_attr_i(node, "keep_top_k");
int num_class = get_node_attr_i(node, "num_class");
std::vector<float> vars = get_node_attr_af(node, "vars");
fprintf(pp, " 0=%d", num_class);
fprintf(pp, " 1=%f", nms_threshold);
fprintf(pp, " 2=%d", nms_top_k);
fprintf(pp, " 3=%d", keep_top_k);
fprintf(pp, " 4=%f", score_threshold);
fprintf(pp, " 5=%f", vars[0]);
fprintf(pp, " 6=%f", vars[1]);
fprintf(pp, " 7=%f", vars[2]);
fprintf(pp, " 8=%f", vars[3]);
} else if (op == "Div") {
int op_type = 3;
fprintf(pp, " 0=%d", op_type);
@ -4660,10 +4665,14 @@ int main(int argc, char** argv) {
}
int image_width = get_node_attr_i(node, "image_width");
int image_height = get_node_attr_i(node, "image_height");
float step_width = get_node_attr_f(node, "step_width");
float step_height = get_node_attr_f(node, "step_height");
float offset = get_node_attr_f(node, "offset");
int step_mmdetection = get_node_attr_i(node, "step_mmdetection");
fprintf(pp, " 9=%d", image_width);
fprintf(pp, " 10=%d", image_height);
fprintf(pp, " 11=%f", step_width);
fprintf(pp, " 12=%f", step_height);
fprintf(pp, " 13=%f", offset);
fprintf(pp, " 14=%d", step_mmdetection);
} else if (op == "PixelShuffle") {

View File

@ -141,124 +141,3 @@ def delta2bbox(ctx,
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
return bboxes
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox', # noqa
backend='ncnn')
def delta2bbox__ncnn(ctx,
rois,
deltas,
means=(0., 0., 0., 0.),
stds=(1., 1., 1., 1.),
max_shape=None,
wh_ratio_clip=16 / 1000,
clip_border=True,
add_ctr_clamp=False,
ctr_clamp=32):
"""Rewrite `delta2bbox` for ncnn backend.
Batch dimension is not supported by ncnn, but supported by pytorch.
NCNN regards the lowest two dimensions as continuous address with byte
alignment, so the lowest two dimensions are not absolutely independent.
Reshape operator with -1 arguments should operates ncnn::Mat with
dimension >= 3.
Args:
ctx (ContextCaller): The context with additional information.
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
deltas (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
when rois is a grid of anchors.Offset encoding follows [1]_.
means (Sequence[float]): Denormalizing means for delta coordinates
stds (Sequence[float]): Denormalizing standard deviation for delta
coordinates
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
wh_ratio_clip (float): Maximum aspect ratio for boxes.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
add_ctr_clamp (bool): Whether to add center clamp, when added, the
predicted box is clamped is its center is too far away from
the original anchor's center. Only used by YOLOF. Default False.
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
Default 32.
Return:
bboxes (Tensor): Boxes with shape (B, N, num_classes * 4) or (B, N, 4)
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
br_x, br_y.
"""
means = deltas.new_tensor(means).view(1, 1,
-1).repeat(1, deltas.size(-2),
deltas.size(-1) // 4).data
stds = deltas.new_tensor(stds).view(1, 1,
-1).repeat(1, deltas.size(-2),
deltas.size(-1) // 4).data
denorm_deltas = deltas * stds + means
if denorm_deltas.shape[-1] == 4:
dx = denorm_deltas[..., 0:1]
dy = denorm_deltas[..., 1:2]
dw = denorm_deltas[..., 2:3]
dh = denorm_deltas[..., 3:4]
else:
dx = denorm_deltas[..., 0::4]
dy = denorm_deltas[..., 1::4]
dw = denorm_deltas[..., 2::4]
dh = denorm_deltas[..., 3::4]
x1, y1 = rois[..., 0:1], rois[..., 1:2]
x2, y2 = rois[..., 2:3], rois[..., 3:4]
# Compute center of each roi
px = (x1 + x2) * 0.5
py = (y1 + y2) * 0.5
# Compute width/height of each roi
pw = x2 - x1
ph = y2 - y1
# do not use expand unless necessary
# since expand is a custom ops
if px.shape[-1] != 4:
px = px.expand_as(dx)
if py.shape[-1] != 4:
py = py.expand_as(dy)
if pw.shape[-1] != 4:
pw = pw.expand_as(dw)
if px.shape[-1] != 4:
ph = ph.expand_as(dh)
dx_width = pw * dx
dy_height = ph * dy
max_ratio = np.abs(np.log(wh_ratio_clip))
if add_ctr_clamp:
dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
dw = torch.clamp(dw, max=max_ratio)
dh = torch.clamp(dh, max=max_ratio)
else:
dw = dw.clamp(min=-max_ratio, max=max_ratio)
dh = dh.clamp(min=-max_ratio, max=max_ratio)
# Use exp(network energy) to enlarge/shrink each roi
gw = pw * dw.exp()
gh = ph * dh.exp()
# Use network energy to shift the center of each roi
gx = px + dx_width
gy = py + dy_height
# Convert center-xy/width/height to top-left, bottom-right
x1 = gx - gw * 0.5
y1 = gy - gh * 0.5
x2 = gx + gw * 0.5
y2 = gy + gh * 0.5
if clip_border and max_shape is not None:
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape)
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
return bboxes

View File

@ -71,75 +71,3 @@ def tblr2bboxes(ctx,
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size())
return bboxes
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
backend='ncnn')
def tblr2bboxes__ncnn(ctx,
priors,
tblr,
normalizer=4.0,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
"""Rewrite `tblr2bboxes` for ncnn backend.
Batch dimension is not supported by ncnn, but supported by pytorch.
The negative value of axis in torch.cat is rewritten as corresponding
positive value to avoid axis shift.
Args:
ctx (ContextCaller): The context with additional information.
priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
Shape: (N,4) or (B, N, 4).
tblr (Tensor): Coords of network output in tblr form
Shape: (N, 4) or (B, N, 4).
normalizer (Sequence[float] | float): Normalization parameter of
encoded boxes. By list, it represents the normalization factors at
tblr dims. By float, it is the unified normalization factor at all
dims. Default: 4.0
normalize_by_wh (bool): Whether the tblr coordinates have been
normalized by the side length (wh) of prior bboxes.
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
Return:
bboxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
"""
assert priors.size(0) == tblr.size(0)
if priors.ndim == 3:
assert priors.size(1) == tblr.size(1)
loc_decode = tblr * normalizer
prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
if normalize_by_wh:
w = priors[..., 2:3] - priors[..., 0:1]
h = priors[..., 3:4] - priors[..., 1:2]
_h = h.unsqueeze(0).unsqueeze(-1)
_loc_h = loc_decode[..., 0:2].unsqueeze(0).unsqueeze(-1)
_w = w.unsqueeze(0).unsqueeze(-1)
_loc_w = loc_decode[..., 2:4].unsqueeze(0).unsqueeze(-1)
th = (_h * _loc_h).reshape(1, -1, 2)
tw = (_w * _loc_w).reshape(1, -1, 2)
loc_decode = torch.cat([th, tw], dim=2)
top = loc_decode[..., 0:1]
bottom = loc_decode[..., 1:2]
left = loc_decode[..., 2:3]
right = loc_decode[..., 3:4]
xmin = prior_centers[..., 0].unsqueeze(-1) - left
xmax = prior_centers[..., 0].unsqueeze(-1) + right
ymin = prior_centers[..., 1].unsqueeze(-1) - top
ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
if clip_border and max_shape is not None:
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
xmin, ymin, xmax, ymax = clip_bboxes(xmin, ymin, xmax, ymax, max_shape)
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size())
return bboxes

View File

@ -37,7 +37,8 @@ class NcnnDetectionOutputOp(torch.autograd.Function):
nms_threshold=0.45,
nms_top_k=100,
keep_top_k=100,
num_class=81):
num_class=81,
target_stds=[0.1, 0.1, 0.2, 0.2]):
"""Symbolic function of dummy onnx DetectionOutput op for ncnn."""
return g.op(
'mmdeploy::DetectionOutput',
@ -49,6 +50,7 @@ class NcnnDetectionOutputOp(torch.autograd.Function):
nms_top_k_i=nms_top_k,
keep_top_k_i=keep_top_k,
num_class_i=num_class,
vars_f=target_stds,
outputs=1)
@staticmethod
@ -60,7 +62,8 @@ class NcnnDetectionOutputOp(torch.autograd.Function):
nms_threshold=0.45,
nms_top_k=100,
keep_top_k=100,
num_class=81):
num_class=81,
target_stds=[0.1, 0.1, 0.2, 0.2]):
"""Forward function of dummy onnx DetectionOutput op for ncnn."""
return torch.rand(1, 100, 6)

View File

@ -40,6 +40,8 @@ class NcnnPriorBoxOp(torch.autograd.Function):
aspect_ratios=[2, 3],
image_height=300,
image_width=300,
step_height=300,
step_width=300,
max_sizes=[300],
min_sizes=[285],
offset=0.5,
@ -51,6 +53,8 @@ class NcnnPriorBoxOp(torch.autograd.Function):
aspect_ratios_f=aspect_ratios,
image_height_i=image_height,
image_width_i=image_width,
step_height_f=step_height,
step_width_f=step_width,
max_sizes_f=max_sizes,
min_sizes_f=min_sizes,
offset_f=offset,
@ -63,6 +67,8 @@ class NcnnPriorBoxOp(torch.autograd.Function):
aspect_ratios=[2, 3],
image_height=300,
image_width=300,
step_height=300,
step_width=300,
max_sizes=[300],
min_sizes=[285],
offset=0.5,

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403

View File

@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor_head import anchor_head__get_bboxes__ncnn
from .base_dense_head import base_dense_head__get_bbox
from .base_dense_head import (base_dense_head__get_bbox,
base_dense_head__get_bboxes__ncnn)
from .fcos_head import fcos_head__get_bboxes__ncnn
from .fovea_head import fovea_head__get_bboxes
from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn
from .ssd_head import ssd_head__get_bboxes__ncnn
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
from .yolox_head import yolox_head__get_bboxes
__all__ = [
'anchor_head__get_bboxes__ncnn', 'fcos_head__get_bboxes__ncnn',
'rpn_head__get_bboxes', 'rpn_head__get_bboxes__ncnn',
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn',
'yolox_head__get_bboxes', 'base_dense_head__get_bbox',
'fovea_head__get_bboxes'
'fcos_head__get_bboxes__ncnn', 'rpn_head__get_bboxes',
'rpn_head__get_bboxes__ncnn', 'yolov3_head__get_bboxes',
'yolov3_head__get_bboxes__ncnn', 'yolox_head__get_bboxes',
'base_dense_head__get_bbox', 'fovea_head__get_bboxes',
'base_dense_head__get_bboxes__ncnn', 'ssd_head__get_bboxes__ncnn'
]

View File

@ -1,7 +1,10 @@
import torch
from mmdet.core.bbox.coder.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from mmdet.core.bbox.coder.tblr_bbox_coder import TBLRBBoxCoder
from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms, pad_with_value)
from mmdeploy.codebase.mmdet.core.ops import ncnn_detection_output_forward
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
@ -188,3 +191,205 @@ def base_dense_head__get_bbox(ctx,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead'
'.get_bboxes',
backend='ncnn')
def base_dense_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
bbox_preds,
score_factors=None,
img_metas=None,
cfg=None,
rescale=False,
with_nms=True,
**kwargs):
"""Rewrite `get_bboxes` of AnchorHead for NCNN backend.
Shape node and batch inference is not supported by ncnn. This function
transform dynamic shape to constant shape and remove batch inference.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Default None.
img_metas (list[dict], Optional): Image meta info. Default None.
cfg (mmcv.Config, Optional): Test / postprocessing configuration,
if None, test_cfg would be used. Default None.
rescale (bool): If True, return boxes in original image space.
Default False.
with_nms (bool): If True, do nms before return boxes.
Default True.
Returns:
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg), 'base_dense_head for ncnn\
only supports static shape.'
if score_factors is None:
# e.g. Retina, FreeAnchor, Foveabox, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, AutoAssign, etc.
with_score_factors = True
assert len(cls_scores) == len(score_factors)
batch_size = cls_scores[0].shape[0]
assert batch_size == 1, f'ncnn deployment requires batch size 1, \
got {batch_size}.'
num_levels = len(cls_scores)
if with_score_factors:
score_factor_list = score_factors
else:
score_factor_list = [None for _ in range(num_levels)]
if isinstance(self.bbox_coder, DeltaXYWHBBoxCoder):
vars = torch.tensor(self.bbox_coder.stds)
elif isinstance(self.bbox_coder, TBLRBBoxCoder):
normalizer = self.bbox_coder.normalizer
if isinstance(normalizer, float):
vars = torch.tensor([normalizer, normalizer, 1, 1],
dtype=torch.float32)
else:
assert len(normalizer) == 4, f'normalizer of tblr must be 4,\
got {len(normalizer)}'
assert (normalizer[0] == normalizer[1] and normalizer[2]
== normalizer[3]), 'normalizer between top \
and bottom, left and right must be the same value, or \
we can not transform it to delta_xywh format.'
vars = torch.tensor([normalizer[0], normalizer[2], 1, 1],
dtype=torch.float32)
else:
vars = None
if isinstance(img_metas[0]['img_shape'][0], int):
assert isinstance(img_metas[0]['img_shape'][1], int)
img_height = img_metas[0]['img_shape'][0]
img_width = img_metas[0]['img_shape'][1]
else:
img_height = img_metas[0]['img_shape'][0].item()
img_width = img_metas[0]['img_shape'][1].item()
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, device=cls_scores[0].device)
batch_mlvl_priors = []
for i in range(num_levels):
_priors = mlvl_priors[i].reshape(1, -1, mlvl_priors[i].shape[-1])
x1 = _priors[:, :, 0:1] / img_width
y1 = _priors[:, :, 1:2] / img_height
x2 = _priors[:, :, 2:3] / img_width
y2 = _priors[:, :, 3:4] / img_height
priors = torch.cat([x1, y1, x2, y2], dim=2).data
batch_mlvl_priors.append(priors)
cfg = self.test_cfg if cfg is None else cfg
batch_mlvl_bboxes = []
batch_mlvl_scores = []
batch_mlvl_score_factors = []
for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
enumerate(zip(cls_scores, bbox_preds,
score_factor_list, batch_mlvl_priors)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
# NCNN needs 3 dimensions to reshape when including -1 parameter in
# width or height dimension.
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
if with_score_factors:
score_factor = score_factor.permute(0, 2, 3, 1).\
reshape(batch_size, -1, 1).sigmoid()
cls_score = cls_score.permute(0, 2, 3, 1).\
reshape(batch_size, -1, self.cls_out_channels)
# NCNN DetectionOutput op needs num_class + 1 classes. So if sigmoid
# score, we should padding background class according to mmdetection
# num_class definition.
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
dummy_background_score = torch.zeros(
batch_size, cls_score.shape[1], 1, device=cls_score.device)
scores = torch.cat([scores, dummy_background_score], dim=2)
else:
scores = cls_score.softmax(-1)
batch_mlvl_bboxes.append(bbox_pred)
batch_mlvl_scores.append(scores)
batch_mlvl_score_factors.append(score_factor)
batch_mlvl_priors = torch.cat(batch_mlvl_priors, dim=1)
batch_mlvl_scores = torch.cat(batch_mlvl_scores, dim=1)
batch_mlvl_bboxes = torch.cat(batch_mlvl_bboxes, dim=1)
batch_mlvl_scores = torch.cat([
batch_mlvl_scores[:, :, self.num_classes:],
batch_mlvl_scores[:, :, 0:self.num_classes]
],
dim=2)
if isinstance(self.bbox_coder, TBLRBBoxCoder):
batch_mlvl_bboxes = _tblr_pred_to_delta_xywh_pred(
batch_mlvl_bboxes, vars[0:2])
# flatten for ncnn DetectionOutput op inputs.
batch_mlvl_vars = vars.expand_as(batch_mlvl_priors)
batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, 1, -1)
batch_mlvl_scores = batch_mlvl_scores.reshape(batch_size, 1, -1)
batch_mlvl_priors = batch_mlvl_priors.reshape(batch_size, 1, -1)
batch_mlvl_vars = batch_mlvl_vars.reshape(batch_size, 1, -1)
batch_mlvl_priors = torch.cat([batch_mlvl_priors, batch_mlvl_vars], dim=1)\
.data
post_params = get_post_processing_params(ctx.cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
output__ncnn = ncnn_detection_output_forward(
batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_priors,
score_threshold, iou_threshold, pre_top_k, keep_top_k,
self.num_classes + 1,
vars.cpu().detach().numpy())
return output__ncnn
def _tblr_pred_to_delta_xywh_pred(bbox_pred: torch.Tensor,
normalizer: torch.Tensor) -> torch.Tensor:
"""Transform tblr format bbox prediction to delta_xywh format for ncnn.
An internal function for transforming tblr format bbox prediction to
delta_xywh format. NCNN DetectionOutput layer needs delta_xywh format
bbox_pred as input.
Args:
bbox_pred (Tensor): The bbox prediction of tblr format, has shape
(N, num_det, 4).
normalizer (Tensor): The normalizer scale of bbox horizon and
vertical coordinates, has shape (2,).
Returns:
Tensor: The delta_xywh format bbox predictions.
"""
top = bbox_pred[:, :, 0:1]
bottom = bbox_pred[:, :, 1:2]
left = bbox_pred[:, :, 2:3]
right = bbox_pred[:, :, 3:4]
h = (top + bottom) * normalizer[0]
w = (left + right) * normalizer[1]
_dwh = torch.cat([w, h], dim=2)
assert torch.all(_dwh >= 0), 'wh must be positive before log.'
dwh = torch.log(_dwh)
return torch.cat([(right - left) / 2, (bottom - top) / 2, dwh], dim=2)

View File

@ -0,0 +1,124 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.codebase.mmdet.core.ops import (ncnn_detection_output_forward,
ncnn_prior_box_forward)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.SSDHead.get_bboxes', backend='ncnn')
def ssd_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of SSDHead for NCNN backend.
This rewriter using ncnn PriorBox and DetectionOutput layer to
support dynamic deployment, and has higher speed.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Box scores for each level in the
feature pyramid, has shape
(N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each
level in the feature pyramid, has shape
(N, num_anchors * 4, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
Default: None.
Returns:
Tensor: outputs, shape is [N, num_det, 6].
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
aspect_ratio = [
ratio[ratio > 1].detach().cpu().numpy()
for ratio in self.anchor_generator.ratios
]
strides = self.anchor_generator.strides
min_sizes = self.anchor_generator.base_sizes
if is_dynamic_flag:
max_sizes = min_sizes[1:] + img_metas[0]['img_shape'][0:1].tolist()
img_height = img_metas[0]['img_shape'][0].item()
img_width = img_metas[0]['img_shape'][1].item()
else:
max_sizes = min_sizes[1:] + img_metas[0]['img_shape'][0:1]
img_height = img_metas[0]['img_shape'][0]
img_width = img_metas[0]['img_shape'][1]
# if no reshape, concat will be error in ncnn.
mlvl_anchors = [
ncnn_prior_box_forward(cls_scores[i], aspect_ratio[i], img_height,
img_width, strides[i][0], strides[i][1],
max_sizes[i:i + 1],
min_sizes[i:i + 1]).reshape(1, 2, -1)
for i in range(num_levels)
]
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
cfg = self.test_cfg if cfg is None else cfg
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
batch_size = 1
mlvl_valid_bboxes = []
mlvl_scores = []
for level_id, cls_score, bbox_pred in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(cls_score)
# NCNN DetectionOutput layer uses background class at 0 position, but
# in mmdetection, background class is at self.num_classes position.
# We should adapt for ncnn.
batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
if self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores.sigmoid()
else:
batch_mlvl_scores = batch_mlvl_scores.softmax(-1)
batch_mlvl_anchors = torch.cat(mlvl_anchors, dim=2)
batch_mlvl_scores = torch.cat([
batch_mlvl_scores[:, :, self.num_classes:],
batch_mlvl_scores[:, :, 0:self.num_classes]
],
dim=2)
batch_mlvl_valid_bboxes = batch_mlvl_valid_bboxes.reshape(
batch_size, 1, -1)
batch_mlvl_scores = batch_mlvl_scores.reshape(batch_size, 1, -1)
batch_mlvl_anchors = batch_mlvl_anchors.reshape(batch_size, 2, -1)
post_params = get_post_processing_params(deploy_cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
output__ncnn = ncnn_detection_output_forward(
batch_mlvl_valid_bboxes, batch_mlvl_scores, batch_mlvl_anchors,
score_threshold, iou_threshold, pre_top_k, keep_top_k,
self.num_classes + 1)
return output__ncnn

View File

@ -0,0 +1,10 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.necks.ssd_neck.L2Norm.forward')
def l2norm__forward__default(ctx, self, x):
return torch.nn.functional.normalize(
x, dim=1) * self.weight[None, :, None, None]

View File

@ -139,114 +139,6 @@ def get_single_roi_extractor():
return model
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
@pytest.mark.parametrize('is_ssd', [True, False])
def test_anchor_head_get_bboxes(backend_type: Backend, is_ssd: bool):
"""Test get_bboxes rewrite of anchor head."""
check_backend(backend_type)
if is_ssd:
anchor_head = get_ssd_head_model()
else:
anchor_head = get_anchor_head_model()
anchor_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
if is_ssd:
output_names = ['output']
else:
output_names = ['dets', 'labels']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))))
if not is_ssd:
# For the general anchor_head:
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
else:
# For the ssd_head:
# the cls_score's size: (1, 30, 20, 20), (1, 30, 10, 10),
# (1, 30, 5, 5), (1, 30, 3, 3), (1, 30, 2, 2), (1, 30, 1, 1)
# the bboxes's size: (1, 24, 20, 20), (1, 24, 10, 10),
# (1, 24, 5, 5), (1, 24, 3, 3), (1, 24, 2, 2), (1, 24, 1, 1)
feat_shape = [20, 10, 5, 3, 2, 1]
num_prior = 6
seed_everything(1234)
cls_score = [
torch.rand(1, 30, feat_shape[i], feat_shape[i])
for i in range(num_prior)
]
seed_everything(5678)
bboxes = [
torch.rand(1, 24, feat_shape[i], feat_shape[i])
for i in range(num_prior)
]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.tensor([s, s], dtype=torch.int32)
wrapped_model = WrapModel(
anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze().cpu().numpy()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
def test_get_bboxes_of_fcos_head(backend_type: Backend):
check_backend(backend_type)
@ -1126,8 +1018,8 @@ def test_get_bboxes_of_vfnet_head(backend_type: Backend):
@pytest.mark.parametrize('backend_type',
[Backend.ONNXRUNTIME, Backend.OPENVINO])
def test_get_bboxes_of_base_dense_head(backend_type: Backend):
[Backend.ONNXRUNTIME, Backend.NCNN, Backend.OPENVINO])
def test_base_dense_head_get_bboxes(backend_type: Backend):
"""Test get_bboxes rewrite of base dense head."""
check_backend(backend_type)
anchor_head = get_anchor_head_model()
@ -1139,7 +1031,10 @@ def test_get_bboxes_of_base_dense_head(backend_type: Backend):
'img_shape': (s, s, 3)
}]
output_names = ['dets', 'labels']
if backend_type != Backend.NCNN:
output_names = ['dets', 'labels']
else:
output_names = ['output']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
@ -1204,3 +1099,89 @@ def test_get_bboxes_of_base_dense_head(backend_type: Backend):
atol=1e-05)
else:
assert rewrite_outputs is not None
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
def test_ssd_head_get_bboxes(backend_type: Backend):
"""Test get_bboxes rewrite of anchor head."""
check_backend(backend_type)
ssd_head = get_ssd_head_model()
ssd_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['output']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))))
# For the ssd_head:
# the cls_score's size: (1, 30, 20, 20), (1, 30, 10, 10),
# (1, 30, 5, 5), (1, 30, 3, 3), (1, 30, 2, 2), (1, 30, 1, 1)
# the bboxes's size: (1, 24, 20, 20), (1, 24, 10, 10),
# (1, 24, 5, 5), (1, 24, 3, 3), (1, 24, 2, 2), (1, 24, 1, 1)
feat_shape = [20, 10, 5, 3, 2, 1]
num_prior = 6
seed_everything(1234)
cls_score = [
torch.rand(1, 30, feat_shape[i], feat_shape[i])
for i in range(num_prior)
]
seed_everything(5678)
bboxes = [
torch.rand(1, 24, feat_shape[i], feat_shape[i])
for i in range(num_prior)
]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(ssd_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.tensor([s, s], dtype=torch.int32)
wrapped_model = WrapModel(
ssd_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze().cpu().numpy()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]],
rewrite_output,
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None