simplify non batch nms (#99)

This commit is contained in:
q.yao 2022-01-26 19:04:24 +08:00 committed by GitHub
parent a543d41159
commit f2d0b15341
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,6 +5,7 @@ from torch import Tensor
import mmdeploy import mmdeploy
from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop
from mmdeploy.utils import is_dynamic_batch
def select_nms_index(scores: torch.Tensor, def select_nms_index(scores: torch.Tensor,
@ -82,28 +83,10 @@ def _multiclass_nms(boxes: Tensor,
keep_top_k: int = -1): keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX. """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
This function helps exporting to onnx with batch and multiclass NMS op. This function helps exporting to onnx with batch and multiclass NMS op. It
It only supports class-agnostic detection results. That is, the scores only supports class-agnostic detection results. That is, the scores is of
is of shape (N, num_bboxes, num_classes) and the boxes is of shape shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
(N, num_boxes, 4). 4).
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
""" """
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class]) max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
@ -129,6 +112,116 @@ def _multiclass_nms(boxes: Tensor,
return dets, labels return dets, labels
def _multiclass_nms_single(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
Single batch nms could be optimized.
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
# pre topk
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.squeeze(0).topk(pre_top_k)
boxes = boxes[:, topk_inds, :]
scores = scores[:, topk_inds, :]
scores = scores.permute(0, 2, 1)
selected_indices = ONNXNMSop.apply(boxes, scores,
max_output_boxes_per_class,
iou_threshold, score_threshold)
cls_inds = selected_indices[:, 1]
box_inds = selected_indices[:, 2]
scores = scores[:, cls_inds, box_inds].unsqueeze(2)
boxes = boxes[:, box_inds, ...]
dets = torch.cat([boxes, scores], dim=2)
labels = cls_inds.unsqueeze(0)
# pad
dets = torch.cat((dets, dets.new_zeros((1, 1, 5))), 1)
labels = torch.cat((labels, labels.new_zeros((1, 1))), 1)
# topk or sort
is_use_topk = keep_top_k > 0 and \
(torch.onnx.is_in_onnx_export() or keep_top_k < dets.shape[1])
if is_use_topk:
_, topk_inds = dets[:, :, -1].topk(keep_top_k, dim=1)
else:
_, topk_inds = dets[:, :, -1].sort(dim=1, descending=True)
topk_inds = topk_inds.squeeze(0)
dets = dets[:, topk_inds, ...]
labels = labels[:, topk_inds, ...]
return dets, labels
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms')
def multiclass_nms__default(ctx,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
This function helps exporting to onnx with batch and multiclass NMS op.
It only supports class-agnostic detection results. That is, the scores
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
(N, num_boxes, 4).
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
"""
deploy_cfg = ctx.cfg
batch_size = boxes.size(0)
if not is_dynamic_batch(deploy_cfg) and batch_size != 1:
return _multiclass_nms_single(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
return _multiclass_nms(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms', func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms',
backend='tensorrt') backend='tensorrt')