# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) OpenMMLab. All rights reserved. from distutils.version import LooseVersion import numpy as np import torch import torchvision from torchvision.ops.boxes import box_area from easycv.models.detection.utils.misc import fp16_clamp def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: raise IndexError if xyxy: tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) else: tl = torch.max( (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), ) br = torch.min( (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), ) area_a = torch.prod(bboxes_a[:, 2:], 1) area_b = torch.prod(bboxes_b[:, 2:], 1) en = (tl < br).type(tl.type()).prod(dim=2) area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) return area_i / (area_a[:, None] + area_b - area_i) def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45): box_corner = prediction.new(prediction.shape) box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 prediction[:, :, :4] = box_corner[:, :, :4] output = [None for _ in range(len(prediction))] for i, image_pred in enumerate(prediction): # If none are remaining => process next image if not image_pred.numel(): continue # Get score and class with highest confidence class_conf, class_pred = torch.max( image_pred[:, 5:5 + num_classes], 1, keepdim=True) conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) detections = torch.cat( (image_pred[:, :5], class_conf, class_pred.float()), 1) detections = detections[conf_mask] if not detections.numel(): continue if LooseVersion(torchvision.__version__) >= LooseVersion('0.8.0'): nms_out_index = torchvision.ops.batched_nms( detections[:, :4], detections[:, 4] * detections[:, 5], detections[:, 6], nms_thre) else: nms_out_index = torchvision.ops.nms( detections[:, :4], detections[:, 4] * detections[:, 5], nms_thre) detections = detections[nms_out_index] if output[i] is None: output[i] = detections else: output[i] = torch.cat((output[i], detections)) return output def bbox2result(bboxes, labels, num_classes): """Convert detection results to a list of numpy arrays. Args: bboxes (torch.Tensor | np.ndarray): shape (n, 5) labels (torch.Tensor | np.ndarray): shape (n, ) num_classes (int): class number, including background class Returns: list(ndarray): bbox results of each class """ if bboxes.shape[0] == 0: return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)] else: if isinstance(bboxes, torch.Tensor): bboxes = bboxes.detach().cpu().numpy() labels = labels.detach().cpu().numpy() return [bboxes[labels == i, :] for i in range(num_classes)] def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x.unbind(-1) b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return torch.stack(b, dim=-1) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / union return iou, union def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / area def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6): """Calculate overlap between two set of bboxes. FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889 Note: Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou', there are some new generated variable when calculating IOU using bbox_overlaps function: 1) is_aligned is False area1: M x 1 area2: N x 1 lt: M x N x 2 rb: M x N x 2 wh: M x N x 2 overlap: M x N x 1 union: M x N x 1 ious: M x N x 1 Total memory: S = (9 x N x M + N + M) * 4 Byte, When using FP16, we can reduce: R = (9 x N x M + N + M) * 4 / 2 Byte R large than (N + M) * 4 * 2 is always true when N and M >= 1. Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2, N + 1 < 3 * N, when N or M is 1. Given M = 40 (ground truth), N = 400000 (three anchor boxes in per grid, FPN, R-CNNs), R = 275 MB (one times) A special case (dense detection), M = 512 (ground truth), R = 3516 MB = 3.43 GB When the batch size is B, reduce: B x R Therefore, CUDA memory runs out frequently. Experiments on GeForce RTX 2080Ti (11019 MiB): | dtype | M | N | Use | Real | Ideal | |:----:|:----:|:----:|:----:|:----:|:----:| | FP32 | 512 | 400000 | 8020 MiB | -- | -- | | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB | | FP32 | 40 | 400000 | 1540 MiB | -- | -- | | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB | 2) is_aligned is True area1: N x 1 area2: N x 1 lt: N x 2 rb: N x 2 wh: N x 2 overlap: N x 1 union: N x 1 ious: N x 1 Total memory: S = 11 x N * 4 Byte When using FP16, we can reduce: R = 11 x N * 4 / 2 Byte So do the 'giou' (large than 'iou'). Time-wise, FP16 is generally faster than FP32. When gpu_assign_thr is not -1, it takes more time on cpu but not reduce memory. There, we can reduce half the memory and keep the speed. If ``is_aligned`` is ``False``, then calculate the overlaps between each bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned pair of bboxes1 and bboxes2. Args: bboxes1 (Tensor): shape (B, m, 4) in format or empty. bboxes2 (Tensor): shape (B, n, 4) in format or empty. B indicates the batch dim, in shape (B1, B2, ..., Bn). If ``is_aligned`` is ``True``, then m and n must be equal. mode (str): "iou" (intersection over union), "iof" (intersection over foreground) or "giou" (generalized intersection over union). Default "iou". is_aligned (bool, optional): If True, then m and n must be equal. Default False. eps (float, optional): A value added to the denominator for numerical stability. Default 1e-6. Returns: Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) Example: >>> bboxes1 = torch.FloatTensor([ >>> [0, 0, 10, 10], >>> [10, 10, 20, 20], >>> [32, 32, 38, 42], >>> ]) >>> bboxes2 = torch.FloatTensor([ >>> [0, 0, 10, 20], >>> [0, 10, 10, 19], >>> [10, 10, 20, 20], >>> ]) >>> overlaps = bbox_overlaps(bboxes1, bboxes2) >>> assert overlaps.shape == (3, 3) >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True) >>> assert overlaps.shape == (3, ) Example: >>> empty = torch.empty(0, 4) >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]]) >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) """ assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}' # Either the boxes are empty or the length of boxes' last dimension is 4 assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) # Batch dim must be the same # Batch dim: (B1, B2, ... Bn) assert bboxes1.shape[:-2] == bboxes2.shape[:-2] batch_shape = bboxes1.shape[:-2] rows = bboxes1.size(-2) cols = bboxes2.size(-2) if is_aligned: assert rows == cols if rows * cols == 0: if is_aligned: return bboxes1.new(batch_shape + (rows, )) else: return bboxes1.new(batch_shape + (rows, cols)) area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( bboxes1[..., 3] - bboxes1[..., 1]) area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( bboxes2[..., 3] - bboxes2[..., 1]) if is_aligned: lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] wh = fp16_clamp(rb - lt, min=0) overlap = wh[..., 0] * wh[..., 1] if mode in ['iou', 'giou']: union = area1 + area2 - overlap else: union = area1 if mode == 'giou': enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) else: lt = torch.max(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) # [B, rows, cols, 2] rb = torch.min(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] wh = fp16_clamp(rb - lt, min=0) overlap = wh[..., 0] * wh[..., 1] if mode in ['iou', 'giou']: union = area1[..., None] + area2[..., None, :] - overlap else: union = area1[..., None] if mode == 'giou': enclosed_lt = torch.min(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) enclosed_rb = torch.max(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) eps = union.new_tensor([eps]) union = torch.max(union, eps) ious = overlap / union if mode in ['iou', 'iof']: return ious # calculate gious enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0) enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] enclose_area = torch.max(enclose_area, eps) gious = ious - (enclose_area - union) / enclose_area return gious def bbox2distance(points, bbox, max_dis=None, eps=0.1): """Decode bounding box based on distances. Args: points (Tensor): Shape (n, 2), [x, y]. bbox (Tensor): Shape (n, 4), "xyxy" format max_dis (float): Upper bound of the distance. eps (float): a small value to ensure target < max_dis, instead <= Returns: Tensor: Decoded distances. """ left = points[:, 0] - bbox[:, 0] top = points[:, 1] - bbox[:, 1] right = bbox[:, 2] - points[:, 0] bottom = bbox[:, 3] - points[:, 1] if max_dis is not None: left = left.clamp(min=0, max=max_dis - eps) top = top.clamp(min=0, max=max_dis - eps) right = right.clamp(min=0, max=max_dis - eps) bottom = bottom.clamp(min=0, max=max_dis - eps) return torch.stack([left, top, right, bottom], -1) def distance2bbox(points, distance, max_shape=None): """Decode distance prediction to bounding box. Args: points (Tensor): Shape (B, N, 2) or (N, 2). distance (Tensor): Distance from the given point to 4 boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4) max_shape (Sequence[int] or torch.Tensor or Sequence[ Sequence[int]],optional): Maximum bounds for boxes, specifies (H, W, C) or (H, W). If priors shape is (B, N, 4), then the max_shape should be a Sequence[Sequence[int]] and the length of max_shape should also be B. Returns: Tensor: Boxes with shape (N, 4) or (B, N, 4) """ x1 = points[..., 0] - distance[..., 0] y1 = points[..., 1] - distance[..., 1] x2 = points[..., 0] + distance[..., 2] y2 = points[..., 1] + distance[..., 3] bboxes = torch.stack([x1, y1, x2, y2], -1) if max_shape is not None: if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export(): # speed up bboxes[:, 0::2].clamp_(min=0, max=max_shape[1]) bboxes[:, 1::2].clamp_(min=0, max=max_shape[0]) return bboxes if not isinstance(max_shape, torch.Tensor): max_shape = x1.new_tensor(max_shape) max_shape = max_shape[..., :2].type_as(x1) if max_shape.ndim == 2: assert bboxes.ndim == 3 assert max_shape.size(0) == bboxes.size(0) min_xy = x1.new_tensor(0) max_xy = torch.cat([max_shape, max_shape], dim=-1).flip(-1).unsqueeze(-2) bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) return bboxes