mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
simplify non batch nms (#99)
This commit is contained in:
parent
a543d41159
commit
f2d0b15341
@ -5,6 +5,7 @@ from torch import Tensor
|
||||
import mmdeploy
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop
|
||||
from mmdeploy.utils import is_dynamic_batch
|
||||
|
||||
|
||||
def select_nms_index(scores: torch.Tensor,
|
||||
@ -82,28 +83,10 @@ def _multiclass_nms(boxes: Tensor,
|
||||
keep_top_k: int = -1):
|
||||
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
|
||||
|
||||
This function helps exporting to onnx with batch and multiclass NMS op.
|
||||
It only supports class-agnostic detection results. That is, the scores
|
||||
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
|
||||
(N, num_boxes, 4).
|
||||
|
||||
Args:
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
max_output_boxes_per_class (int): Maximum number of output
|
||||
boxes per class of nms. Defaults to 1000.
|
||||
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
|
||||
score_threshold (float): score threshold of nms.
|
||||
Defaults to 0.05.
|
||||
pre_top_k (int): Number of top K boxes to keep before nms.
|
||||
Defaults to -1.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
||||
and `labels` of shape [N, num_det].
|
||||
This function helps exporting to onnx with batch and multiclass NMS op. It
|
||||
only supports class-agnostic detection results. That is, the scores is of
|
||||
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
|
||||
4).
|
||||
"""
|
||||
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
|
||||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||
@ -129,6 +112,116 @@ def _multiclass_nms(boxes: Tensor,
|
||||
return dets, labels
|
||||
|
||||
|
||||
def _multiclass_nms_single(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = -1):
|
||||
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
|
||||
|
||||
Single batch nms could be optimized.
|
||||
"""
|
||||
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
|
||||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
||||
|
||||
# pre topk
|
||||
if pre_top_k > 0:
|
||||
max_scores, _ = scores.max(-1)
|
||||
_, topk_inds = max_scores.squeeze(0).topk(pre_top_k)
|
||||
boxes = boxes[:, topk_inds, :]
|
||||
scores = scores[:, topk_inds, :]
|
||||
|
||||
scores = scores.permute(0, 2, 1)
|
||||
selected_indices = ONNXNMSop.apply(boxes, scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold)
|
||||
|
||||
cls_inds = selected_indices[:, 1]
|
||||
box_inds = selected_indices[:, 2]
|
||||
|
||||
scores = scores[:, cls_inds, box_inds].unsqueeze(2)
|
||||
boxes = boxes[:, box_inds, ...]
|
||||
dets = torch.cat([boxes, scores], dim=2)
|
||||
labels = cls_inds.unsqueeze(0)
|
||||
|
||||
# pad
|
||||
dets = torch.cat((dets, dets.new_zeros((1, 1, 5))), 1)
|
||||
labels = torch.cat((labels, labels.new_zeros((1, 1))), 1)
|
||||
|
||||
# topk or sort
|
||||
is_use_topk = keep_top_k > 0 and \
|
||||
(torch.onnx.is_in_onnx_export() or keep_top_k < dets.shape[1])
|
||||
if is_use_topk:
|
||||
_, topk_inds = dets[:, :, -1].topk(keep_top_k, dim=1)
|
||||
else:
|
||||
_, topk_inds = dets[:, :, -1].sort(dim=1, descending=True)
|
||||
topk_inds = topk_inds.squeeze(0)
|
||||
dets = dets[:, topk_inds, ...]
|
||||
labels = labels[:, topk_inds, ...]
|
||||
|
||||
return dets, labels
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms')
|
||||
def multiclass_nms__default(ctx,
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = -1):
|
||||
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
|
||||
|
||||
This function helps exporting to onnx with batch and multiclass NMS op.
|
||||
It only supports class-agnostic detection results. That is, the scores
|
||||
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
|
||||
(N, num_boxes, 4).
|
||||
|
||||
Args:
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
max_output_boxes_per_class (int): Maximum number of output
|
||||
boxes per class of nms. Defaults to 1000.
|
||||
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
|
||||
score_threshold (float): score threshold of nms.
|
||||
Defaults to 0.05.
|
||||
pre_top_k (int): Number of top K boxes to keep before nms.
|
||||
Defaults to -1.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
||||
and `labels` of shape [N, num_det].
|
||||
"""
|
||||
deploy_cfg = ctx.cfg
|
||||
batch_size = boxes.size(0)
|
||||
if not is_dynamic_batch(deploy_cfg) and batch_size != 1:
|
||||
return _multiclass_nms_single(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
else:
|
||||
return _multiclass_nms(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms',
|
||||
backend='tensorrt')
|
||||
|
Loading…
x
Reference in New Issue
Block a user