EasyCV/easycv/models/detection/utils/misc.py

214 lines
7.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import numpy as np
import torch
import torchvision
from packaging import version
from torch import Tensor
from torch.autograd import Function
if version.parse(torchvision.__version__) < version.parse('0.7'):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
@torch.no_grad()
def accuracy(output, target, topk=(1, )):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse('0.7'):
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor,
mode, align_corners)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor,
mode, align_corners)
def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
"""Extract a multi-scale single image tensor from a multi-scale batch
tensor based on batch index.
Note: The default value of detach is True, because the proposal gradient
needs to be detached during the training of the two-stage model. E.g
Cascade Mask R-CNN.
Args:
mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
each is a 4D-tensor.
batch_id (int): Batch index.
detach (bool): Whether detach gradient. Default True.
Returns:
list[Tensor]: Multi-scale single image tensor.
"""
assert isinstance(mlvl_tensors, (list, tuple))
num_levels = len(mlvl_tensors)
if detach:
mlvl_tensor_list = [
mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
]
else:
mlvl_tensor_list = [
mlvl_tensors[i][batch_id] for i in range(num_levels)
]
return mlvl_tensor_list
def filter_scores_and_topk(scores, score_thr, topk, results=None):
"""Filter results using score threshold and topk candidates.
Args:
scores (Tensor): The scores, shape (num_bboxes, K).
score_thr (float): The score filter threshold.
topk (int): The number of topk candidates.
results (dict or list or Tensor, Optional): The results to
which the filtering rule is to be applied. The shape
of each item is (num_bboxes, N).
Returns:
tuple: Filtered results
- scores (Tensor): The scores after being filtered, \
shape (num_bboxes_filtered, ).
- labels (Tensor): The class labels, shape \
(num_bboxes_filtered, ).
- anchor_idxs (Tensor): The anchor indexes, shape \
(num_bboxes_filtered, ).
- filtered_results (dict or list or Tensor, Optional): \
The filtered results. The shape of each item is \
(num_bboxes_filtered, N).
"""
valid_mask = scores > score_thr
scores = scores[valid_mask]
valid_idxs = torch.nonzero(valid_mask)
num_topk = min(topk, valid_idxs.size(0))
# torch.sort is actually faster than .topk (at least on GPUs)
scores, idxs = scores.sort(descending=True)
scores = scores[:num_topk]
topk_idxs = valid_idxs[idxs[:num_topk]]
keep_idxs, labels = topk_idxs.unbind(dim=1)
filtered_results = None
if results is not None:
if isinstance(results, dict):
filtered_results = {k: v[keep_idxs] for k, v in results.items()}
elif isinstance(results, list):
filtered_results = [result[keep_idxs] for result in results]
elif isinstance(results, torch.Tensor):
filtered_results = results[keep_idxs]
else:
raise NotImplementedError(f'Only supports dict or list or Tensor, '
f'but get {type(results)}.')
return scores, labels, keep_idxs, filtered_results
def output_postprocess(outputs, img_metas=None):
detection_boxes = []
detection_scores = []
detection_classes = []
img_metas_list = []
for i in range(len(outputs)):
if img_metas:
img_metas_list.append(img_metas[i])
if outputs[i] is not None:
bboxes = outputs[i][:, 0:4] if outputs[i] is not None else None
if img_metas:
bboxes /= img_metas[i]['scale_factor'][0]
detection_boxes.append(bboxes.cpu().numpy())
detection_scores.append(
(outputs[i][:, 4] * outputs[i][:, 5]).cpu().numpy())
detection_classes.append(outputs[i][:, 6].cpu().numpy().astype(
np.int32))
else:
detection_boxes.append(None)
detection_scores.append(None)
detection_classes.append(None)
test_outputs = {
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'img_metas': img_metas_list
}
return test_outputs
def fp16_clamp(x, min=None, max=None):
if not x.is_cuda and x.dtype == torch.float16:
# clamp for cpu float16, tensor fp16 has no clamp implementation
return x.float().clamp(min, max).half()
return x.clamp(min, max)
def inverse_sigmoid(x, eps=1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
class SigmoidGeometricMean(Function):
"""Forward and backward function of geometric mean of two sigmoid
functions.
This implementation with analytical gradient function substitutes
the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
original implementation incurs none during gradient backprapagation
if both x and y are very small values.
"""
@staticmethod
def forward(ctx, x, y):
x_sigmoid = x.sigmoid()
y_sigmoid = y.sigmoid()
z = (x_sigmoid * y_sigmoid).sqrt()
ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
return z
@staticmethod
def backward(ctx, grad_output):
x_sigmoid, y_sigmoid, z = ctx.saved_tensors
grad_x = grad_output * z * (1 - x_sigmoid) / 2
grad_y = grad_output * z * (1 - y_sigmoid) / 2
return grad_x, grad_y
sigmoid_geometric_mean = SigmoidGeometricMean.apply