diff --git a/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py b/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py index 73e3db584..2580d37ef 100644 --- a/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py +++ b/mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py @@ -5,6 +5,7 @@ from torch import Tensor import mmdeploy from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop +from mmdeploy.utils import is_dynamic_batch def select_nms_index(scores: torch.Tensor, @@ -82,28 +83,10 @@ def _multiclass_nms(boxes: Tensor, keep_top_k: int = -1): """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX. - This function helps exporting to onnx with batch and multiclass NMS op. - It only supports class-agnostic detection results. That is, the scores - is of shape (N, num_bboxes, num_classes) and the boxes is of shape - (N, num_boxes, 4). - - Args: - boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. - scores (Tensor): The detection scores of shape - [N, num_boxes, num_classes]. - max_output_boxes_per_class (int): Maximum number of output - boxes per class of nms. Defaults to 1000. - iou_threshold (float): IOU threshold of nms. Defaults to 0.5. - score_threshold (float): score threshold of nms. - Defaults to 0.05. - pre_top_k (int): Number of top K boxes to keep before nms. - Defaults to -1. - keep_top_k (int): Number of top K boxes to keep after nms. - Defaults to -1. - - Returns: - tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5] - and `labels` of shape [N, num_det]. + This function helps exporting to onnx with batch and multiclass NMS op. It + only supports class-agnostic detection results. That is, the scores is of + shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes, + 4). """ max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class]) iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) @@ -129,6 +112,116 @@ def _multiclass_nms(boxes: Tensor, return dets, labels +def _multiclass_nms_single(boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1): + """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX. + + Single batch nms could be optimized. + """ + max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class]) + iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) + score_threshold = torch.tensor([score_threshold], dtype=torch.float32) + + # pre topk + if pre_top_k > 0: + max_scores, _ = scores.max(-1) + _, topk_inds = max_scores.squeeze(0).topk(pre_top_k) + boxes = boxes[:, topk_inds, :] + scores = scores[:, topk_inds, :] + + scores = scores.permute(0, 2, 1) + selected_indices = ONNXNMSop.apply(boxes, scores, + max_output_boxes_per_class, + iou_threshold, score_threshold) + + cls_inds = selected_indices[:, 1] + box_inds = selected_indices[:, 2] + + scores = scores[:, cls_inds, box_inds].unsqueeze(2) + boxes = boxes[:, box_inds, ...] + dets = torch.cat([boxes, scores], dim=2) + labels = cls_inds.unsqueeze(0) + + # pad + dets = torch.cat((dets, dets.new_zeros((1, 1, 5))), 1) + labels = torch.cat((labels, labels.new_zeros((1, 1))), 1) + + # topk or sort + is_use_topk = keep_top_k > 0 and \ + (torch.onnx.is_in_onnx_export() or keep_top_k < dets.shape[1]) + if is_use_topk: + _, topk_inds = dets[:, :, -1].topk(keep_top_k, dim=1) + else: + _, topk_inds = dets[:, :, -1].sort(dim=1, descending=True) + topk_inds = topk_inds.squeeze(0) + dets = dets[:, topk_inds, ...] + labels = labels[:, topk_inds, ...] + + return dets, labels + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms') +def multiclass_nms__default(ctx, + boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1): + """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX. + + This function helps exporting to onnx with batch and multiclass NMS op. + It only supports class-agnostic detection results. That is, the scores + is of shape (N, num_bboxes, num_classes) and the boxes is of shape + (N, num_boxes, 4). + + Args: + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes]. + max_output_boxes_per_class (int): Maximum number of output + boxes per class of nms. Defaults to 1000. + iou_threshold (float): IOU threshold of nms. Defaults to 0.5. + score_threshold (float): score threshold of nms. + Defaults to 0.05. + pre_top_k (int): Number of top K boxes to keep before nms. + Defaults to -1. + keep_top_k (int): Number of top K boxes to keep after nms. + Defaults to -1. + + Returns: + tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5] + and `labels` of shape [N, num_det]. + """ + deploy_cfg = ctx.cfg + batch_size = boxes.size(0) + if not is_dynamic_batch(deploy_cfg) and batch_size != 1: + return _multiclass_nms_single( + boxes, + scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) + else: + return _multiclass_nms( + boxes, + scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) + + @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms', backend='tensorrt')