[Feature]Support skip nms (#1552)

* skip nms

* judge at beginning

* add test

* remove else

* add more details in docstr including version not

* fix unitest

* fix doc

* fix doc

* fix typo

* resove conversation

* fix link
pull/1585/head
Shilong Zhang 2021-12-14 13:18:30 +08:00 committed by GitHub
parent 88e017337a
commit 43b2f0981c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 5 deletions

View File

@ -260,20 +260,25 @@ def soft_nms(boxes,
def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
r"""Performs non-maximum suppression in a batched fashion.
Modified from https://github.com/pytorch/vision/blob\
/505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
Modified from
https://github.com/pytorch/vision/blob/505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
In order to perform NMS independently per class, we add an offset to all
the boxes. The offset is dependent only on the class idx, and is large
enough so that boxes from different classes do not overlap.
Note:
In v1.4.1 and later, ``batched_nms`` supports skipping the NMS and
returns sorted raw results when `nms_cfg` is None.
Args:
boxes (torch.Tensor): boxes in shape (N, 4).
scores (torch.Tensor): scores in shape (N, ).
idxs (torch.Tensor): each index value correspond to a bbox cluster,
and NMS will not be applied between elements of different idxs,
shape (N, ).
nms_cfg (dict): specify nms type and other parameters like iou_thr.
Possible keys includes the following.
nms_cfg (dict | None): Supports skipping the nms when `nms_cfg`
is None, otherwise it should specify nms type and other
parameters like `iou_thr`. Possible keys includes the following.
- iou_thr (float): IoU threshold used for NMS.
- split_thr (float): threshold number of boxes. In some cases the
@ -288,7 +293,19 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
Returns:
tuple: kept dets and indice.
- boxes (Tensor): Bboxes with score after nms, has shape
(num_bboxes, 5). last dimension 5 arrange as
(x1, y1, x2, y2, score)
- keep (Tensor): The indices of remaining boxes in input
boxes.
"""
# skip nms when nms_cfg is None
if nms_cfg is None:
scores, inds = scores.sort(descending=True)
boxes = boxes[inds]
return torch.cat([boxes, scores[:, None]], -1), inds
nms_cfg_ = nms_cfg.copy()
class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
if class_agnostic:
@ -333,7 +350,8 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
boxes = boxes[:max_num]
scores = scores[:max_num]
return torch.cat([boxes, scores[:, None]], -1), keep
boxes = torch.cat([boxes, scores[:, None]], -1)
return boxes, keep
def nms_match(dets, iou_threshold):

View File

@ -182,3 +182,14 @@ class Testnms(object):
assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes)
# test skip nms when `nms_cfg` is None
seq_boxes, seq_keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
None,
class_agnostic=False)
assert len(seq_keep) == len(results['boxes'])
# assert score is descending order
assert ((seq_boxes[:, -1][1:] - seq_boxes[:, -1][:-1]) < 0).all()