[Refactor] Refactor ClsDatasample to a union DataSample. (#1371)

* [Refactor] Refactor ClsDatasample to a union DataSample.

* Add  method

* Fix docstring

* Update docstring.
pull/1380/head
Ma Zerun 2023-02-23 10:07:53 +08:00 committed by GitHub
parent 4016f1348e
commit 36bea13fca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 659 additions and 755 deletions

View File

@ -1,13 +1,13 @@
.. role:: hidden
:class: hidden-section
.. module:: mmcls.structures
.. module:: mmpretrain.structures
mmcls.structures
mmpretrain.structures
===================================
This package includes basic data structures for classification tasks.
This package includes basic data structures.
ClsDataSample
DataSample
-------------
.. autoclass:: ClsDataSample
.. autoclass:: DataSample

View File

@ -12,7 +12,7 @@ from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from .model import get_model, init_model, list_models
ModelType = Union[BaseModel, str, Config]
@ -176,7 +176,7 @@ class ImageClassificationInferencer(BaseInferencer):
def visualize(self,
ori_inputs: List[InputType],
preds: List[ClsDataSample],
preds: List[DataSample],
show: bool = False,
rescale_factor: Optional[float] = None,
draw_score=True,
@ -223,7 +223,7 @@ class ImageClassificationInferencer(BaseInferencer):
return visualization
def postprocess(self,
preds: List[ClsDataSample],
preds: List[DataSample],
visualization: List[np.ndarray],
return_datasamples=False) -> dict:
if return_datasamples:
@ -231,14 +231,13 @@ class ImageClassificationInferencer(BaseInferencer):
results = []
for data_sample in preds:
prediction = data_sample.pred_label
pred_scores = prediction.score.detach().cpu().numpy()
pred_score = torch.max(prediction.score).item()
pred_label = torch.argmax(prediction.score).item()
pred_scores = data_sample.pred_score
pred_score = float(torch.max(pred_scores).item())
pred_label = torch.argmax(pred_scores).item()
result = {
'pred_scores': pred_scores,
'pred_scores': pred_scores.detach().cpu().numpy(),
'pred_label': pred_label,
'pred_score': float(pred_score),
'pred_score': pred_score,
}
if self.classes is not None:
result['pred_class'] = self.classes[pred_label]

View File

@ -10,7 +10,7 @@ from mmengine.utils import is_str
from PIL import Image
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
def to_tensor(data):
@ -53,7 +53,7 @@ class PackClsInputs(BaseTransform):
**Added Keys:**
- inputs (:obj:`torch.Tensor`): The forward data of models.
- data_samples (:obj:`~mmpretrain.structures.ClsDataSample`): The
- data_samples (:obj:`~mmpretrain.structures.DataSample`): The
annotation info of the sample.
Args:
@ -87,10 +87,11 @@ class PackClsInputs(BaseTransform):
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
data_sample = ClsDataSample()
data_sample = DataSample()
if 'gt_label' in results:
gt_label = results['gt_label']
data_sample.set_gt_label(gt_label)
data_sample.set_gt_label(results['gt_label'])
if 'gt_score' in results:
data_sample.set_gt_score(results['gt_score'])
img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)

View File

@ -9,7 +9,7 @@ from mmengine.runner import EpochBasedTrainLoop, Runner
from mmengine.visualization import Visualizer
from mmpretrain.registry import HOOKS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
@HOOKS.register_module()
@ -51,14 +51,14 @@ class VisualizationHook(Hook):
def _draw_samples(self,
batch_idx: int,
data_batch: dict,
data_samples: Sequence[ClsDataSample],
data_samples: Sequence[DataSample],
step: int = 0) -> None:
"""Visualize every ``self.interval`` samples from a data batch.
Args:
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
outputs (Sequence[:obj:`DataSample`]): Outputs from model.
step (int): Global step value to record. Defaults to 0.
"""
if self.enable is False:
@ -97,14 +97,14 @@ class VisualizationHook(Hook):
)
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[ClsDataSample]) -> None:
outputs: Sequence[DataSample]) -> None:
"""Visualize every ``self.interval`` samples during validation.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
outputs (Sequence[:obj:`DataSample`]): Outputs from model.
"""
if isinstance(runner.train_loop, EpochBasedTrainLoop):
step = runner.epoch
@ -114,7 +114,7 @@ class VisualizationHook(Hook):
self._draw_samples(batch_idx, data_batch, outputs, step=step)
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[ClsDataSample]) -> None:
outputs: Sequence[DataSample]) -> None:
"""Visualize every ``self.interval`` samples during test.
Args:

View File

@ -5,9 +5,9 @@ import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from mmengine.structures import LabelData
from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .single_label import _precision_recall_f1_support, to_tensor
@ -114,10 +114,10 @@ class MultiLabelMetric(BaseMetric):
(tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8))
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_sampels = [
... ClsDataSample().set_pred_score(pred).set_gt_score(gt)
... DataSample().set_pred_score(pred).set_gt_score(gt)
... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))]
>>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5))
>>> evaluator.process(data_sampels)
@ -181,17 +181,15 @@ class MultiLabelMetric(BaseMetric):
"""
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
result['pred_score'] = pred_label['score'].clone()
result['pred_score'] = data_sample['pred_score'].clone()
num_classes = result['pred_score'].size()[-1]
if 'score' in gt_label:
result['gt_score'] = gt_label['score'].clone()
if 'gt_score' in data_sample:
result['gt_score'] = data_sample['gt_score'].clone()
else:
result['gt_score'] = LabelData.label_to_onehot(
gt_label['label'], num_classes)
result['gt_score'] = label_to_onehot(data_sample['gt_label'],
num_classes)
# Save the result to `self.results`.
self.results.append(result)
@ -331,8 +329,7 @@ class MultiLabelMetric(BaseMetric):
assert num_classes is not None, 'For index-type labels, ' \
'please specify `num_classes`.'
label = torch.stack([
LabelData.label_to_onehot(
to_tensor(indices), num_classes)
label_to_onehot(indices, num_classes)
for indices in label
])
else:
@ -479,10 +476,10 @@ class AveragePrecision(BaseMetric):
>>> AveragePrecision.calculate(y_pred, y_true)
tensor(70.833)
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_pred_score(i).set_gt_score(j)
... DataSample().set_pred_score(i).set_gt_score(j)
... for i, j in zip(y_pred, y_true)
... ]
>>> evaluator = Evaluator(metrics=AveragePrecision())
@ -517,17 +514,15 @@ class AveragePrecision(BaseMetric):
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
result['pred_score'] = pred_label['score']
result['pred_score'] = data_sample['pred_score'].clone()
num_classes = result['pred_score'].size()[-1]
if 'score' in gt_label:
result['gt_score'] = gt_label['score']
if 'gt_score' in data_sample:
result['gt_score'] = data_sample['gt_score'].clone()
else:
result['gt_score'] = LabelData.label_to_onehot(
gt_label['label'], num_classes)
result['gt_score'] = label_to_onehot(data_sample['gt_label'],
num_classes)
# Save the result to `self.results`.
self.results.append(result)

View File

@ -5,10 +5,10 @@ import mmengine
import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.structures import LabelData
from mmengine.utils import is_seq_of
from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .single_label import to_tensor
@ -48,10 +48,10 @@ class RetrievalRecall(BaseMetric):
[tensor(9.3000), tensor(48.4000)]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_gt_label([0, 1]).set_pred_score(
... DataSample().set_gt_label([0, 1]).set_pred_score(
... torch.rand(10))
... for i in range(1000)
... ]
@ -95,23 +95,21 @@ class RetrievalRecall(BaseMetric):
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_label = data_sample['pred_label']
pred_score = data_sample['pred_score'].clone()
gt_label = data_sample['gt_label']
pred = pred_label['score'].clone()
if 'score' in gt_label:
target = gt_label['score'].clone()
if 'gt_score' in data_sample:
target = data_sample.get('gt_score').clone()
else:
num_classes = pred_label['score'].size()[-1]
target = LabelData.label_to_onehot(gt_label['label'],
num_classes)
num_classes = pred_score.size()[-1]
target = label_to_onehot(gt_label, num_classes)
# Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce
# all results at the end.
result = RetrievalRecall.calculate(
pred.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
self.results.append(result)
def compute_metrics(self, results: List):
@ -230,5 +228,5 @@ def _format_target(label, is_indices=False):
raise TypeError(f'The pred must be type of torch.tensor, '
f'np.ndarray or Sequence but get {type(label)}.')
indices = [LabelData.onehot_to_label(sample_gt) for sample_gt in label]
indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label]
return indices

View File

@ -104,10 +104,10 @@ class Accuracy(BaseMetric):
[[tensor([9.9000])], [tensor([51.5000])]]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10))
... DataSample().set_gt_label(0).set_pred_score(torch.rand(10))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
@ -150,13 +150,11 @@ class Accuracy(BaseMetric):
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
if 'pred_score' in data_sample:
result['pred_score'] = data_sample['pred_score'].cpu()
else:
result['pred_label'] = pred_label['label'].cpu()
result['gt_label'] = gt_label['label'].cpu()
result['pred_label'] = data_sample['pred_label'].cpu()
result['gt_label'] = data_sample['gt_label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
@ -358,10 +356,10 @@ class SingleLabelMetric(BaseMetric):
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample
>>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=SingleLabelMetric())
@ -418,19 +416,16 @@ class SingleLabelMetric(BaseMetric):
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
if 'pred_score' in data_sample:
result['pred_score'] = data_sample['pred_score'].cpu()
else:
num_classes = self.num_classes or data_sample.get(
'num_classes')
assert num_classes is not None, \
'The `num_classes` must be specified if `pred_label` has '\
'only `label`.'
result['pred_label'] = pred_label['label'].cpu()
'The `num_classes` must be specified if no `pred_score`.'
result['pred_label'] = data_sample['pred_label'].cpu()
result['num_classes'] = num_classes
result['gt_label'] = gt_label['label'].cpu()
result['gt_label'] = data_sample['gt_label'].cpu()
# Save the result to `self.results`.
self.results.append(result)
@ -641,17 +636,16 @@ class ConfusionMatrix(BaseMetric):
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
pred = data_sample['pred_label']
gt_label = data_sample['gt_label']['label']
if 'score' in pred:
pred_label = pred['score'].argmax(dim=0, keepdim=True)
self.num_classes = pred['score'].size(0)
if 'pred_score' in data_sample:
pred_score = data_sample['pred_score']
pred_label = pred_score.argmax(dim=0, keepdim=True)
self.num_classes = pred_score.size(0)
else:
pred_label = pred['label']
pred_label = data_sample['pred_label']
self.results.append({
'pred_label': pred_label,
'gt_label': gt_label
'gt_label': data_sample['gt_label'],
})
def compute_metrics(self, results: list) -> dict:

View File

@ -1,9 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from mmengine.structures import LabelData
from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .multi_label import AveragePrecision, MultiLabelMetric
@ -39,18 +38,16 @@ class VOCMetricMixin:
"""
for data_sample in data_samples:
result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
gt_label_difficult = data_sample['gt_label_difficult']
result['pred_score'] = pred_label['score'].clone()
result['pred_score'] = data_sample['pred_score'].clone()
num_classes = result['pred_score'].size()[-1]
if 'score' in gt_label:
result['gt_score'] = gt_label['score'].clone()
if 'gt_score' in data_sample:
result['gt_score'] = data_sample['gt_score'].clone()
else:
result['gt_score'] = LabelData.label_to_onehot(
gt_label['label'], num_classes)
result['gt_score'] = label_to_onehot(gt_label, num_classes)
# VOC annotation labels all the objects in a single image
# therefore, some categories are appeared both in
@ -58,7 +55,7 @@ class VOCMetricMixin:
# Here we reckon those labels which are only exists in difficult
# objects as difficult labels.
difficult_label = set(gt_label_difficult) - (
set(gt_label_difficult) & set(gt_label['label'].tolist()))
set(gt_label_difficult) & set(gt_label.tolist()))
# set difficult label for better eval
if self.difficult_as_positive is None:

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from .base import BaseClassifier
@ -122,14 +122,14 @@ class HuggingFaceClassifier(BaseClassifier):
raise NotImplementedError(
"The HuggingFaceClassifier doesn't support extract feature yet.")
def loss(self, inputs: torch.Tensor, data_samples: List[ClsDataSample],
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs):
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments of the loss module.
@ -144,14 +144,14 @@ class HuggingFaceClassifier(BaseClassifier):
return losses
def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs):
data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss."""
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
target = torch.stack([i.gt_score for i in data_samples])
else:
target = torch.cat([i.gt_label.label for i in data_samples])
target = torch.cat([i.gt_label for i in data_samples])
# compute loss
losses = dict()
@ -163,17 +163,17 @@ class HuggingFaceClassifier(BaseClassifier):
def predict(self,
inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None):
data_samples: Optional[List[DataSample]] = None):
"""Predict results from a batch of inputs.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
Returns:
List[ClsDataSample]: The prediction results.
List[DataSample]: The prediction results.
"""
# The part can be traced by torch.fx
cls_score = self.model(inputs).logits
@ -197,8 +197,8 @@ class HuggingFaceClassifier(BaseClassifier):
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score(
score).set_pred_label(label))
data_samples.append(
DataSample().set_pred_score(score).set_pred_label(label))
return data_samples

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from .base import BaseClassifier
@ -76,7 +76,7 @@ class ImageClassifier(BaseClassifier):
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None,
data_samples: Optional[List[DataSample]] = None,
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
@ -85,7 +85,7 @@ class ImageClassifier(BaseClassifier):
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`ClsDataSample`.
processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
@ -95,7 +95,7 @@ class ImageClassifier(BaseClassifier):
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
@ -105,7 +105,7 @@ class ImageClassifier(BaseClassifier):
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of
:obj:`mmpretrain.structures.ClsDataSample`.
:obj:`mmpretrain.structures.DataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'tensor':
@ -209,13 +209,13 @@ class ImageClassifier(BaseClassifier):
return self.head.pre_logits(x)
def loss(self, inputs: torch.Tensor,
data_samples: List[ClsDataSample]) -> dict:
data_samples: List[DataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
Returns:
@ -226,14 +226,14 @@ class ImageClassifier(BaseClassifier):
def predict(self,
inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None,
**kwargs) -> List[ClsDataSample]:
data_samples: Optional[List[DataSample]] = None,
**kwargs) -> List[DataSample]:
"""Predict results from a batch of inputs.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`.

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from .base import BaseClassifier
@ -110,14 +110,14 @@ class TimmClassifier(BaseClassifier):
f"The model {type(self.model)} doesn't support extract "
"feature because it don't have `forward_features` method.")
def loss(self, inputs: torch.Tensor, data_samples: List[ClsDataSample],
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs):
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments of the loss module.
@ -132,14 +132,14 @@ class TimmClassifier(BaseClassifier):
return losses
def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs):
data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss."""
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
target = torch.stack([i.gt_score for i in data_samples])
else:
target = torch.cat([i.gt_label.label for i in data_samples])
target = torch.cat([i.gt_label for i in data_samples])
# compute loss
losses = dict()
@ -150,17 +150,17 @@ class TimmClassifier(BaseClassifier):
def predict(self,
inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None):
data_samples: Optional[List[DataSample]] = None):
"""Predict results from a batch of inputs.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
Returns:
List[ClsDataSample]: The prediction results.
List[DataSample]: The prediction results.
"""
# The part can be traced by torch.fx
cls_score = self(inputs)
@ -184,8 +184,8 @@ class TimmClassifier(BaseClassifier):
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score(
score).set_pred_label(label))
data_samples.append(
DataSample().set_pred_score(score).set_pred_label(label))
return data_samples

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from mmpretrain.evaluation.metrics import Accuracy
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from .base_head import BaseHead
@ -57,8 +57,8 @@ class ClsHead(BaseHead):
# just return the unpacked inputs.
return pre_logits
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
**kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
@ -66,7 +66,7 @@ class ClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
@ -81,14 +81,14 @@ class ClsHead(BaseHead):
return losses
def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs):
data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss."""
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
target = torch.stack([i.gt_score for i in data_samples])
else:
target = torch.cat([i.gt_label.label for i in data_samples])
target = torch.cat([i.gt_label for i in data_samples])
# compute loss
losses = dict()
@ -110,8 +110,8 @@ class ClsHead(BaseHead):
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[Union[ClsDataSample, None]] = None
) -> List[ClsDataSample]:
data_samples: Optional[List[Optional[DataSample]]] = None
) -> List[DataSample]:
"""Inference without augmentation.
Args:
@ -119,12 +119,12 @@ class ClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample | None], optional): The annotation
data_samples (List[DataSample | None], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ClsDataSample]: A list of data samples which contains the
List[DataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
@ -149,7 +149,7 @@ class ClsHead(BaseHead):
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
if data_sample is None:
data_sample = ClsDataSample()
data_sample = DataSample()
data_sample.set_pred_score(score).set_pred_label(label)
out_data_samples.append(data_sample)

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmpretrain.evaluation.metrics import Accuracy
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from .cls_head import ClsHead
@ -64,10 +64,9 @@ class ConformerHead(ClsHead):
return conv_cls_score, tran_cls_score
def predict(
self,
feats: Tuple[List[torch.Tensor]],
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
def predict(self,
feats: Tuple[List[torch.Tensor]],
data_samples: List[DataSample] = None) -> List[DataSample]:
"""Inference without augmentation.
Args:
@ -75,12 +74,12 @@ class ConformerHead(ClsHead):
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ClsDataSample]: A list of data samples which contains the
List[DataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
@ -92,14 +91,14 @@ class ConformerHead(ClsHead):
return predictions
def _get_loss(self, cls_score: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
data_samples: List[DataSample], **kwargs) -> dict:
"""Unpack data samples and compute loss."""
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
target = torch.stack([i.gt_score for i in data_samples])
else:
target = torch.cat([i.gt_label.label for i in data_samples])
target = torch.cat([i.gt_label for i in data_samples])
# compute loss
losses = dict()

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from .cls_head import ClsHead
@ -65,8 +65,8 @@ class EfficientFormerClsHead(ClsHead):
# after unpacking.
return feats[-1]
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
**kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
@ -74,7 +74,7 @@ class EfficientFormerClsHead(ClsHead):
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.

View File

@ -11,7 +11,7 @@ from mmengine.utils import is_seq_of
from mmpretrain.models.losses import convert_to_one_hot
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from .cls_head import ClsHead
@ -264,8 +264,8 @@ class ArcFaceClsHead(ClsHead):
return self.scale * logit
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
**kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
@ -273,7 +273,7 @@ class ArcFaceClsHead(ClsHead):
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
@ -281,12 +281,11 @@ class ArcFaceClsHead(ClsHead):
dict[str, Tensor]: a dictionary of loss components
"""
# Unpack data samples and pack targets
label_target = torch.cat([i.gt_label.label for i in data_samples])
if 'score' in data_samples[0].gt_label:
label_target = torch.cat([i.gt_label for i in data_samples])
if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
target = torch.stack([i.gt_score for i in data_samples])
else:
# change the labels to to one-hot format scores.
target = label_target
# the index format target would be used

View File

@ -3,10 +3,9 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from mmengine.structures import LabelData
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample, label_to_onehot
from .base_head import BaseHead
@ -65,8 +64,8 @@ class MultiLabelClsHead(BaseHead):
# just return the unpacked inputs.
return pre_logits
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
**kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
@ -74,7 +73,7 @@ class MultiLabelClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
@ -89,19 +88,16 @@ class MultiLabelClsHead(BaseHead):
return losses
def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs):
data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss."""
num_classes = cls_score.size()[-1]
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
target = torch.stack(
[i.gt_label.score.float() for i in data_samples])
if 'gt_score' in data_samples[0]:
target = torch.stack([i.gt_score for i in data_samples])
else:
target = torch.stack([
LabelData.label_to_onehot(i.gt_label.label,
num_classes).float()
for i in data_samples
])
label_to_onehot(i.gt_label, num_classes) for i in data_samples
]).float()
# compute loss
losses = dict()
@ -111,10 +107,9 @@ class MultiLabelClsHead(BaseHead):
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
def predict(self,
feats: Tuple[torch.Tensor],
data_samples: List[DataSample] = None) -> List[DataSample]:
"""Inference without augmentation.
Args:
@ -122,12 +117,12 @@ class MultiLabelClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ClsDataSample]: A list of data samples which contains the
List[DataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
@ -138,7 +133,7 @@ class MultiLabelClsHead(BaseHead):
return predictions
def _get_predictions(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample]):
data_samples: List[DataSample]):
"""Post-process the output of head.
Including softmax and set ``pred_label`` of data samples.
@ -146,7 +141,7 @@ class MultiLabelClsHead(BaseHead):
pred_scores = torch.sigmoid(cls_score)
if data_samples is None:
data_samples = [ClsDataSample() for _ in range(cls_score.size(0))]
data_samples = [DataSample() for _ in range(cls_score.size(0))]
for data_sample, score in zip(data_samples, pred_scores):
if self.thr is not None:

View File

@ -63,7 +63,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
- "tensor": Forward the whole network and return tensor without any
post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`ClsDataSample`.
processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
@ -73,7 +73,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
Args:
inputs (torch.Tensor, tuple): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
@ -83,7 +83,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
- If ``mode="tensor"``, return a tensor.
- If ``mode="predict"``, return a list of
:obj:`mmpretrain.structures.ClsDataSample`.
:obj:`mmpretrain.structures.DataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
pass
@ -107,7 +107,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
Returns:

View File

@ -8,7 +8,7 @@ from mmengine.runner import Runner
from torch.utils.data import DataLoader
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
from .base import BaseRetriever
@ -114,7 +114,7 @@ class ImageToImageRetriever(BaseRetriever):
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None,
data_samples: Optional[List[DataSample]] = None,
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
@ -123,7 +123,7 @@ class ImageToImageRetriever(BaseRetriever):
- "tensor": Forward the whole network and return tensor without any
post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`ClsDataSample`.
processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
@ -133,7 +133,7 @@ class ImageToImageRetriever(BaseRetriever):
Args:
inputs (torch.Tensor, tuple): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
@ -143,7 +143,7 @@ class ImageToImageRetriever(BaseRetriever):
- If ``mode="tensor"``, return a tensor.
- If ``mode="predict"``, return a list of
:obj:`mmpretrain.structures.ClsDataSample`.
:obj:`mmpretrain.structures.DataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'tensor':
@ -169,13 +169,13 @@ class ImageToImageRetriever(BaseRetriever):
return feat
def loss(self, inputs: torch.Tensor,
data_samples: List[ClsDataSample]) -> dict:
data_samples: List[DataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of
data_samples (List[DataSample]): The annotation data of
every samples.
Returns:
@ -200,18 +200,18 @@ class ImageToImageRetriever(BaseRetriever):
def predict(self,
inputs: tuple,
data_samples: Optional[List[ClsDataSample]] = None,
**kwargs) -> List[ClsDataSample]:
data_samples: Optional[List[DataSample]] = None,
**kwargs) -> List[DataSample]:
"""Predict results from the extracted features.
Args:
inputs (tuple): The features extracted from the backbone.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`.
Returns:
List[ClsDataSample]: the raw data_samples with
List[DataSample]: the raw data_samples with
the predicted results
"""
if not self.prototype_inited:
@ -240,8 +240,8 @@ class ImageToImageRetriever(BaseRetriever):
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score(
score).set_pred_label(label))
data_samples.append(
DataSample().set_pred_score(score).set_pred_label(label))
return data_samples
def _get_prototype_vecs_from_dataloader(self, data_loader):

View File

@ -4,7 +4,7 @@ from typing import List
from mmengine.model import BaseTTAModel
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
@MODELS.register_module()
@ -12,16 +12,16 @@ class AverageClsScoreTTA(BaseTTAModel):
def merge_preds(
self,
data_samples_list: List[List[ClsDataSample]],
) -> List[ClsDataSample]:
data_samples_list: List[List[DataSample]],
) -> List[DataSample]:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[List[ClsDataSample]]): List of predictions
data_samples_list (List[List[DataSample]]): List of predictions
of all enhanced data.
Returns:
List[ClsDataSample]: Merged prediction.
List[DataSample]: Merged prediction.
"""
merged_data_samples = []
for data_samples in data_samples_list:
@ -29,8 +29,8 @@ class AverageClsScoreTTA(BaseTTAModel):
return merged_data_samples
def _merge_single_sample(self, data_samples):
merged_data_sample: ClsDataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_label.score
merged_data_sample: DataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
return merged_data_sample

View File

@ -8,9 +8,9 @@ import torch.nn.functional as F
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmpretrain.registry import MODELS
from mmpretrain.structures import (ClsDataSample, MultiTaskDataSample,
from mmpretrain.structures import (DataSample, MultiTaskDataSample,
batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
tensor_split)
from .batch_augments import RandomBatchAugment
@ -153,23 +153,28 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
data_samples = data.get('data_samples', None)
sample_item = data_samples[0] if data_samples is not None else None
if isinstance(sample_item,
ClsDataSample) and 'gt_label' in sample_item:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(
gt_labels, device=self.device)
batch_score = stack_batch_scores(gt_labels, device=self.device)
if batch_score is None and self.to_onehot:
if isinstance(sample_item, DataSample):
batch_label = None
batch_score = None
if 'gt_label' in sample_item:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(gt_labels)
batch_label = batch_label.to(self.device)
if 'gt_score' in sample_item:
gt_scores = [sample.gt_score for sample in data_samples]
batch_score = torch.stack(gt_scores).to(self.device)
elif self.to_onehot:
assert batch_label is not None, \
'Cannot generate onehot format labels because no labels.'
num_classes = self.num_classes or data_samples[0].get(
num_classes = self.num_classes or sample_item.get(
'num_classes')
assert num_classes is not None, \
'Cannot generate one-hot format labels because not set ' \
'`num_classes` in `data_preprocessor`.'
batch_score = batch_label_to_onehot(batch_label, label_indices,
num_classes)
batch_score = batch_label_to_onehot(
batch_label, label_indices, num_classes).to(self.device)
# ----- Batch Augmentations ----
if training and self.batch_augments is not None:

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cls_data_sample import ClsDataSample
from .data_sample import DataSample
from .multi_task_data_sample import MultiTaskDataSample
from .utils import (batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
from .utils import (batch_label_to_onehot, cat_batch_labels, label_to_onehot,
tensor_split)
__all__ = [
'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels',
'stack_batch_scores', 'tensor_split', 'MultiTaskDataSample'
'DataSample', 'batch_label_to_onehot', 'cat_batch_labels', 'tensor_split',
'MultiTaskDataSample', 'label_to_onehot'
]

View File

@ -1,235 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing.reduction import ForkingPickler
from numbers import Number
from typing import Sequence, Union
import numpy as np
import torch
from mmengine.structures import BaseDataElement, LabelData
from mmengine.utils import is_str
def format_label(
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
"""Convert various python types to label-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
Returns:
:obj:`torch.Tensor`: The foramtted label tensor.
"""
# Handle single number
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
value = int(value.item())
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).to(torch.long)
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).to(torch.long)
elif isinstance(value, int):
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
def format_score(
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
"""Convert various python types to score-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
Returns:
:obj:`torch.Tensor`: The foramtted score tensor.
"""
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).float()
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).float()
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
class ClsDataSample(BaseDataElement):
"""A data structure interface of classification task.
It's used as interfaces between different components.
Meta fields:
img_shape (Tuple): The shape of the corresponding input image.
Used for visualization.
ori_shape (Tuple): The original shape of the corresponding image.
Used for visualization.
num_classes (int): The number of all categories.
Used for label format conversion.
Data fields:
gt_label (:obj:`~mmengine.structures.LabelData`): The ground truth
label.
pred_label (:obj:`~mmengine.structures.LabelData`): The predicted
label.
scores (torch.Tensor): The outputs of model.
logits (torch.Tensor): The outputs of model without softmax nor
sigmoid.
Examples:
>>> import torch
>>> from mmpretrain.structures import ClsDataSample
>>>
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
>>> data_sample = ClsDataSample(metainfo=img_meta)
>>> data_sample.set_gt_label(3)
>>> print(data_sample)
<ClsDataSample(
META INFORMATION
num_classes = 5
img_shape = (960, 720)
DATA FIELDS
gt_label: <LabelData(
META INFORMATION
num_classes: 5
DATA FIELDS
label: tensor([3])
) at 0x7f21fb1b9190>
) at 0x7f21fb1b9880>
>>> # For multi-label data
>>> data_sample.set_gt_label([0, 1, 4])
>>> print(data_sample.gt_label)
<LabelData(
META INFORMATION
num_classes: 5
DATA FIELDS
label: tensor([0, 1, 4])
) at 0x7fd7d1b41970>
>>> # Set one-hot format score
>>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])
>>> data_sample.set_pred_score(score)
>>> print(data_sample.pred_label)
<LabelData(
META INFORMATION
num_classes: 5
DATA FIELDS
score: tensor([0.1, 0.1, 0.6, 0.1, 0.1])
) at 0x7fd7d1b41970>
"""
def set_gt_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ClsDataSample':
"""Set label of ``gt_label``."""
label_data = getattr(self, '_gt_label', LabelData())
label_data.label = format_label(value)
self.gt_label = label_data
return self
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``gt_label``."""
label_data = getattr(self, '_gt_label', LabelData())
label_data.score = format_score(value)
if hasattr(self, 'num_classes'):
assert len(label_data.score) == self.num_classes, \
f'The length of score {len(label_data.score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes',
value=len(label_data.score),
field_type='metainfo')
self.gt_label = label_data
return self
def set_pred_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ClsDataSample':
"""Set label of ``pred_label``."""
label_data = getattr(self, '_pred_label', LabelData())
label_data.label = format_label(value)
self.pred_label = label_data
return self
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``pred_label``."""
label_data = getattr(self, '_pred_label', LabelData())
label_data.score = format_score(value)
if hasattr(self, 'num_classes'):
assert len(label_data.score) == self.num_classes, \
f'The length of score {len(label_data.score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes',
value=len(label_data.score),
field_type='metainfo')
self.pred_label = label_data
return self
@property
def gt_label(self):
return self._gt_label
@gt_label.setter
def gt_label(self, value: LabelData):
self.set_field(value, '_gt_label', dtype=LabelData)
@gt_label.deleter
def gt_label(self):
del self._gt_label
@property
def pred_label(self):
return self._pred_label
@pred_label.setter
def pred_label(self, value: LabelData):
self.set_field(value, '_pred_label', dtype=LabelData)
@pred_label.deleter
def pred_label(self):
del self._pred_label
def _reduce_cls_datasample(data_sample):
"""reduce ClsDataSample."""
attr_dict = data_sample.__dict__
convert_keys = []
for k, v in attr_dict.items():
if isinstance(v, LabelData):
attr_dict[k] = v.numpy()
convert_keys.append(k)
return _rebuild_cls_datasample, (attr_dict, convert_keys)
def _rebuild_cls_datasample(attr_dict, convert_keys):
"""rebuild ClsDataSample."""
data_sample = ClsDataSample()
for k in convert_keys:
attr_dict[k] = attr_dict[k].to_tensor()
data_sample.__dict__ = attr_dict
return data_sample
# Due to the multi-processing strategy of PyTorch, ClsDataSample may consume
# many file descriptors because it contains multiple LabelData with tensors.
# Here we overwrite the reduce function of ClsDataSample in ForkingPickler and
# convert these tensors to np.ndarray during pickling. It may influence the
# performance of dataloader, but slightly because these tensors in LabelData
# are very small.
ForkingPickler.register(ClsDataSample, _reduce_cls_datasample)

View File

@ -0,0 +1,167 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing.reduction import ForkingPickler
from typing import Union
import numpy as np
import torch
from mmengine.structures import BaseDataElement
from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score
class DataSample(BaseDataElement):
"""A general data structure interface.
It's used as the interface between different components.
The following fields are convention names in MMPretrain, and we will set or
get these fields in data transforms, models, and metrics if needed. You can
also set any new fields for your need.
Meta fields:
img_shape (Tuple): The shape of the corresponding input image.
ori_shape (Tuple): The original shape of the corresponding image.
sample_idx (int): The index of the sample in the dataset.
num_classes (int): The number of all categories.
Data fields:
gt_label (tensor): The ground truth label.
gt_score (tensor): The ground truth score.
pred_label (tensor): The predicted label.
pred_score (tensor): The predicted score.
mask (tensor): The mask used in masked image modeling.
Examples:
>>> import torch
>>> from mmpretrain.structures import DataSample
>>>
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
>>> data_sample = DataSample(metainfo=img_meta)
>>> data_sample.set_gt_label(3)
>>> print(data_sample)
<DataSample(
META INFORMATION
num_classes: 5
img_shape: (960, 720)
DATA FIELDS
gt_label: tensor([3])
) at 0x7ff64c1c1d30>
>>>
>>> # For multi-label data
>>> data_sample = DataSample().set_gt_label([0, 1, 4])
>>> print(data_sample)
<DataSample(
DATA FIELDS
gt_label: tensor([0, 1, 4])
) at 0x7ff5b490e100>
>>>
>>> # Set one-hot format score
>>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1])
>>> print(data_sample)
<DataSample(
META INFORMATION
num_classes: 4
DATA FIELDS
pred_score: tensor([0.1000, 0.1000, 0.6000, 0.1000])
) at 0x7ff5b48ef6a0>
>>>
>>> # Set custom field
>>> data_sample = DataSample()
>>> data_sample.my_field = [1, 2, 3]
>>> print(data_sample)
<DataSample(
DATA FIELDS
my_field: [1, 2, 3]
) at 0x7f8e9603d3a0>
>>> print(data_sample.my_field)
[1, 2, 3]
"""
def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample':
"""Set ``gt_label``."""
self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor)
return self
def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample':
"""Set ``gt_score``."""
score = format_score(value)
self.set_field(score, 'gt_score', dtype=torch.Tensor)
if hasattr(self, 'num_classes'):
assert len(score) == self.num_classes, \
f'The length of score {len(score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes', value=len(score), field_type='metainfo')
return self
def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample':
"""Set ``pred_label``."""
self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor)
return self
def set_pred_score(self, value: SCORE_TYPE):
"""Set ``pred_label``."""
score = format_score(value)
self.set_field(score, 'pred_score', dtype=torch.Tensor)
if hasattr(self, 'num_classes'):
assert len(score) == self.num_classes, \
f'The length of score {len(score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes', value=len(score), field_type='metainfo')
return self
def set_mask(self, value: Union[torch.Tensor, np.ndarray]):
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Invalid mask type {type(value)}')
self.set_field(value, 'mask', dtype=torch.Tensor)
return self
def __repr__(self) -> str:
"""Represent the object."""
def dump_items(items, prefix=''):
return '\n'.join(f'{prefix}{k}: {v}' for k, v in items)
repr_ = ''
if len(self._metainfo_fields) > 0:
repr_ += '\n\nMETA INFORMATION\n'
repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4)
if len(self._data_fields) > 0:
repr_ += '\n\nDATA FIELDS\n'
repr_ += dump_items(self.items(), prefix=' ' * 4)
repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>'
return repr_
def _reduce_datasample(data_sample):
"""reduce DataSample."""
attr_dict = data_sample.__dict__
convert_keys = []
for k, v in attr_dict.items():
if isinstance(v, torch.Tensor):
attr_dict[k] = v.numpy()
convert_keys.append(k)
return _rebuild_datasample, (attr_dict, convert_keys)
def _rebuild_datasample(attr_dict, convert_keys):
"""rebuild DataSample."""
data_sample = DataSample()
for k in convert_keys:
attr_dict[k] = torch.from_numpy(attr_dict[k])
data_sample.__dict__ = attr_dict
return data_sample
# Due to the multi-processing strategy of PyTorch, DataSample may consume many
# file descriptors because it contains multiple tensors. Here we overwrite the
# reduce function of DataSample in ForkingPickler and convert these tensors to
# np.ndarray during pickling. It may slightly influence the performance of
# dataloader.
ForkingPickler.register(DataSample, _reduce_datasample)

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from typing import List, Sequence, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.structures import LabelData
from mmengine.utils import is_str
if hasattr(torch, 'tensor_split'):
tensor_split = torch.tensor_split
@ -16,30 +17,82 @@ else:
return outs
def cat_batch_labels(elements: List[LabelData], device=None):
"""Concat the ``label`` of a batch of :obj:`LabelData` to a tensor.
LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int]
SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence]
def format_label(value: LABEL_TYPE) -> torch.Tensor:
"""Convert various python types to label-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int`.
Args:
elements (List[LabelData]): A batch of :obj`LabelData`.
device (torch.device, optional): The output device of the batch label.
Defaults to None.
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
Returns:
:obj:`torch.Tensor`: The foramtted label tensor.
"""
# Handle single number
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
value = int(value.item())
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).to(torch.long)
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).to(torch.long)
elif isinstance(value, int):
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
def format_score(value: SCORE_TYPE) -> torch.Tensor:
"""Convert various python types to score-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
Returns:
:obj:`torch.Tensor`: The foramtted score tensor.
"""
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).float()
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).float()
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
def cat_batch_labels(elements: List[torch.Tensor]):
"""Concat a batch of label tensor to one tensor.
Args:
elements (List[tensor]): A batch of labels.
Returns:
Tuple[torch.Tensor, List[int]]: The first item is the concated label
tensor, and the second item is the split indices of every sample.
"""
item = elements[0]
if 'label' not in item._data_fields:
return None, None
labels = []
splits = [0]
for element in elements:
labels.append(element.label)
splits.append(splits[-1] + element.label.size(0))
labels.append(element)
splits.append(splits[-1] + element.size(0))
batch_label = torch.cat(labels)
if device is not None:
batch_label = batch_label.to(device=device)
return batch_label, splits[1:-1]
@ -75,22 +128,26 @@ def batch_label_to_onehot(batch_label, split_indices, num_classes):
return torch.stack(onehot_list)
def stack_batch_scores(elements, device=None):
"""Stack the ``score`` of a batch of :obj:`LabelData` to a tensor.
def label_to_onehot(label: LABEL_TYPE, num_classes: int):
"""Convert a label to onehot format tensor.
Args:
elements (List[LabelData]): A batch of :obj`LabelData`.
device (torch.device, optional): The output device of the batch label.
Defaults to None.
label (LABEL_TYPE): Label value.
num_classes (int): The number of classes.
Returns:
torch.Tensor: The stacked score tensor.
"""
item = elements[0]
if 'score' not in item._data_fields:
return None
torch.Tensor: The onehot format label tensor.
batch_score = torch.stack([element.score for element in elements])
if device is not None:
batch_score = batch_score.to(device)
return batch_score
Examples:
>>> import torch
>>> from mmpretrain.structures import label_to_onehot
>>> # Single-label
>>> label_to_onehot(1, num_classes=5)
tensor([0, 1, 0, 0, 0])
>>> # Multi-label
>>> label_to_onehot([0, 2, 3], num_classes=5)
tensor([1, 0, 1, 1, 0])
"""
label = format_label(label)
sparse_onehot = F.one_hot(label, num_classes)
return sparse_onehot.sum(0)

View File

@ -7,7 +7,7 @@ from mmengine.dist import master_only
from mmengine.visualization import Visualizer
from mmpretrain.registry import VISUALIZERS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
def _get_adaptive_scale(img_shape: Tuple[int, int],
@ -57,11 +57,11 @@ class ClsVisualizer(Visualizer):
>>> import mmcv
>>> from pathlib import Path
>>> from mmpretrain.visualization import ClsVisualizer
>>> from mmpretrain.structures import ClsDataSample
>>> from mmpretrain.structures import DataSample
>>> # Example image
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
>>> # Example annotation
>>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
>>> data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
>>> # Setup the visualizer
>>> vis = ClsVisualizer(
@ -84,7 +84,7 @@ class ClsVisualizer(Visualizer):
def add_datasample(self,
name: str,
image: np.ndarray,
data_sample: Optional[ClsDataSample] = None,
data_sample: Optional[DataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
draw_score: bool = True,
@ -104,7 +104,7 @@ class ClsVisualizer(Visualizer):
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
data_sample (:obj:`ClsDataSample`, optional): The annotation of the
data_sample (:obj:`DataSample`, optional): The annotation of the
image. Defaults to None.
draw_gt (bool): Whether to draw ground truth labels.
Defaults to True.
@ -137,8 +137,7 @@ class ClsVisualizer(Visualizer):
self.set_image(image)
if draw_gt and 'gt_label' in data_sample:
gt_label = data_sample.gt_label
idx = gt_label.label.tolist()
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
@ -147,13 +146,12 @@ class ClsVisualizer(Visualizer):
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
if draw_pred and 'pred_label' in data_sample:
pred_label = data_sample.pred_label
idx = pred_label.label.tolist()
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'score' in pred_label:
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {pred_label.score[i].item():.2f}' for i in idx
f', {data_sample.pred_score[i].item():.2f}' for i in idx
]
if classes is not None:

View File

@ -9,7 +9,7 @@ from mmcv.image import imread
from mmpretrain.apis import (ImageClassificationInferencer, ModelHub,
get_model, inference_model)
from mmpretrain.models import MobileNetV3
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer
MODEL = 'mobilenet-v3-small-050_3rdparty_in1k'
@ -58,7 +58,7 @@ class TestImageClassificationInferencer(TestCase):
# test return_datasample=True
results = inferencer(img, return_datasamples=True)[0]
self.assertIsInstance(results, ClsDataSample)
self.assertIsInstance(results, DataSample)
def test_visualize(self):
img_path = osp.join(osp.dirname(__file__), '../data/color.jpg')

View File

@ -6,11 +6,10 @@ import unittest
import mmcv
import numpy as np
import torch
from mmengine.structures import LabelData
from PIL import Image
from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
class TestPackClsInputs(unittest.TestCase):
@ -34,9 +33,9 @@ class TestPackClsInputs(unittest.TestCase):
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], ClsDataSample)
self.assertIsInstance(results['data_samples'], DataSample)
self.assertIn('flip', results['data_samples'].metainfo_keys())
self.assertIsInstance(results['data_samples'].gt_label, LabelData)
self.assertIsInstance(results['data_samples'].gt_label, torch.Tensor)
# Test grayscale image
data['img'] = data['img'].mean(-1)
@ -155,7 +154,7 @@ class TestPackMultiTaskInputs(unittest.TestCase):
self.assertIsInstance(results['data_samples'], MultiTaskDataSample)
self.assertIn('flip', results['data_samples'].task1.metainfo_keys())
self.assertIsInstance(results['data_samples'].task1.gt_label,
LabelData)
torch.Tensor)
# Test grayscale image
data['img'] = data['img'].mean(-1)

View File

@ -14,7 +14,7 @@ from mmengine.runner import Runner
from torch.utils.data import DataLoader, Dataset
from mmpretrain.registry import HOOKS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
class ExampleDataset(Dataset):
@ -36,7 +36,7 @@ class MockDataPreprocessor(BaseDataPreprocessor):
def forward(self, data, training):
return data['imgs'], ClsDataSample()
return data['imgs'], DataSample()
class ExampleModel(BaseModel):

View File

@ -9,7 +9,7 @@ from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
from mmpretrain.engine import VisualizationHook
from mmpretrain.registry import HOOKS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer
@ -18,7 +18,7 @@ class TestVisualizationHook(TestCase):
def setUp(self) -> None:
ClsVisualizer.get_instance('visualizer')
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2)
data_sample = DataSample().set_gt_label(1).set_pred_label(2)
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
self.data_batch = {
'inputs': torch.randint(0, 256, (10, 3, 224, 224)),
@ -53,7 +53,7 @@ class TestVisualizationHook(TestCase):
cfg = dict(type='VisualizationHook', enable=True)
hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock:
outputs = [ClsDataSample()] * 10
outputs = [DataSample()] * 10
hook._draw_samples(0, self.data_batch, outputs, step=0)
mock.assert_called_once_with(
'0', image=ANY, data_sample=outputs[0], step=0, show=False)

View File

@ -8,7 +8,7 @@ from mmengine.evaluator import Evaluator
from mmengine.registry import init_default_scope
from mmpretrain.evaluation.metrics import AveragePrecision, MultiLabelMetric
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
init_default_scope('mmpretrain')
@ -152,7 +152,7 @@ class TestMultiLabel(TestCase):
])
pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
for i, j in zip(y_pred_score, y_true)
]
@ -261,7 +261,7 @@ class TestMultiLabel(TestCase):
# Test with gt_score
pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
DataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
for i, j in zip(y_pred_score, y_true_binary)
]
@ -304,7 +304,7 @@ class TestAveragePrecision(TestCase):
])
pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
DataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
for i, j in zip(y_pred, y_true)
]
@ -328,7 +328,7 @@ class TestAveragePrecision(TestCase):
# Test with gt_label without score
pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
for i, j in zip(y_pred, [[0, 1], [1], [2], [0]])
]
evaluator = Evaluator(dict(type='AveragePrecision'))

View File

@ -4,7 +4,7 @@ from unittest import TestCase
import torch
from mmpretrain.evaluation.metrics import MultiTasksMetric
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
class MultiTaskMetric(TestCase):
@ -24,7 +24,7 @@ class MultiTaskMetric(TestCase):
for i, pred in enumerate(data_pred):
sample = {}
for task_name in pred:
task_sample = ClsDataSample().set_pred_score(pred[task_name])
task_sample = DataSample().set_pred_score(pred[task_name])
if task_name in data_gt[i]:
task_sample.set_gt_label(data_gt[i][task_name])
task_sample.set_field(True, 'eval_mask', field_type='metainfo')
@ -68,7 +68,7 @@ class MultiTaskMetric(TestCase):
sample = {}
for task_name in score:
if type(score[task_name]) != dict:
task_sample = ClsDataSample().set_pred_score(score[task_name])
task_sample = DataSample().set_pred_score(score[task_name])
task_sample.set_gt_label(label[task_name])
sample[task_name] = task_sample.to_dict()
sample[task_name]['eval_mask'] = True
@ -76,7 +76,7 @@ class MultiTaskMetric(TestCase):
sample[task_name] = {}
sample[task_name]['eval_mask'] = True
for task_name2 in score[task_name]:
task_sample = ClsDataSample().set_pred_score(
task_sample = DataSample().set_pred_score(
score[task_name][task_name2])
task_sample.set_gt_label(label[task_name][task_name2])
sample[task_name][task_name2] = task_sample.to_dict()

View File

@ -6,7 +6,7 @@ import torch
from mmpretrain.evaluation.metrics import RetrievalRecall
from mmpretrain.registry import METRICS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
class TestRetrievalRecall(TestCase):
@ -14,7 +14,7 @@ class TestRetrievalRecall(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
pred = [
ClsDataSample().set_pred_score(i).set_gt_label(k).to_dict()
DataSample().set_pred_score(i).set_gt_label(k).to_dict()
for i, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),

View File

@ -8,7 +8,7 @@ import torch
from mmpretrain.evaluation.metrics import (Accuracy, ConfusionMatrix,
SingleLabelMetric)
from mmpretrain.registry import METRICS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
class TestAccuracy(TestCase):
@ -16,7 +16,7 @@ class TestAccuracy(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
@ -52,7 +52,7 @@ class TestAccuracy(TestCase):
# Test with label
for sample in pred:
del sample['pred_label']['score']
del sample['pred_score']
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
metric.process(None, pred)
acc = metric.evaluate(6)
@ -123,7 +123,7 @@ class TestSingleLabel(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
@ -212,7 +212,7 @@ class TestSingleLabel(TestCase):
# Test with label, the thrs will be ignored
pred_no_score = copy.deepcopy(pred)
for sample in pred_no_score:
del sample['pred_label']['score']
del sample['pred_score']
del sample['num_classes']
metric = METRICS.build(
dict(type='SingleLabelMetric', thrs=(0., 0.6), num_classes=3))
@ -304,7 +304,7 @@ class TestConfusionMatrix(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
@ -330,7 +330,7 @@ class TestConfusionMatrix(TestCase):
# Test with label
for sample in pred:
del sample['pred_label']['score']
del sample['pred_score']
metric = METRICS.build(dict(type='ConfusionMatrix'))
metric.process(None, pred)
with self.assertRaisesRegex(AssertionError,

View File

@ -7,7 +7,7 @@ import torch
from mmengine.evaluator import Evaluator
from mmengine.registry import init_default_scope
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
init_default_scope('mmpretrain')
@ -27,7 +27,7 @@ class TestVOCMultiLabel(TestCase):
# generate data samples
pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
for i, j in zip(y_pred_score, y_true_label)
]
for sample, difficult_label in zip(pred, y_true_difficult):
@ -155,7 +155,7 @@ class TestVOCAveragePrecision(TestCase):
# generate data samples
pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(
DataSample(num_classes=4).set_pred_score(i).set_gt_score(
j).set_gt_label(k)
for i, j, k in zip(y_pred_score, y_true, y_true_label)
]

View File

@ -9,7 +9,7 @@ from mmengine import ConfigDict
from mmpretrain.models import ImageClassifier
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
def has_timm() -> bool:
@ -134,7 +134,7 @@ class TestImageClassifier(TestCase):
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
@ -142,21 +142,21 @@ class TestImageClassifier(TestCase):
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score,
predictions[0].pred_label.score)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
@ -169,13 +169,13 @@ class TestImageClassifier(TestCase):
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score,
predictions[0].pred_label.score)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
@ -190,7 +190,7 @@ class TestImageClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
@ -207,11 +207,11 @@ class TestImageClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))
def test_test_step(self):
cfg = {
@ -222,11 +222,11 @@ class TestImageClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))
@unittest.skipIf(not has_timm(), 'timm is not installed.')
@ -255,7 +255,7 @@ class TestTimmClassifier(TestCase):
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
@ -263,21 +263,21 @@ class TestTimmClassifier(TestCase):
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score,
predictions[0].pred_label.score)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
@ -290,13 +290,13 @@ class TestTimmClassifier(TestCase):
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score,
predictions[0].pred_label.score)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
@ -311,7 +311,7 @@ class TestTimmClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
@ -328,11 +328,11 @@ class TestTimmClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
def test_test_step(self):
cfg = {
@ -343,11 +343,11 @@ class TestTimmClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
@unittest.skipIf(not has_huggingface(), 'huggingface is not installed.')
@ -376,7 +376,7 @@ class TestHuggingFaceClassifier(TestCase):
def test_loss(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
@ -384,21 +384,21 @@ class TestHuggingFaceClassifier(TestCase):
def test_predict(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score,
predictions[0].pred_label.score)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
@ -411,13 +411,13 @@ class TestHuggingFaceClassifier(TestCase):
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score,
predictions[0].pred_label.score)
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
@ -432,7 +432,7 @@ class TestHuggingFaceClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
@ -449,11 +449,11 @@ class TestHuggingFaceClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
def test_test_step(self):
cfg = {
@ -464,8 +464,8 @@ class TestHuggingFaceClassifier(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
self.assertEqual(predictions[0].pred_score.shape, (1000, ))

View File

@ -10,7 +10,7 @@ import torch
from mmengine import is_seq_of
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
def setup_seed(seed):
@ -43,7 +43,7 @@ class TestClsHead(TestCase):
def test_loss(self):
feats = self.FAKE_FEATS
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
@ -75,23 +75,23 @@ class TestClsHead(TestCase):
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
class TestLinearClsHead(TestCase):
@ -224,7 +224,7 @@ class TestConformerHead(TestCase):
self.assertEqual(outs[1].shape, (4, 5))
def test_loss(self):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
@ -255,23 +255,23 @@ class TestConformerHead(TestCase):
head.loss(self.fake_feats, data_samples)
def test_predict(self):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(self.fake_feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
# with with data_samples
predictions = head.predict(self.fake_feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
class TestStackedLinearClsHead(TestCase):
@ -338,7 +338,7 @@ class TestMultiLabelClsHead(TestCase):
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)]
data_samples = [DataSample().set_gt_label([0, 3]) for _ in range(4)]
# Test with thr and topk are all None
head = MODELS.build(self.DEFAULT_ARGS)
@ -383,7 +383,7 @@ class TestMultiLabelClsHead(TestCase):
# Test with gt_lable with score
data_samples = [
ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
DataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
]
head = MODELS.build(self.DEFAULT_ARGS)
@ -395,23 +395,23 @@ class TestMultiLabelClsHead(TestCase):
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)]
data_samples = [DataSample().set_gt_label([1, 2]) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
# Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS)
@ -419,11 +419,11 @@ class TestMultiLabelClsHead(TestCase):
head = MODELS.build(cfg)
predictions = head.predict(feats, data_samples)
self.assertEqual(head.thr, None)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
self.assertIn('pred_label', pred)
self.assertIn('pred_score', pred)
class EfficientFormerClsHead(TestClsHead):
@ -454,7 +454,7 @@ class EfficientFormerClsHead(TestClsHead):
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# test with distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS)
@ -525,7 +525,7 @@ class TestMultiTaskHead(TestCase):
for _ in range(4):
data_sample = MultiTaskDataSample()
for task_name in self.DEFAULT_ARGS['task_heads']:
task_sample = ClsDataSample().set_gt_label(1)
task_sample = DataSample().set_gt_label(1)
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
# with cal_acc = False
@ -545,7 +545,7 @@ class TestMultiTaskHead(TestCase):
for _ in range(4):
data_sample = MultiTaskDataSample()
for task_name in self.DEFAULT_ARGS['task_heads']:
task_sample = ClsDataSample().set_gt_label(1)
task_sample = DataSample().set_gt_label(1)
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
head = MODELS.build(self.DEFAULT_ARGS)
@ -555,7 +555,7 @@ class TestMultiTaskHead(TestCase):
for pred in predictions:
self.assertIn('task0', pred)
task0_sample = predictions[0].task0
self.assertTrue(type(task0_sample.pred_label.score), 'torch.tensor')
self.assertTrue(type(task0_sample.pred_score), 'torch.tensor')
# with with data_samples
predictions = head.predict(feats, data_samples)
@ -596,7 +596,7 @@ class TestMultiTaskHead(TestCase):
head = MODELS.build(self.DEFAULT_ARGS2)
data_sample = MultiTaskDataSample()
for task_name in gt_label:
task_sample = ClsDataSample().set_gt_label(gt_label[task_name])
task_sample = DataSample().set_gt_label(gt_label[task_name])
data_sample.set_field(task_sample, task_name)
with self.assertRaises(Exception):
head.loss(feats, data_sample)
@ -606,11 +606,11 @@ class TestMultiTaskHead(TestCase):
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
head = MODELS.build(self.DEFAULT_ARGS)
data_sample = MultiTaskDataSample()
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
task_sample = DataSample().set_gt_label(gt_label['task1'])
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = ClsDataSample().set_gt_label(
task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name)
with self.assertRaises(Exception):
@ -694,7 +694,7 @@ class TestArcFaceClsHead(TestCase):
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# test loss with used='before'
head = MODELS.build(self.DEFAULT_ARGS)

View File

@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset
from mmpretrain.datasets.transforms import PackClsInputs
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
class ExampleDataset(Dataset):
@ -125,7 +125,7 @@ class TestImageToImageRetriever(TestCase):
def test_loss(self):
inputs = torch.rand(1, 3, 64, 64)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples)
@ -172,32 +172,32 @@ class TestImageToImageRetriever(TestCase):
def test_predict(self):
inputs = torch.rand(1, 3, 64, 64)
data_samples = [ClsDataSample().set_gt_label([1, 2, 6])]
data_samples = [DataSample().set_gt_label([1, 2, 6])]
# default
model = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score,
predictions[0].pred_label.score)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# k is not -1
cfg = {**self.DEFAULT_ARGS, 'topk': 2}
model = MODELS.build(cfg)
predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model.predict(inputs, data_samples)
assert predictions is data_samples
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
def test_forward(self):
inputs = torch.rand(1, 3, 64, 64)
data_samples = [ClsDataSample().set_gt_label(1)]
data_samples = [DataSample().set_gt_label(1)]
model = MODELS.build(self.DEFAULT_ARGS)
# test pure forward
@ -213,13 +213,13 @@ class TestImageToImageRetriever(TestCase):
# test forward test
predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score,
predictions[0].pred_label.score)
self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_score)
# test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
@ -234,7 +234,7 @@ class TestImageToImageRetriever(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
@ -251,11 +251,11 @@ class TestImageToImageRetriever(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))
def test_test_step(self):
cfg = {
@ -266,8 +266,8 @@ class TestImageToImageRetriever(TestCase):
data = {
'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
self.assertEqual(predictions[0].pred_score.shape, (10, ))

View File

@ -8,7 +8,7 @@ from mmengine.registry import init_default_scope
from mmpretrain.models import AverageClsScoreTTA, ImageClassifier
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
init_default_scope('mmpretrain')
@ -48,20 +48,20 @@ class TestAverageClsScoreTTA(TestCase):
img2 = torch.randint(0, 256, (1, 3, 224, 224))
data1 = {
'inputs': img1,
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
data2 = {
'inputs': img2,
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
data_tta = {
'inputs': [img1, img2],
'data_samples': [[ClsDataSample().set_gt_label(1)],
[ClsDataSample().set_gt_label(1)]]
'data_samples': [[DataSample().set_gt_label(1)],
[DataSample().set_gt_label(1)]]
}
score1 = model.module.test_step(data1)[0].pred_label.score
score2 = model.module.test_step(data2)[0].pred_label.score
score_tta = model.test_step(data_tta)[0].pred_label.score
score1 = model.module.test_step(data1)[0].pred_score
score2 = model.module.test_step(data2)[0].pred_score
score_tta = model.test_step(data_tta)[0].pred_score
torch.testing.assert_allclose(score_tta, (score1 + score2) / 2)

View File

@ -5,7 +5,7 @@ import torch
from mmpretrain.models import ClsDataPreprocessor, RandomBatchAugment
from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
class TestClsDataPreprocessor(TestCase):
@ -16,15 +16,14 @@ class TestClsDataPreprocessor(TestCase):
data = {
'inputs': [torch.randint(0, 256, (3, 224, 224))],
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
processed_data = processor(data)
inputs = processed_data['inputs']
data_samples = processed_data['data_samples']
self.assertEqual(inputs.shape, (1, 3, 224, 224))
self.assertEqual(len(data_samples), 1)
self.assertTrue(
(data_samples[0].gt_label.label == torch.tensor([1])).all())
self.assertTrue((data_samples[0].gt_label == torch.tensor([1])).all())
def test_padding(self):
cfg = dict(type='ClsDataPreprocessor', pad_size_divisor=16)
@ -87,7 +86,7 @@ class TestClsDataPreprocessor(TestCase):
self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
data = {
'inputs': [torch.randint(0, 256, (3, 224, 224))],
'data_samples': [ClsDataSample().set_gt_label(1)]
'data_samples': [DataSample().set_gt_label(1)]
}
processed_data = processor(data, training=True)
self.assertIn('inputs', processed_data)

View File

@ -3,58 +3,51 @@ from unittest import TestCase
import numpy as np
import torch
from mmengine.structures import LabelData
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
class TestClsDataSample(TestCase):
class TestDataSample(TestCase):
def _test_set_label(self, key):
data_sample = ClsDataSample()
data_sample = DataSample()
method = getattr(data_sample, 'set_' + key)
# Test number
method(1)
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
self.assertIsInstance(label, torch.LongTensor)
# Test tensor with single number
method(torch.tensor(2))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
self.assertIsInstance(label, torch.LongTensor)
# Test array with single number
method(np.array(3))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.LongTensor)
self.assertIsInstance(label, torch.LongTensor)
# Test tensor
method(torch.tensor([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertIsInstance(label.label, torch.Tensor)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
self.assertIsInstance(label, torch.Tensor)
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
# Test array
method(np.array([1, 2, 3]))
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
# Test Sequence
method([1, 2, 3])
self.assertIn(key, data_sample)
label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData)
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
# Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
@ -66,34 +59,13 @@ class TestClsDataSample(TestCase):
def test_set_pred_label(self):
self._test_set_label('pred_label')
def test_del_gt_label(self):
data_sample = ClsDataSample()
self.assertNotIn('gt_label', data_sample)
data_sample.set_gt_label(1)
self.assertIn('gt_label', data_sample)
del data_sample.gt_label
self.assertNotIn('gt_label', data_sample)
def test_del_pred_label(self):
data_sample = ClsDataSample()
self.assertNotIn('pred_label', data_sample)
data_sample.set_pred_label(1)
self.assertIn('pred_label', data_sample)
del data_sample.pred_label
self.assertNotIn('pred_label', data_sample)
def test_set_gt_score(self):
data_sample = ClsDataSample()
data_sample = DataSample()
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.gt_label)
torch.testing.assert_allclose(data_sample.gt_label.score,
self.assertIn('gt_score', data_sample)
torch.testing.assert_allclose(data_sample.gt_score,
[0.1, 0.1, 0.6, 0.1, 0.1])
# Test set again
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
@ -103,17 +75,12 @@ class TestClsDataSample(TestCase):
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
def test_set_pred_score(self):
data_sample = ClsDataSample()
data_sample = DataSample()
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.pred_label)
torch.testing.assert_allclose(data_sample.pred_label.score,
self.assertIn('pred_score', data_sample)
torch.testing.assert_allclose(data_sample.pred_score,
[0.1, 0.1, 0.6, 0.1, 0.1])
# Test set again
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2])
@ -129,13 +96,13 @@ class TestMultiTaskDataSample(TestCase):
def test_multi_task_data_sample(self):
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample()
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
task_sample = DataSample().set_gt_label(gt_label['task1'])
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = ClsDataSample().set_gt_label(
task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name)
self.assertIsInstance(data_sample.task0, MultiTaskDataSample)
self.assertIsInstance(data_sample.task1, ClsDataSample)
self.assertIsInstance(data_sample.task0.task00, ClsDataSample)
self.assertIsInstance(data_sample.task1, DataSample)
self.assertIsInstance(data_sample.task0.task00, DataSample)

View File

@ -2,10 +2,9 @@
from unittest import TestCase
import torch
from mmengine.structures import LabelData
from mmpretrain.structures import (batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
tensor_split)
class TestStructureUtils(TestCase):
@ -28,11 +27,11 @@ class TestStructureUtils(TestCase):
def test_cat_batch_labels(self):
labels = [
LabelData(label=torch.tensor([1])),
LabelData(label=torch.tensor([3, 2])),
LabelData(label=torch.tensor([0, 1, 4])),
LabelData(label=torch.tensor([], dtype=torch.int64)),
LabelData(label=torch.tensor([], dtype=torch.int64)),
torch.tensor([1]),
torch.tensor([3, 2]),
torch.tensor([0, 1, 4]),
torch.tensor([], dtype=torch.int64),
torch.tensor([], dtype=torch.int64),
]
batch_label, split_indices = cat_batch_labels(labels)
@ -45,42 +44,13 @@ class TestStructureUtils(TestCase):
self.assertEqual(labels[3].tolist(), [])
self.assertEqual(labels[4].tolist(), [])
labels = [
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
]
batch_label, split_indices = cat_batch_labels(labels)
self.assertIsNone(batch_label)
self.assertIsNone(split_indices)
def test_stack_batch_scores(self):
labels = [
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
]
batch_score = stack_batch_scores(labels)
self.assertEqual(batch_score.shape, (3, 5))
labels = [
LabelData(label=torch.tensor([1])),
LabelData(label=torch.tensor([3, 2])),
LabelData(label=torch.tensor([0, 1, 4])),
LabelData(label=torch.tensor([], dtype=torch.int64)),
LabelData(label=torch.tensor([], dtype=torch.int64)),
]
batch_score = stack_batch_scores(labels)
self.assertIsNone(batch_score)
def test_batch_label_to_onehot(self):
labels = [
LabelData(label=torch.tensor([1])),
LabelData(label=torch.tensor([3, 2])),
LabelData(label=torch.tensor([0, 1, 4])),
LabelData(label=torch.tensor([], dtype=torch.int64)),
LabelData(label=torch.tensor([], dtype=torch.int64)),
torch.tensor([1]),
torch.tensor([3, 2]),
torch.tensor([0, 1, 4]),
torch.tensor([], dtype=torch.int64),
torch.tensor([], dtype=torch.int64),
]
batch_label, split_indices = cat_batch_labels(labels)

View File

@ -7,7 +7,7 @@ from unittest.mock import patch
import numpy as np
import torch
from mmpretrain.structures import ClsDataSample
from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer
@ -24,7 +24,7 @@ class TestClsVisualizer(TestCase):
def test_add_datasample(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
# Test show
@ -82,7 +82,7 @@ class TestClsVisualizer(TestCase):
'test', image=image, data_sample=data_sample, draw_gt=False)
# Test without score
del data_sample.pred_label.score
del data_sample.pred_score
def test_texts(text, *_, **__):
self.assertEqual(