154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Sequence, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from mmengine.utils import is_str
|
|
|
|
if hasattr(torch, 'tensor_split'):
|
|
tensor_split = torch.tensor_split
|
|
else:
|
|
# A simple implementation of `tensor_split`.
|
|
def tensor_split(input: torch.Tensor, indices: list):
|
|
outs = []
|
|
for start, end in zip([0] + indices, indices + [input.size(0)]):
|
|
outs.append(input[start:end])
|
|
return outs
|
|
|
|
|
|
LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int]
|
|
SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence]
|
|
|
|
|
|
def format_label(value: LABEL_TYPE) -> torch.Tensor:
|
|
"""Convert various python types to label-format tensor.
|
|
|
|
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
|
:class:`Sequence`, :class:`int`.
|
|
|
|
Args:
|
|
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
|
|
|
|
Returns:
|
|
:obj:`torch.Tensor`: The foramtted label tensor.
|
|
"""
|
|
|
|
# Handle single number
|
|
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
|
|
value = int(value.item())
|
|
|
|
if isinstance(value, np.ndarray):
|
|
value = torch.from_numpy(value).to(torch.long)
|
|
elif isinstance(value, Sequence) and not is_str(value):
|
|
value = torch.tensor(value).to(torch.long)
|
|
elif isinstance(value, int):
|
|
value = torch.LongTensor([value])
|
|
elif not isinstance(value, torch.Tensor):
|
|
raise TypeError(f'Type {type(value)} is not an available label type.')
|
|
assert value.ndim == 1, \
|
|
f'The dims of value should be 1, but got {value.ndim}.'
|
|
|
|
return value
|
|
|
|
|
|
def format_score(value: SCORE_TYPE) -> torch.Tensor:
|
|
"""Convert various python types to score-format tensor.
|
|
|
|
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
|
:class:`Sequence`.
|
|
|
|
Args:
|
|
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
|
|
|
|
Returns:
|
|
:obj:`torch.Tensor`: The foramtted score tensor.
|
|
"""
|
|
|
|
if isinstance(value, np.ndarray):
|
|
value = torch.from_numpy(value).float()
|
|
elif isinstance(value, Sequence) and not is_str(value):
|
|
value = torch.tensor(value).float()
|
|
elif not isinstance(value, torch.Tensor):
|
|
raise TypeError(f'Type {type(value)} is not an available label type.')
|
|
assert value.ndim == 1, \
|
|
f'The dims of value should be 1, but got {value.ndim}.'
|
|
|
|
return value
|
|
|
|
|
|
def cat_batch_labels(elements: List[torch.Tensor]):
|
|
"""Concat a batch of label tensor to one tensor.
|
|
|
|
Args:
|
|
elements (List[tensor]): A batch of labels.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, List[int]]: The first item is the concated label
|
|
tensor, and the second item is the split indices of every sample.
|
|
"""
|
|
labels = []
|
|
splits = [0]
|
|
for element in elements:
|
|
labels.append(element)
|
|
splits.append(splits[-1] + element.size(0))
|
|
batch_label = torch.cat(labels)
|
|
return batch_label, splits[1:-1]
|
|
|
|
|
|
def batch_label_to_onehot(batch_label, split_indices, num_classes):
|
|
"""Convert a concated label tensor to onehot format.
|
|
|
|
Args:
|
|
batch_label (torch.Tensor): A concated label tensor from multiple
|
|
samples.
|
|
split_indices (List[int]): The split indices of every sample.
|
|
num_classes (int): The number of classes.
|
|
|
|
Returns:
|
|
torch.Tensor: The onehot format label tensor.
|
|
|
|
Examples:
|
|
>>> import torch
|
|
>>> from mmpretrain.structures import batch_label_to_onehot
|
|
>>> # Assume a concated label from 3 samples.
|
|
>>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1]
|
|
>>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1])
|
|
>>> split_indices = [2, 5]
|
|
>>> batch_label_to_onehot(batch_label, split_indices, num_classes=5)
|
|
tensor([[1, 1, 0, 0, 0],
|
|
[1, 0, 1, 0, 1],
|
|
[0, 1, 0, 1, 0]])
|
|
"""
|
|
sparse_onehot_list = F.one_hot(batch_label, num_classes)
|
|
onehot_list = [
|
|
sparse_onehot.sum(0)
|
|
for sparse_onehot in tensor_split(sparse_onehot_list, split_indices)
|
|
]
|
|
return torch.stack(onehot_list)
|
|
|
|
|
|
def label_to_onehot(label: LABEL_TYPE, num_classes: int):
|
|
"""Convert a label to onehot format tensor.
|
|
|
|
Args:
|
|
label (LABEL_TYPE): Label value.
|
|
num_classes (int): The number of classes.
|
|
|
|
Returns:
|
|
torch.Tensor: The onehot format label tensor.
|
|
|
|
Examples:
|
|
>>> import torch
|
|
>>> from mmpretrain.structures import label_to_onehot
|
|
>>> # Single-label
|
|
>>> label_to_onehot(1, num_classes=5)
|
|
tensor([0, 1, 0, 0, 0])
|
|
>>> # Multi-label
|
|
>>> label_to_onehot([0, 2, 3], num_classes=5)
|
|
tensor([1, 0, 1, 1, 0])
|
|
"""
|
|
label = format_label(label)
|
|
sparse_onehot = F.one_hot(label, num_classes)
|
|
return sparse_onehot.sum(0)
|