mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
110 lines
3.9 KiB
Python
110 lines
3.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from functools import partial
|
|
|
|
import torch
|
|
|
|
from easycv.framework.errors import NotImplementedError
|
|
|
|
|
|
def multi_apply(func, *args, **kwargs):
|
|
"""Apply function to a list of arguments.
|
|
|
|
Note:
|
|
This function applies the ``func`` to multiple inputs and
|
|
map the multiple outputs of the ``func`` into different
|
|
list. Each list contains the same type of outputs corresponding
|
|
to different inputs.
|
|
|
|
Args:
|
|
func (Function): A function that will be applied to a list of
|
|
arguments
|
|
|
|
Returns:
|
|
tuple(list): A tuple containing multiple list, each list contains \
|
|
a kind of returned results by the function
|
|
"""
|
|
pfunc = partial(func, **kwargs) if kwargs else func
|
|
map_results = map(pfunc, *args)
|
|
return tuple(map(list, zip(*map_results)))
|
|
|
|
|
|
def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
|
|
"""Extract a multi-scale single image tensor from a multi-scale batch
|
|
tensor based on batch index.
|
|
|
|
Note: The default value of detach is True, because the proposal gradient
|
|
needs to be detached during the training of the two-stage model. E.g
|
|
Cascade Mask R-CNN.
|
|
|
|
Args:
|
|
mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
|
|
each is a 4D-tensor.
|
|
batch_id (int): Batch index.
|
|
detach (bool): Whether detach gradient. Default True.
|
|
|
|
Returns:
|
|
list[Tensor]: Multi-scale single image tensor.
|
|
"""
|
|
assert isinstance(mlvl_tensors, (list, tuple))
|
|
num_levels = len(mlvl_tensors)
|
|
|
|
if detach:
|
|
mlvl_tensor_list = [
|
|
mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
|
|
]
|
|
else:
|
|
mlvl_tensor_list = [
|
|
mlvl_tensors[i][batch_id] for i in range(num_levels)
|
|
]
|
|
return mlvl_tensor_list
|
|
|
|
|
|
def filter_scores_and_topk(scores, score_thr, topk, results=None):
|
|
"""Filter results using score threshold and topk candidates.
|
|
|
|
Args:
|
|
scores (Tensor): The scores, shape (num_bboxes, K).
|
|
score_thr (float): The score filter threshold.
|
|
topk (int): The number of topk candidates.
|
|
results (dict or list or Tensor, Optional): The results to
|
|
which the filtering rule is to be applied. The shape
|
|
of each item is (num_bboxes, N).
|
|
|
|
Returns:
|
|
tuple: Filtered results
|
|
|
|
- scores (Tensor): The scores after being filtered, \
|
|
shape (num_bboxes_filtered, ).
|
|
- labels (Tensor): The class labels, shape \
|
|
(num_bboxes_filtered, ).
|
|
- anchor_idxs (Tensor): The anchor indexes, shape \
|
|
(num_bboxes_filtered, ).
|
|
- filtered_results (dict or list or Tensor, Optional): \
|
|
The filtered results. The shape of each item is \
|
|
(num_bboxes_filtered, N).
|
|
"""
|
|
valid_mask = scores > score_thr
|
|
scores = scores[valid_mask]
|
|
valid_idxs = torch.nonzero(valid_mask)
|
|
|
|
num_topk = min(topk, valid_idxs.size(0))
|
|
# torch.sort is actually faster than .topk (at least on GPUs)
|
|
scores, idxs = scores.sort(descending=True)
|
|
scores = scores[:num_topk]
|
|
topk_idxs = valid_idxs[idxs[:num_topk]]
|
|
keep_idxs, labels = topk_idxs.unbind(dim=1)
|
|
|
|
filtered_results = None
|
|
if results is not None:
|
|
if isinstance(results, dict):
|
|
filtered_results = {k: v[keep_idxs] for k, v in results.items()}
|
|
elif isinstance(results, list):
|
|
filtered_results = [result[keep_idxs] for result in results]
|
|
elif isinstance(results, torch.Tensor):
|
|
filtered_results = results[keep_idxs]
|
|
else:
|
|
raise NotImplementedError(f'Only supports dict or list or Tensor, '
|
|
f'but get {type(results)}.')
|
|
return scores, labels, keep_idxs, filtered_results
|