[Feature] Support FCOS and FSAF for mmdetection ncnn deployment (#63)

* support ncnn mmdet fcos.py

* support fsaf

* fix_lint

* fix yapf

* fix clang-format

* Delete output_ncnn.jpg

* Delete output_pytorch.jpg

* remove comments and fix typo

* fix blank line

* Fix typo

* Remove unnessisary comments

* Add comment

* Add comments
pull/12/head
hanrui1sensetime 2021-09-14 20:10:18 +08:00 committed by GitHub
parent 10793f488e
commit aba6ad5da7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 225 additions and 2 deletions

0
backend_ops/ncnn/ops/topk/topk.h 100755 → 100644
View File

View File

@ -1 +1,2 @@
from .delta_xywh_bbox_coder import * # noqa: F401,F403
from .tblr_bbox_coder import * # noqa: F401, F403

View File

@ -0,0 +1,87 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
backend='default')
def tblr2bboxes(ctx,
priors,
tblr,
normalizer=4.0,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
if not isinstance(normalizer, float):
normalizer = torch.tensor(normalizer, device=priors.device)
assert len(normalizer) == 4, 'Normalizer must have length = 4'
assert priors.size(0) == tblr.size(0)
if priors.ndim == 3:
assert priors.size(1) == tblr.size(1)
loc_decode = tblr * normalizer
prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
if normalize_by_wh:
wh = priors[..., 2:4] - priors[..., 0:2]
w, h = torch.split(wh, 1, dim=-1)
# Inplace operation with slice would fail for exporting to ONNX
th = h * loc_decode[..., :2] # tb
tw = w * loc_decode[..., 2:] # lr
loc_decode = torch.cat([th, tw], dim=-1)
top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1)
xmin = prior_centers[..., 0].unsqueeze(-1) - left
xmax = prior_centers[..., 0].unsqueeze(-1) + right
ymin = prior_centers[..., 1].unsqueeze(-1) - top
ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
if clip_border and max_shape is not None:
from mmdeploy.mmdet.export import clip_bboxes
xmin, ymin, xmax, ymax = clip_bboxes(xmin, ymin, xmax, ymax, max_shape)
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size())
return bboxes
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
backend='ncnn')
def delta2bbox_ncnn(ctx,
priors,
tblr,
normalizer=4.0,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
assert priors.size(0) == tblr.size(0)
if priors.ndim == 3:
assert priors.size(1) == tblr.size(1)
loc_decode = tblr * normalizer
prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
if normalize_by_wh:
w = priors[..., 2:3] - priors[..., 0:1]
h = priors[..., 3:4] - priors[..., 1:2]
_h = h.unsqueeze(0).unsqueeze(-1)
_loc_h = loc_decode[..., 0:2].unsqueeze(0).unsqueeze(-1)
_w = w.unsqueeze(0).unsqueeze(-1)
_loc_w = loc_decode[..., 2:4].unsqueeze(0).unsqueeze(-1)
th = (_h * _loc_h).reshape(1, -1, 2)
tw = (_w * _loc_w).reshape(1, -1, 2)
loc_decode = torch.cat([th, tw], dim=2)
top = loc_decode[..., 0:1]
bottom = loc_decode[..., 1:2]
left = loc_decode[..., 2:3]
right = loc_decode[..., 3:4]
xmin = prior_centers[..., 0].unsqueeze(-1) - left
xmax = prior_centers[..., 0].unsqueeze(-1) + right
ymin = prior_centers[..., 1].unsqueeze(-1) - top
ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
if clip_border and max_shape is not None:
from mmdeploy.mmdet.export import clip_bboxes
xmin, ymin, xmax, ymax = clip_bboxes(xmin, ymin, xmax, ymax, max_shape)
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size())
return bboxes

View File

@ -105,3 +105,95 @@ def get_bboxes_of_fcos_head(ctx,
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class, iou_threshold,
score_threshold, nms_pre, cfg.max_per_img)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.FCOSHead.get_bboxes', backend='ncnn')
def get_bboxes_of_fcos_head_ncnn(ctx,
self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
points_list = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
cls_score_list = [cls_scores[i].detach() for i in range(num_levels)]
bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)]
centerness_pred_list = [
centernesses[i].detach() for i in range(num_levels)
]
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds) == len(points_list)
batch_size = 1
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_bboxes = []
mlvl_scores = []
mlvl_centerness = []
mlvl_points = []
for level_id, cls_score, bbox_pred, centerness, points in zip(
range(num_levels), cls_score_list, bbox_pred_list,
centerness_pred_list, points_list):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels).sigmoid()
centerness = centerness.permute(0, 2, 3, 1).reshape(batch_size, -1,
1).sigmoid()
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
points = points.expand(1, -1, 2).data
if pre_topk > 0:
_scores = scores.reshape(batch_size, -1, self.cls_out_channels, 1)
_centerness = centerness.reshape(batch_size, -1, 1, 1)
max_scores, _ = (_scores * _centerness). \
reshape(batch_size, -1, self.cls_out_channels).max(-1)
_, topk_inds = max_scores.topk(pre_topk)
topk_inds = topk_inds.view(-1)
points = points[:, topk_inds, :]
bbox_pred = bbox_pred[:, topk_inds, :]
scores = scores[:, topk_inds, :]
centerness = centerness[:, topk_inds, :]
mlvl_points.append(points)
mlvl_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_centerness.append(centerness)
batch_mlvl_points = torch.cat(mlvl_points, dim=1)
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1)
batch_mlvl_bboxes = distance2bbox(
batch_mlvl_points, batch_mlvl_bboxes, max_shape=img_metas['img_shape'])
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness
_batch_mlvl_scores = batch_mlvl_scores.unsqueeze(3)
_batch_mlvl_centerness = batch_mlvl_centerness.unsqueeze(3)
batch_mlvl_scores = (_batch_mlvl_scores * _batch_mlvl_centerness). \
reshape(batch_mlvl_scores.shape)
batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, -1, 4)
post_params = deploy_cfg.post_processing
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
nms_pre = cfg.get('deploy_nms_pre', -1)
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class, iou_threshold,
score_threshold, nms_pre, cfg.max_per_img)

View File

@ -1,4 +1,5 @@
from .getattribute import getattribute_static
from .group_norm import group_norm_ncnn
from .interpolate import interpolate_static
from .linear import linear_ncnn
from .repeat import repeat_static
@ -6,6 +7,7 @@ from .size import size_of_tensor_static
from .topk import topk_dynamic, topk_static
__all__ = [
'getattribute_static', 'interpolate_static', 'linear_ncnn',
'repeat_static', 'size_of_tensor_static', 'topk_static', 'topk_dynamic'
'getattribute_static', 'group_norm_ncnn', 'interpolate_static',
'linear_ncnn', 'repeat_static', 'size_of_tensor_static', 'topk_static',
'topk_dynamic'
]

View File

@ -0,0 +1,41 @@
from typing import Union
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.group_norm', backend='ncnn')
def group_norm_ncnn(
ctx,
input: torch.Tensor,
num_groups: int,
weight: Union[torch.Tensor, torch.NoneType] = None,
bias: Union[torch.Tensor, torch.NoneType] = None,
eps: float = 1e-05,
) -> torch.Tensor:
input_shape = input.shape
batch_size = input_shape[0]
# We cannot use input.reshape(batch_size, num_groups, -1, 1)
# instead, or we will meet bug on ncnn Reshape ops.
input_reshaped = input.reshape(batch_size, num_groups, -1)
input_reshaped = input_reshaped.unsqueeze(3)
# the weight_'s size is not the same as weight's size
# we only use groupnorm to calculate instancenorm, but the
# input parameters may not be the same, and need to transform.
weight_ = torch.tensor([1.] * num_groups).type_as(input)
bias_ = torch.tensor([0.] * num_groups).type_as(input)
norm_reshaped = torch.nn.functional.instance_norm(
input_reshaped, weight=weight_, bias=bias_, eps=eps)
norm = norm_reshaped.reshape(*input_shape)
if weight is None:
weight = torch.tensor([1.]).type_as(input)
if bias is None:
bias = torch.tensor([0.]).type_as(input)
weight = weight.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return norm * weight + bias