[Feature] Support FCOS and FSAF for mmdetection ncnn deployment (#63)
* support ncnn mmdet fcos.py * support fsaf * fix_lint * fix yapf * fix clang-format * Delete output_ncnn.jpg * Delete output_pytorch.jpg * remove comments and fix typo * fix blank line * Fix typo * Remove unnessisary comments * Add comment * Add commentspull/12/head
parent
10793f488e
commit
aba6ad5da7
|
@ -1 +1,2 @@
|
|||
from .delta_xywh_bbox_coder import * # noqa: F401,F403
|
||||
from .tblr_bbox_coder import * # noqa: F401, F403
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
|
||||
backend='default')
|
||||
def tblr2bboxes(ctx,
|
||||
priors,
|
||||
tblr,
|
||||
normalizer=4.0,
|
||||
normalize_by_wh=True,
|
||||
max_shape=None,
|
||||
clip_border=True):
|
||||
if not isinstance(normalizer, float):
|
||||
normalizer = torch.tensor(normalizer, device=priors.device)
|
||||
assert len(normalizer) == 4, 'Normalizer must have length = 4'
|
||||
assert priors.size(0) == tblr.size(0)
|
||||
if priors.ndim == 3:
|
||||
assert priors.size(1) == tblr.size(1)
|
||||
|
||||
loc_decode = tblr * normalizer
|
||||
prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
|
||||
if normalize_by_wh:
|
||||
wh = priors[..., 2:4] - priors[..., 0:2]
|
||||
|
||||
w, h = torch.split(wh, 1, dim=-1)
|
||||
# Inplace operation with slice would fail for exporting to ONNX
|
||||
th = h * loc_decode[..., :2] # tb
|
||||
tw = w * loc_decode[..., 2:] # lr
|
||||
loc_decode = torch.cat([th, tw], dim=-1)
|
||||
top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1)
|
||||
xmin = prior_centers[..., 0].unsqueeze(-1) - left
|
||||
xmax = prior_centers[..., 0].unsqueeze(-1) + right
|
||||
ymin = prior_centers[..., 1].unsqueeze(-1) - top
|
||||
ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
|
||||
|
||||
if clip_border and max_shape is not None:
|
||||
from mmdeploy.mmdet.export import clip_bboxes
|
||||
xmin, ymin, xmax, ymax = clip_bboxes(xmin, ymin, xmax, ymax, max_shape)
|
||||
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size())
|
||||
|
||||
return bboxes
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
|
||||
backend='ncnn')
|
||||
def delta2bbox_ncnn(ctx,
|
||||
priors,
|
||||
tblr,
|
||||
normalizer=4.0,
|
||||
normalize_by_wh=True,
|
||||
max_shape=None,
|
||||
clip_border=True):
|
||||
assert priors.size(0) == tblr.size(0)
|
||||
if priors.ndim == 3:
|
||||
assert priors.size(1) == tblr.size(1)
|
||||
|
||||
loc_decode = tblr * normalizer
|
||||
prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
|
||||
if normalize_by_wh:
|
||||
w = priors[..., 2:3] - priors[..., 0:1]
|
||||
h = priors[..., 3:4] - priors[..., 1:2]
|
||||
_h = h.unsqueeze(0).unsqueeze(-1)
|
||||
_loc_h = loc_decode[..., 0:2].unsqueeze(0).unsqueeze(-1)
|
||||
_w = w.unsqueeze(0).unsqueeze(-1)
|
||||
_loc_w = loc_decode[..., 2:4].unsqueeze(0).unsqueeze(-1)
|
||||
th = (_h * _loc_h).reshape(1, -1, 2)
|
||||
tw = (_w * _loc_w).reshape(1, -1, 2)
|
||||
loc_decode = torch.cat([th, tw], dim=2)
|
||||
top = loc_decode[..., 0:1]
|
||||
bottom = loc_decode[..., 1:2]
|
||||
left = loc_decode[..., 2:3]
|
||||
right = loc_decode[..., 3:4]
|
||||
xmin = prior_centers[..., 0].unsqueeze(-1) - left
|
||||
xmax = prior_centers[..., 0].unsqueeze(-1) + right
|
||||
ymin = prior_centers[..., 1].unsqueeze(-1) - top
|
||||
ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
|
||||
|
||||
if clip_border and max_shape is not None:
|
||||
from mmdeploy.mmdet.export import clip_bboxes
|
||||
xmin, ymin, xmax, ymax = clip_bboxes(xmin, ymin, xmax, ymax, max_shape)
|
||||
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size())
|
||||
|
||||
return bboxes
|
|
@ -105,3 +105,95 @@ def get_bboxes_of_fcos_head(ctx,
|
|||
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
|
||||
max_output_boxes_per_class, iou_threshold,
|
||||
score_threshold, nms_pre, cfg.max_per_img)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.FCOSHead.get_bboxes', backend='ncnn')
|
||||
def get_bboxes_of_fcos_head_ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
centernesses,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
||||
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
|
||||
points_list = self.get_points(featmap_sizes, bbox_preds[0].dtype,
|
||||
bbox_preds[0].device)
|
||||
|
||||
cls_score_list = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
centerness_pred_list = [
|
||||
centernesses[i].detach() for i in range(num_levels)
|
||||
]
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
assert len(cls_scores) == len(bbox_preds) == len(points_list)
|
||||
batch_size = 1
|
||||
pre_topk = cfg.get('nms_pre', -1)
|
||||
|
||||
# loop over features, decode boxes
|
||||
mlvl_bboxes = []
|
||||
mlvl_scores = []
|
||||
mlvl_centerness = []
|
||||
mlvl_points = []
|
||||
for level_id, cls_score, bbox_pred, centerness, points in zip(
|
||||
range(num_levels), cls_score_list, bbox_pred_list,
|
||||
centerness_pred_list, points_list):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
scores = cls_score.permute(0, 2, 3,
|
||||
1).reshape(batch_size, -1,
|
||||
self.cls_out_channels).sigmoid()
|
||||
centerness = centerness.permute(0, 2, 3, 1).reshape(batch_size, -1,
|
||||
1).sigmoid()
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
points = points.expand(1, -1, 2).data
|
||||
if pre_topk > 0:
|
||||
|
||||
_scores = scores.reshape(batch_size, -1, self.cls_out_channels, 1)
|
||||
_centerness = centerness.reshape(batch_size, -1, 1, 1)
|
||||
max_scores, _ = (_scores * _centerness). \
|
||||
reshape(batch_size, -1, self.cls_out_channels).max(-1)
|
||||
|
||||
_, topk_inds = max_scores.topk(pre_topk)
|
||||
|
||||
topk_inds = topk_inds.view(-1)
|
||||
|
||||
points = points[:, topk_inds, :]
|
||||
bbox_pred = bbox_pred[:, topk_inds, :]
|
||||
scores = scores[:, topk_inds, :]
|
||||
centerness = centerness[:, topk_inds, :]
|
||||
mlvl_points.append(points)
|
||||
mlvl_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_centerness.append(centerness)
|
||||
|
||||
batch_mlvl_points = torch.cat(mlvl_points, dim=1)
|
||||
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1)
|
||||
batch_mlvl_bboxes = distance2bbox(
|
||||
batch_mlvl_points, batch_mlvl_bboxes, max_shape=img_metas['img_shape'])
|
||||
|
||||
if not with_nms:
|
||||
return batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness
|
||||
|
||||
_batch_mlvl_scores = batch_mlvl_scores.unsqueeze(3)
|
||||
_batch_mlvl_centerness = batch_mlvl_centerness.unsqueeze(3)
|
||||
batch_mlvl_scores = (_batch_mlvl_scores * _batch_mlvl_centerness). \
|
||||
reshape(batch_mlvl_scores.shape)
|
||||
batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, -1, 4)
|
||||
post_params = deploy_cfg.post_processing
|
||||
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
nms_pre = cfg.get('deploy_nms_pre', -1)
|
||||
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
|
||||
max_output_boxes_per_class, iou_threshold,
|
||||
score_threshold, nms_pre, cfg.max_per_img)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .getattribute import getattribute_static
|
||||
from .group_norm import group_norm_ncnn
|
||||
from .interpolate import interpolate_static
|
||||
from .linear import linear_ncnn
|
||||
from .repeat import repeat_static
|
||||
|
@ -6,6 +7,7 @@ from .size import size_of_tensor_static
|
|||
from .topk import topk_dynamic, topk_static
|
||||
|
||||
__all__ = [
|
||||
'getattribute_static', 'interpolate_static', 'linear_ncnn',
|
||||
'repeat_static', 'size_of_tensor_static', 'topk_static', 'topk_dynamic'
|
||||
'getattribute_static', 'group_norm_ncnn', 'interpolate_static',
|
||||
'linear_ncnn', 'repeat_static', 'size_of_tensor_static', 'topk_static',
|
||||
'topk_dynamic'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.group_norm', backend='ncnn')
|
||||
def group_norm_ncnn(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
num_groups: int,
|
||||
weight: Union[torch.Tensor, torch.NoneType] = None,
|
||||
bias: Union[torch.Tensor, torch.NoneType] = None,
|
||||
eps: float = 1e-05,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input.shape
|
||||
batch_size = input_shape[0]
|
||||
# We cannot use input.reshape(batch_size, num_groups, -1, 1)
|
||||
# instead, or we will meet bug on ncnn Reshape ops.
|
||||
input_reshaped = input.reshape(batch_size, num_groups, -1)
|
||||
input_reshaped = input_reshaped.unsqueeze(3)
|
||||
# the weight_'s size is not the same as weight's size
|
||||
# we only use groupnorm to calculate instancenorm, but the
|
||||
# input parameters may not be the same, and need to transform.
|
||||
weight_ = torch.tensor([1.] * num_groups).type_as(input)
|
||||
bias_ = torch.tensor([0.] * num_groups).type_as(input)
|
||||
|
||||
norm_reshaped = torch.nn.functional.instance_norm(
|
||||
input_reshaped, weight=weight_, bias=bias_, eps=eps)
|
||||
|
||||
norm = norm_reshaped.reshape(*input_shape)
|
||||
if weight is None:
|
||||
weight = torch.tensor([1.]).type_as(input)
|
||||
if bias is None:
|
||||
bias = torch.tensor([0.]).type_as(input)
|
||||
weight = weight.reshape(1, -1, 1, 1)
|
||||
bias = bias.reshape(1, -1, 1, 1)
|
||||
|
||||
return norm * weight + bias
|
Loading…
Reference in New Issue