mirror of https://github.com/alibaba/EasyCV.git
214 lines
7.2 KiB
Python
214 lines
7.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
from packaging import version
|
|
from torch import Tensor
|
|
from torch.autograd import Function
|
|
|
|
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
|
from torchvision.ops import _new_empty_tensor
|
|
from torchvision.ops.misc import _output_size
|
|
|
|
|
|
@torch.no_grad()
|
|
def accuracy(output, target, topk=(1, )):
|
|
"""Computes the precision@k for the specified values of k"""
|
|
if target.numel() == 0:
|
|
return [torch.zeros([], device=output.device)]
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
res = []
|
|
for k in topk:
|
|
correct_k = correct[:k].view(-1).float().sum(0)
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
return res
|
|
|
|
|
|
def interpolate(input,
|
|
size=None,
|
|
scale_factor=None,
|
|
mode='nearest',
|
|
align_corners=None):
|
|
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
|
"""
|
|
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
|
This will eventually be supported natively by PyTorch, and this
|
|
class can go away.
|
|
"""
|
|
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
|
if input.numel() > 0:
|
|
return torch.nn.functional.interpolate(input, size, scale_factor,
|
|
mode, align_corners)
|
|
|
|
output_shape = _output_size(2, input, size, scale_factor)
|
|
output_shape = list(input.shape[:-2]) + list(output_shape)
|
|
return _new_empty_tensor(input, output_shape)
|
|
else:
|
|
return torchvision.ops.misc.interpolate(input, size, scale_factor,
|
|
mode, align_corners)
|
|
|
|
|
|
def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
|
|
"""Extract a multi-scale single image tensor from a multi-scale batch
|
|
tensor based on batch index.
|
|
|
|
Note: The default value of detach is True, because the proposal gradient
|
|
needs to be detached during the training of the two-stage model. E.g
|
|
Cascade Mask R-CNN.
|
|
|
|
Args:
|
|
mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
|
|
each is a 4D-tensor.
|
|
batch_id (int): Batch index.
|
|
detach (bool): Whether detach gradient. Default True.
|
|
|
|
Returns:
|
|
list[Tensor]: Multi-scale single image tensor.
|
|
"""
|
|
assert isinstance(mlvl_tensors, (list, tuple))
|
|
num_levels = len(mlvl_tensors)
|
|
|
|
if detach:
|
|
mlvl_tensor_list = [
|
|
mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
|
|
]
|
|
else:
|
|
mlvl_tensor_list = [
|
|
mlvl_tensors[i][batch_id] for i in range(num_levels)
|
|
]
|
|
return mlvl_tensor_list
|
|
|
|
|
|
def filter_scores_and_topk(scores, score_thr, topk, results=None):
|
|
"""Filter results using score threshold and topk candidates.
|
|
|
|
Args:
|
|
scores (Tensor): The scores, shape (num_bboxes, K).
|
|
score_thr (float): The score filter threshold.
|
|
topk (int): The number of topk candidates.
|
|
results (dict or list or Tensor, Optional): The results to
|
|
which the filtering rule is to be applied. The shape
|
|
of each item is (num_bboxes, N).
|
|
|
|
Returns:
|
|
tuple: Filtered results
|
|
|
|
- scores (Tensor): The scores after being filtered, \
|
|
shape (num_bboxes_filtered, ).
|
|
- labels (Tensor): The class labels, shape \
|
|
(num_bboxes_filtered, ).
|
|
- anchor_idxs (Tensor): The anchor indexes, shape \
|
|
(num_bboxes_filtered, ).
|
|
- filtered_results (dict or list or Tensor, Optional): \
|
|
The filtered results. The shape of each item is \
|
|
(num_bboxes_filtered, N).
|
|
"""
|
|
valid_mask = scores > score_thr
|
|
scores = scores[valid_mask]
|
|
valid_idxs = torch.nonzero(valid_mask)
|
|
|
|
num_topk = min(topk, valid_idxs.size(0))
|
|
# torch.sort is actually faster than .topk (at least on GPUs)
|
|
scores, idxs = scores.sort(descending=True)
|
|
scores = scores[:num_topk]
|
|
topk_idxs = valid_idxs[idxs[:num_topk]]
|
|
keep_idxs, labels = topk_idxs.unbind(dim=1)
|
|
|
|
filtered_results = None
|
|
if results is not None:
|
|
if isinstance(results, dict):
|
|
filtered_results = {k: v[keep_idxs] for k, v in results.items()}
|
|
elif isinstance(results, list):
|
|
filtered_results = [result[keep_idxs] for result in results]
|
|
elif isinstance(results, torch.Tensor):
|
|
filtered_results = results[keep_idxs]
|
|
else:
|
|
raise NotImplementedError(f'Only supports dict or list or Tensor, '
|
|
f'but get {type(results)}.')
|
|
return scores, labels, keep_idxs, filtered_results
|
|
|
|
|
|
def output_postprocess(outputs, img_metas=None):
|
|
detection_boxes = []
|
|
detection_scores = []
|
|
detection_classes = []
|
|
img_metas_list = []
|
|
|
|
for i in range(len(outputs)):
|
|
if img_metas:
|
|
img_metas_list.append(img_metas[i])
|
|
if outputs[i] is not None:
|
|
bboxes = outputs[i][:, 0:4] if outputs[i] is not None else None
|
|
if img_metas:
|
|
bboxes /= img_metas[i]['scale_factor'][0]
|
|
detection_boxes.append(bboxes.cpu().numpy())
|
|
detection_scores.append(
|
|
(outputs[i][:, 4] * outputs[i][:, 5]).cpu().numpy())
|
|
detection_classes.append(outputs[i][:, 6].cpu().numpy().astype(
|
|
np.int32))
|
|
else:
|
|
detection_boxes.append(None)
|
|
detection_scores.append(None)
|
|
detection_classes.append(None)
|
|
|
|
test_outputs = {
|
|
'detection_boxes': detection_boxes,
|
|
'detection_scores': detection_scores,
|
|
'detection_classes': detection_classes,
|
|
'img_metas': img_metas_list
|
|
}
|
|
|
|
return test_outputs
|
|
|
|
|
|
def fp16_clamp(x, min=None, max=None):
|
|
if not x.is_cuda and x.dtype == torch.float16:
|
|
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
|
return x.float().clamp(min, max).half()
|
|
|
|
return x.clamp(min, max)
|
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-3):
|
|
x = x.clamp(min=0, max=1)
|
|
x1 = x.clamp(min=eps)
|
|
x2 = (1 - x).clamp(min=eps)
|
|
return torch.log(x1 / x2)
|
|
|
|
|
|
class SigmoidGeometricMean(Function):
|
|
"""Forward and backward function of geometric mean of two sigmoid
|
|
functions.
|
|
|
|
This implementation with analytical gradient function substitutes
|
|
the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
|
|
original implementation incurs none during gradient backprapagation
|
|
if both x and y are very small values.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
x_sigmoid = x.sigmoid()
|
|
y_sigmoid = y.sigmoid()
|
|
z = (x_sigmoid * y_sigmoid).sqrt()
|
|
ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
|
|
return z
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x_sigmoid, y_sigmoid, z = ctx.saved_tensors
|
|
grad_x = grad_output * z * (1 - x_sigmoid) / 2
|
|
grad_y = grad_output * z * (1 - y_sigmoid) / 2
|
|
return grad_x, grad_y
|
|
|
|
|
|
sigmoid_geometric_mean = SigmoidGeometricMean.apply
|