EasyCV/easycv/models/detection/utils/boxes.py

510 lines
19 KiB
Python
Raw Normal View History

2022-04-02 20:01:06 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
2022-04-02 20:01:06 +08:00
from distutils.version import LooseVersion
import numpy as np
2022-04-02 20:01:06 +08:00
import torch
import torchvision
from torchvision.ops.boxes import box_area, nms
2022-04-02 20:01:06 +08:00
from easycv.models.detection.utils.misc import fp16_clamp
2022-04-02 20:01:06 +08:00
def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError
if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)
area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
for i, image_pred in enumerate(prediction):
# If none are remaining => process next image
if not image_pred.numel():
2022-04-02 20:01:06 +08:00
continue
# Get score and class with highest confidence
class_conf, class_pred = torch.max(
image_pred[:, 5:5 + num_classes], 1, keepdim=True)
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >=
conf_thre).squeeze()
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
detections = torch.cat(
(image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
if not detections.numel():
2022-04-02 20:01:06 +08:00
continue
if LooseVersion(torchvision.__version__) >= LooseVersion('0.8.0'):
nms_out_index = torchvision.ops.batched_nms(
detections[:, :4], detections[:, 4] * detections[:, 5],
detections[:, 6], nms_thre)
else:
nms_out_index = torchvision.ops.nms(
detections[:, :4], detections[:, 4] * detections[:, 5],
nms_thre)
detections = detections[nms_out_index]
if output[i] is None:
output[i] = detections
else:
output[i] = torch.cat((output[i], detections))
return output
def bbox2result(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (torch.Tensor | np.ndarray): shape (n, 5)
labels (torch.Tensor | np.ndarray): shape (n, )
num_classes (int): class number, including background class
Returns:
list(ndarray): bbox results of each class
"""
if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
else:
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes)]
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
"""Calculate overlap between two set of bboxes.
FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889
Note:
Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou',
there are some new generated variable when calculating IOU
using bbox_overlaps function:
1) is_aligned is False
area1: M x 1
area2: N x 1
lt: M x N x 2
rb: M x N x 2
wh: M x N x 2
overlap: M x N x 1
union: M x N x 1
ious: M x N x 1
Total memory:
S = (9 x N x M + N + M) * 4 Byte,
When using FP16, we can reduce:
R = (9 x N x M + N + M) * 4 / 2 Byte
R large than (N + M) * 4 * 2 is always true when N and M >= 1.
Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2,
N + 1 < 3 * N, when N or M is 1.
Given M = 40 (ground truth), N = 400000 (three anchor boxes
in per grid, FPN, R-CNNs),
R = 275 MB (one times)
A special case (dense detection), M = 512 (ground truth),
R = 3516 MB = 3.43 GB
When the batch size is B, reduce:
B x R
Therefore, CUDA memory runs out frequently.
Experiments on GeForce RTX 2080Ti (11019 MiB):
| dtype | M | N | Use | Real | Ideal |
|:----:|:----:|:----:|:----:|:----:|:----:|
| FP32 | 512 | 400000 | 8020 MiB | -- | -- |
| FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB |
| FP32 | 40 | 400000 | 1540 MiB | -- | -- |
| FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB |
2) is_aligned is True
area1: N x 1
area2: N x 1
lt: N x 2
rb: N x 2
wh: N x 2
overlap: N x 1
union: N x 1
ious: N x 1
Total memory:
S = 11 x N * 4 Byte
When using FP16, we can reduce:
R = 11 x N * 4 / 2 Byte
So do the 'giou' (large than 'iou').
Time-wise, FP16 is generally faster than FP32.
When gpu_assign_thr is not -1, it takes more time on cpu
but not reduce memory.
There, we can reduce half the memory and keep the speed.
If ``is_aligned`` is ``False``, then calculate the overlaps between each
bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
pair of bboxes1 and bboxes2.
Args:
bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
B indicates the batch dim, in shape (B1, B2, ..., Bn).
If ``is_aligned`` is ``True``, then m and n must be equal.
mode (str): "iou" (intersection over union), "iof" (intersection over
foreground) or "giou" (generalized intersection over union).
Default "iou".
is_aligned (bool, optional): If True, then m and n must be equal.
Default False.
eps (float, optional): A value added to the denominator for numerical
stability. Default 1e-6.
Returns:
Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
Example:
>>> bboxes1 = torch.FloatTensor([
>>> [0, 0, 10, 10],
>>> [10, 10, 20, 20],
>>> [32, 32, 38, 42],
>>> ])
>>> bboxes2 = torch.FloatTensor([
>>> [0, 0, 10, 20],
>>> [0, 10, 10, 19],
>>> [10, 10, 20, 20],
>>> ])
>>> overlaps = bbox_overlaps(bboxes1, bboxes2)
>>> assert overlaps.shape == (3, 3)
>>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
>>> assert overlaps.shape == (3, )
Example:
>>> empty = torch.empty(0, 4)
>>> nonempty = torch.FloatTensor([[0, 0, 10, 9]])
>>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
>>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
>>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
"""
assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}'
# Either the boxes are empty or the length of boxes' last dimension is 4
assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
# Batch dim must be the same
# Batch dim: (B1, B2, ... Bn)
assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
batch_shape = bboxes1.shape[:-2]
rows = bboxes1.size(-2)
cols = bboxes2.size(-2)
if is_aligned:
assert rows == cols
if rows * cols == 0:
if is_aligned:
return bboxes1.new(batch_shape + (rows, ))
else:
return bboxes1.new(batch_shape + (rows, cols))
area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
bboxes1[..., 3] - bboxes1[..., 1])
area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
bboxes2[..., 3] - bboxes2[..., 1])
if is_aligned:
lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
wh = fp16_clamp(rb - lt, min=0)
overlap = wh[..., 0] * wh[..., 1]
if mode in ['iou', 'giou']:
union = area1 + area2 - overlap
else:
union = area1
if mode == 'giou':
enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])
enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])
else:
lt = torch.max(bboxes1[..., :, None, :2],
bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
rb = torch.min(bboxes1[..., :, None, 2:],
bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
wh = fp16_clamp(rb - lt, min=0)
overlap = wh[..., 0] * wh[..., 1]
if mode in ['iou', 'giou']:
union = area1[..., None] + area2[..., None, :] - overlap
else:
union = area1[..., None]
if mode == 'giou':
enclosed_lt = torch.min(bboxes1[..., :, None, :2],
bboxes2[..., None, :, :2])
enclosed_rb = torch.max(bboxes1[..., :, None, 2:],
bboxes2[..., None, :, 2:])
eps = union.new_tensor([eps])
union = torch.max(union, eps)
ious = overlap / union
if mode in ['iou', 'iof']:
return ious
# calculate gious
enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0)
enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
enclose_area = torch.max(enclose_area, eps)
gious = ious - (enclose_area - union) / enclose_area
return gious
def bbox2distance(points, bbox, max_dis=None, eps=0.1):
"""Decode bounding box based on distances.
Args:
points (Tensor): Shape (n, 2), [x, y].
bbox (Tensor): Shape (n, 4), "xyxy" format
max_dis (float): Upper bound of the distance.
eps (float): a small value to ensure target < max_dis, instead <=
Returns:
Tensor: Decoded distances.
"""
left = points[:, 0] - bbox[:, 0]
top = points[:, 1] - bbox[:, 1]
right = bbox[:, 2] - points[:, 0]
bottom = bbox[:, 3] - points[:, 1]
if max_dis is not None:
left = left.clamp(min=0, max=max_dis - eps)
top = top.clamp(min=0, max=max_dis - eps)
right = right.clamp(min=0, max=max_dis - eps)
bottom = bottom.clamp(min=0, max=max_dis - eps)
return torch.stack([left, top, right, bottom], -1)
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (B, N, 2) or (N, 2).
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4)
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
Returns:
Tensor: Boxes with shape (N, 4) or (B, N, 4)
"""
x1 = points[..., 0] - distance[..., 0]
y1 = points[..., 1] - distance[..., 1]
x2 = points[..., 0] + distance[..., 2]
y2 = points[..., 1] + distance[..., 3]
bboxes = torch.stack([x1, y1, x2, y2], -1)
if max_shape is not None:
if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export():
# speed up
bboxes[:, 0::2].clamp_(min=0, max=max_shape[1])
bboxes[:, 1::2].clamp_(min=0, max=max_shape[0])
return bboxes
if not isinstance(max_shape, torch.Tensor):
max_shape = x1.new_tensor(max_shape)
max_shape = max_shape[..., :2].type_as(x1)
if max_shape.ndim == 2:
assert bboxes.ndim == 3
assert max_shape.size(0) == bboxes.size(0)
min_xy = x1.new_tensor(0)
max_xy = torch.cat([max_shape, max_shape],
dim=-1).flip(-1).unsqueeze(-2)
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
return bboxes
def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
r"""Performs non-maximum suppression in a batched fashion.
Modified from `torchvision/ops/boxes.py#L39
<https://github.com/pytorch/vision/blob/
505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39>`_.
In order to perform NMS independently per class, we add an offset to all
the boxes. The offset is dependent only on the class idx, and is large
enough so that boxes from different classes do not overlap.
Note:
In v1.4.1 and later, ``batched_nms`` supports skipping the NMS and
returns sorted raw results when `nms_cfg` is None.
Args:
boxes (torch.Tensor): boxes in shape (N, 4).
scores (torch.Tensor): scores in shape (N, ).
idxs (torch.Tensor): each index value correspond to a bbox cluster,
and NMS will not be applied between elements of different idxs,
shape (N, ).
nms_cfg (dict | None): Supports skipping the nms when `nms_cfg`
is None, otherwise it should specify nms type and other
parameters like `iou_thr`. Possible keys includes the following.
- iou_thr (float): IoU threshold used for NMS.
- split_thr (float): threshold number of boxes. In some cases the
number of boxes is large (e.g., 200k). To avoid OOM during
training, the users could set `split_thr` to a small value.
If the number of boxes is greater than the threshold, it will
perform NMS on each group of boxes separately and sequentially.
Defaults to 10000.
class_agnostic (bool): if true, nms is class agnostic,
i.e. IoU thresholding happens over all boxes,
regardless of the predicted class.
Returns:
tuple: kept dets and indice.
- boxes (Tensor): Bboxes with score after nms, has shape
(num_bboxes, 5). last dimension 5 arrange as
(x1, y1, x2, y2, score)
- keep (Tensor): The indices of remaining boxes in input
boxes.
"""
# skip nms when nms_cfg is None
if nms_cfg is None:
scores, inds = scores.sort(descending=True)
boxes = boxes[inds]
return torch.cat([boxes, scores[:, None]], -1), inds
nms_cfg_ = nms_cfg.copy()
class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
if class_agnostic:
boxes_for_nms = boxes
else:
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
nms_type = nms_cfg_.pop('type', 'nms')
nms_op = eval(nms_type)
split_thr = nms_cfg_.pop('split_thr', 10000)
# Won't split to multiple nms nodes when exporting to onnx
if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export():
keep = nms(boxes_for_nms, scores, **nms_cfg_)
boxes = boxes[keep]
# This assumes `dets` has arbitrary dimensions where
# the last dimension is score.
# Currently it supports bounding boxes [x1, y1, x2, y2, score] or
# rotated boxes [cx, cy, w, h, angle_radian, score].
scores = scores[keep]
else:
max_num = nms_cfg_.pop('max_num', -1)
total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
# Some type of nms would reweight the score, such as SoftNMS
scores_after_nms = scores.new_zeros(scores.size())
for id in torch.unique(idxs):
mask = (idxs == id).nonzero(as_tuple=False).view(-1)
keep = nms(boxes_for_nms[mask], scores[mask], **nms_cfg_)
total_mask[mask[keep]] = True
scores_after_nms[mask[keep]] = scores[keep]
keep = total_mask.nonzero(as_tuple=False).view(-1)
scores, inds = scores_after_nms[keep].sort(descending=True)
keep = keep[inds]
boxes = boxes[keep]
if max_num > 0:
keep = keep[:max_num]
boxes = boxes[:max_num]
scores = scores[:max_num]
boxes = torch.cat([boxes, scores[:, None]], -1)
return boxes, keep