mmdeploy/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py

289 lines
12 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from mmengine import ConfigDict
from torch import Tensor
from mmdeploy.codebase.mmdet.deploy import (gather_topk,
get_post_processing_params,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmcv.ops import multiclass_nms
from mmdeploy.utils import Backend, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.rpn_head.'
'RPNHead.predict_by_feat')
def rpn_head__predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True,
**kwargs):
"""Rewrite `predict_by_feat` of `RPNHead` for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
ctx = FUNCTION_REWRITER.get_context()
img_metas = batch_img_metas
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
cfg = self.test_cfg if cfg is None else cfg
batch_size = mlvl_cls_scores[0].shape[0]
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_valid_bboxes = []
mlvl_scores = []
mlvl_valid_anchors = []
for level_id, cls_score, bbox_pred, anchors in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3, 1)
if self.use_sigmoid_cls:
cls_score = cls_score.reshape(batch_size, -1)
scores = cls_score.sigmoid()
else:
cls_score = cls_score.reshape(batch_size, -1, 2)
# We set FG labels to [0, num_class-1] and BG label to
# num_class in RPN head since mmdet v2.5, which is unified to
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = cls_score.softmax(-1)[..., 0]
scores = scores.reshape(batch_size, -1, 1)
dim = self.bbox_coder.encode_size
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, dim)
# use static anchor if input shape is static
if not is_dynamic_flag:
anchors = anchors.data
anchors = anchors.unsqueeze(0)
# topk in tensorrt does not support shape<k
# concate zero to enable topk,
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
anchors = pad_with_value_if_necessary(anchors, 1, pre_topk)
if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
bbox_pred, scores = gather_topk(
bbox_pred,
scores,
inds=topk_inds,
batch_size=batch_size,
is_batched=True)
anchors = gather_topk(
anchors,
inds=topk_inds,
batch_size=batch_size,
is_batched=False)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_bboxes = self.bbox_coder.decode(
batch_mlvl_anchors,
batch_mlvl_bboxes,
max_shape=img_metas[0]['img_shape'])
# ignore background class
if not self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
post_params = get_post_processing_params(deploy_cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# only one class in rpn
max_output_boxes_per_class = keep_top_k
nms_type = cfg.nms.get('type')
return multiclass_nms(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
nms_type=nms_type,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
# TODO: Fix for 1.x
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value)
def rpn_head__get_bboxes__ncnn(self,
cls_scores,
bbox_preds,
img_metas,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` of `RPNHead` for ncnn backend.
Shape node and batch inference is not supported by ncnn. This function
transform dynamic shape to constant shape and remove batch inference.
Args:
ctx (ContextCaller): The context with additional information.
cls_scores (list[Tensor]): Box scores for each level in the
feature pyramid, has shape
(N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each
level in the feature pyramid, has shape
(N, num_anchors * 4, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
with_nms (bool): If True, do nms before return boxes.
Default: True.
cfg (mmengine.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
Default: None.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
ctx = FUNCTION_REWRITER.get_context()
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
cfg = self.test_cfg if cfg is None else cfg
batch_size = 1
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_valid_bboxes = []
mlvl_scores = []
mlvl_valid_anchors = []
for level_id, cls_score, bbox_pred, anchors in zip(
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3, 1)
if self.use_sigmoid_cls:
cls_score = cls_score.reshape(batch_size, -1)
scores = cls_score.sigmoid()
else:
cls_score = cls_score.reshape(batch_size, -1, 2)
# We set FG labels to [0, num_class-1] and BG label to
# num_class in RPN head since mmdet v2.5, which is unified to
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = cls_score.softmax(-1)[..., 0]
scores = scores.reshape(batch_size, -1, 1)
dim = self.bbox_coder.encode_size
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, dim)
anchors = anchors.expand_as(bbox_pred).data
if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
topk_inds = topk_inds.view(-1)
anchors = anchors[:, topk_inds, :]
bbox_pred = bbox_pred[:, topk_inds, :]
scores = scores[:, topk_inds, :]
mlvl_valid_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_valid_anchors.append(anchors)
batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
batch_mlvl_bboxes = self.bbox_coder.decode(
batch_mlvl_anchors,
batch_mlvl_bboxes,
max_shape=img_metas[0]['img_shape'])
# ignore background class
if not self.use_sigmoid_cls:
batch_mlvl_scores = batch_mlvl_scores[..., :self.num_classes]
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
post_params = get_post_processing_params(deploy_cfg)
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# only one class in rpn
max_output_boxes_per_class = keep_top_k
nms_type = cfg.nms.get('type')
return multiclass_nms(
batch_mlvl_bboxes,
batch_mlvl_scores,
max_output_boxes_per_class,
nms_type=nms_type,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)