1
0
mirror of https://github.com/alibaba/EasyCV.git synced 2025-06-03 14:49:00 +08:00

411 lines
15 KiB
Python
Raw Normal View History

2022-04-02 20:01:06 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
2022-04-02 20:01:06 +08:00
from distutils.version import LooseVersion
import numpy as np
2022-04-02 20:01:06 +08:00
import torch
import torchvision
from torchvision.ops.boxes import box_area
2022-04-02 20:01:06 +08:00
from easycv.models.detection.utils.misc import fp16_clamp
2022-04-02 20:01:06 +08:00
def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError
if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)
area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
return area_i / (area_a[:, None] + area_b - area_i)
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
for i, image_pred in enumerate(prediction):
# If none are remaining => process next image
if not image_pred.numel():
2022-04-02 20:01:06 +08:00
continue
# Get score and class with highest confidence
class_conf, class_pred = torch.max(
image_pred[:, 5:5 + num_classes], 1, keepdim=True)
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >=
conf_thre).squeeze()
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
detections = torch.cat(
(image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
if not detections.numel():
2022-04-02 20:01:06 +08:00
continue
if LooseVersion(torchvision.__version__) >= LooseVersion('0.8.0'):
nms_out_index = torchvision.ops.batched_nms(
detections[:, :4], detections[:, 4] * detections[:, 5],
detections[:, 6], nms_thre)
else:
nms_out_index = torchvision.ops.nms(
detections[:, :4], detections[:, 4] * detections[:, 5],
nms_thre)
detections = detections[nms_out_index]
if output[i] is None:
output[i] = detections
else:
output[i] = torch.cat((output[i], detections))
return output
def bbox2result(bboxes, labels, num_classes):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (torch.Tensor | np.ndarray): shape (n, 5)
labels (torch.Tensor | np.ndarray): shape (n, )
num_classes (int): class number, including background class
Returns:
list(ndarray): bbox results of each class
"""
if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
else:
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes)]
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
"""Calculate overlap between two set of bboxes.
FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889
Note:
Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou',
there are some new generated variable when calculating IOU
using bbox_overlaps function:
1) is_aligned is False
area1: M x 1
area2: N x 1
lt: M x N x 2
rb: M x N x 2
wh: M x N x 2
overlap: M x N x 1
union: M x N x 1
ious: M x N x 1
Total memory:
S = (9 x N x M + N + M) * 4 Byte,
When using FP16, we can reduce:
R = (9 x N x M + N + M) * 4 / 2 Byte
R large than (N + M) * 4 * 2 is always true when N and M >= 1.
Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2,
N + 1 < 3 * N, when N or M is 1.
Given M = 40 (ground truth), N = 400000 (three anchor boxes
in per grid, FPN, R-CNNs),
R = 275 MB (one times)
A special case (dense detection), M = 512 (ground truth),
R = 3516 MB = 3.43 GB
When the batch size is B, reduce:
B x R
Therefore, CUDA memory runs out frequently.
Experiments on GeForce RTX 2080Ti (11019 MiB):
| dtype | M | N | Use | Real | Ideal |
|:----:|:----:|:----:|:----:|:----:|:----:|
| FP32 | 512 | 400000 | 8020 MiB | -- | -- |
| FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB |
| FP32 | 40 | 400000 | 1540 MiB | -- | -- |
| FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB |
2) is_aligned is True
area1: N x 1
area2: N x 1
lt: N x 2
rb: N x 2
wh: N x 2
overlap: N x 1
union: N x 1
ious: N x 1
Total memory:
S = 11 x N * 4 Byte
When using FP16, we can reduce:
R = 11 x N * 4 / 2 Byte
So do the 'giou' (large than 'iou').
Time-wise, FP16 is generally faster than FP32.
When gpu_assign_thr is not -1, it takes more time on cpu
but not reduce memory.
There, we can reduce half the memory and keep the speed.
If ``is_aligned`` is ``False``, then calculate the overlaps between each
bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
pair of bboxes1 and bboxes2.
Args:
bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
B indicates the batch dim, in shape (B1, B2, ..., Bn).
If ``is_aligned`` is ``True``, then m and n must be equal.
mode (str): "iou" (intersection over union), "iof" (intersection over
foreground) or "giou" (generalized intersection over union).
Default "iou".
is_aligned (bool, optional): If True, then m and n must be equal.
Default False.
eps (float, optional): A value added to the denominator for numerical
stability. Default 1e-6.
Returns:
Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
Example:
>>> bboxes1 = torch.FloatTensor([
>>> [0, 0, 10, 10],
>>> [10, 10, 20, 20],
>>> [32, 32, 38, 42],
>>> ])
>>> bboxes2 = torch.FloatTensor([
>>> [0, 0, 10, 20],
>>> [0, 10, 10, 19],
>>> [10, 10, 20, 20],
>>> ])
>>> overlaps = bbox_overlaps(bboxes1, bboxes2)
>>> assert overlaps.shape == (3, 3)
>>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
>>> assert overlaps.shape == (3, )
Example:
>>> empty = torch.empty(0, 4)
>>> nonempty = torch.FloatTensor([[0, 0, 10, 9]])
>>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
>>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
>>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
"""
assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}'
# Either the boxes are empty or the length of boxes' last dimension is 4
assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
# Batch dim must be the same
# Batch dim: (B1, B2, ... Bn)
assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
batch_shape = bboxes1.shape[:-2]
rows = bboxes1.size(-2)
cols = bboxes2.size(-2)
if is_aligned:
assert rows == cols
if rows * cols == 0:
if is_aligned:
return bboxes1.new(batch_shape + (rows, ))
else:
return bboxes1.new(batch_shape + (rows, cols))
area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
bboxes1[..., 3] - bboxes1[..., 1])
area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
bboxes2[..., 3] - bboxes2[..., 1])
if is_aligned:
lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
wh = fp16_clamp(rb - lt, min=0)
overlap = wh[..., 0] * wh[..., 1]
if mode in ['iou', 'giou']:
union = area1 + area2 - overlap
else:
union = area1
if mode == 'giou':
enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])
enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])
else:
lt = torch.max(bboxes1[..., :, None, :2],
bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
rb = torch.min(bboxes1[..., :, None, 2:],
bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
wh = fp16_clamp(rb - lt, min=0)
overlap = wh[..., 0] * wh[..., 1]
if mode in ['iou', 'giou']:
union = area1[..., None] + area2[..., None, :] - overlap
else:
union = area1[..., None]
if mode == 'giou':
enclosed_lt = torch.min(bboxes1[..., :, None, :2],
bboxes2[..., None, :, :2])
enclosed_rb = torch.max(bboxes1[..., :, None, 2:],
bboxes2[..., None, :, 2:])
eps = union.new_tensor([eps])
union = torch.max(union, eps)
ious = overlap / union
if mode in ['iou', 'iof']:
return ious
# calculate gious
enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0)
enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
enclose_area = torch.max(enclose_area, eps)
gious = ious - (enclose_area - union) / enclose_area
return gious
def bbox2distance(points, bbox, max_dis=None, eps=0.1):
"""Decode bounding box based on distances.
Args:
points (Tensor): Shape (n, 2), [x, y].
bbox (Tensor): Shape (n, 4), "xyxy" format
max_dis (float): Upper bound of the distance.
eps (float): a small value to ensure target < max_dis, instead <=
Returns:
Tensor: Decoded distances.
"""
left = points[:, 0] - bbox[:, 0]
top = points[:, 1] - bbox[:, 1]
right = bbox[:, 2] - points[:, 0]
bottom = bbox[:, 3] - points[:, 1]
if max_dis is not None:
left = left.clamp(min=0, max=max_dis - eps)
top = top.clamp(min=0, max=max_dis - eps)
right = right.clamp(min=0, max=max_dis - eps)
bottom = bottom.clamp(min=0, max=max_dis - eps)
return torch.stack([left, top, right, bottom], -1)
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (B, N, 2) or (N, 2).
distance (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4)
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]]
and the length of max_shape should also be B.
Returns:
Tensor: Boxes with shape (N, 4) or (B, N, 4)
"""
x1 = points[..., 0] - distance[..., 0]
y1 = points[..., 1] - distance[..., 1]
x2 = points[..., 0] + distance[..., 2]
y2 = points[..., 1] + distance[..., 3]
bboxes = torch.stack([x1, y1, x2, y2], -1)
if max_shape is not None:
if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export():
# speed up
bboxes[:, 0::2].clamp_(min=0, max=max_shape[1])
bboxes[:, 1::2].clamp_(min=0, max=max_shape[0])
return bboxes
if not isinstance(max_shape, torch.Tensor):
max_shape = x1.new_tensor(max_shape)
max_shape = max_shape[..., :2].type_as(x1)
if max_shape.ndim == 2:
assert bboxes.ndim == 3
assert max_shape.size(0) == bboxes.size(0)
min_xy = x1.new_tensor(0)
max_xy = torch.cat([max_shape, max_shape],
dim=-1).flip(-1).unsqueeze(-2)
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
return bboxes