[Refactor] Refactor ClsDatasample to a union DataSample. (#1371)

* [Refactor] Refactor ClsDatasample to a union DataSample.

* Add  method

* Fix docstring

* Update docstring.
pull/1380/head
Ma Zerun 2023-02-23 10:07:53 +08:00 committed by GitHub
parent 4016f1348e
commit 36bea13fca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 659 additions and 755 deletions

View File

@ -1,13 +1,13 @@
.. role:: hidden .. role:: hidden
:class: hidden-section :class: hidden-section
.. module:: mmcls.structures .. module:: mmpretrain.structures
mmcls.structures mmpretrain.structures
=================================== ===================================
This package includes basic data structures for classification tasks. This package includes basic data structures.
ClsDataSample DataSample
------------- -------------
.. autoclass:: ClsDataSample .. autoclass:: DataSample

View File

@ -12,7 +12,7 @@ from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint from mmengine.runner import load_checkpoint
from mmpretrain.registry import TRANSFORMS from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from .model import get_model, init_model, list_models from .model import get_model, init_model, list_models
ModelType = Union[BaseModel, str, Config] ModelType = Union[BaseModel, str, Config]
@ -176,7 +176,7 @@ class ImageClassificationInferencer(BaseInferencer):
def visualize(self, def visualize(self,
ori_inputs: List[InputType], ori_inputs: List[InputType],
preds: List[ClsDataSample], preds: List[DataSample],
show: bool = False, show: bool = False,
rescale_factor: Optional[float] = None, rescale_factor: Optional[float] = None,
draw_score=True, draw_score=True,
@ -223,7 +223,7 @@ class ImageClassificationInferencer(BaseInferencer):
return visualization return visualization
def postprocess(self, def postprocess(self,
preds: List[ClsDataSample], preds: List[DataSample],
visualization: List[np.ndarray], visualization: List[np.ndarray],
return_datasamples=False) -> dict: return_datasamples=False) -> dict:
if return_datasamples: if return_datasamples:
@ -231,14 +231,13 @@ class ImageClassificationInferencer(BaseInferencer):
results = [] results = []
for data_sample in preds: for data_sample in preds:
prediction = data_sample.pred_label pred_scores = data_sample.pred_score
pred_scores = prediction.score.detach().cpu().numpy() pred_score = float(torch.max(pred_scores).item())
pred_score = torch.max(prediction.score).item() pred_label = torch.argmax(pred_scores).item()
pred_label = torch.argmax(prediction.score).item()
result = { result = {
'pred_scores': pred_scores, 'pred_scores': pred_scores.detach().cpu().numpy(),
'pred_label': pred_label, 'pred_label': pred_label,
'pred_score': float(pred_score), 'pred_score': pred_score,
} }
if self.classes is not None: if self.classes is not None:
result['pred_class'] = self.classes[pred_label] result['pred_class'] = self.classes[pred_label]

View File

@ -10,7 +10,7 @@ from mmengine.utils import is_str
from PIL import Image from PIL import Image
from mmpretrain.registry import TRANSFORMS from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample from mmpretrain.structures import DataSample, MultiTaskDataSample
def to_tensor(data): def to_tensor(data):
@ -53,7 +53,7 @@ class PackClsInputs(BaseTransform):
**Added Keys:** **Added Keys:**
- inputs (:obj:`torch.Tensor`): The forward data of models. - inputs (:obj:`torch.Tensor`): The forward data of models.
- data_samples (:obj:`~mmpretrain.structures.ClsDataSample`): The - data_samples (:obj:`~mmpretrain.structures.DataSample`): The
annotation info of the sample. annotation info of the sample.
Args: Args:
@ -87,10 +87,11 @@ class PackClsInputs(BaseTransform):
img = np.ascontiguousarray(img.transpose(2, 0, 1)) img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img) packed_results['inputs'] = to_tensor(img)
data_sample = ClsDataSample() data_sample = DataSample()
if 'gt_label' in results: if 'gt_label' in results:
gt_label = results['gt_label'] data_sample.set_gt_label(results['gt_label'])
data_sample.set_gt_label(gt_label) if 'gt_score' in results:
data_sample.set_gt_score(results['gt_score'])
img_meta = {k: results[k] for k in self.meta_keys if k in results} img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta) data_sample.set_metainfo(img_meta)

View File

@ -9,7 +9,7 @@ from mmengine.runner import EpochBasedTrainLoop, Runner
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
from mmpretrain.registry import HOOKS from mmpretrain.registry import HOOKS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
@HOOKS.register_module() @HOOKS.register_module()
@ -51,14 +51,14 @@ class VisualizationHook(Hook):
def _draw_samples(self, def _draw_samples(self,
batch_idx: int, batch_idx: int,
data_batch: dict, data_batch: dict,
data_samples: Sequence[ClsDataSample], data_samples: Sequence[DataSample],
step: int = 0) -> None: step: int = 0) -> None:
"""Visualize every ``self.interval`` samples from a data batch. """Visualize every ``self.interval`` samples from a data batch.
Args: Args:
batch_idx (int): The index of the current batch in the val loop. batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader. data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model. outputs (Sequence[:obj:`DataSample`]): Outputs from model.
step (int): Global step value to record. Defaults to 0. step (int): Global step value to record. Defaults to 0.
""" """
if self.enable is False: if self.enable is False:
@ -97,14 +97,14 @@ class VisualizationHook(Hook):
) )
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[ClsDataSample]) -> None: outputs: Sequence[DataSample]) -> None:
"""Visualize every ``self.interval`` samples during validation. """Visualize every ``self.interval`` samples during validation.
Args: Args:
runner (:obj:`Runner`): The runner of the validation process. runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop. batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader. data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model. outputs (Sequence[:obj:`DataSample`]): Outputs from model.
""" """
if isinstance(runner.train_loop, EpochBasedTrainLoop): if isinstance(runner.train_loop, EpochBasedTrainLoop):
step = runner.epoch step = runner.epoch
@ -114,7 +114,7 @@ class VisualizationHook(Hook):
self._draw_samples(batch_idx, data_batch, outputs, step=step) self._draw_samples(batch_idx, data_batch, outputs, step=step)
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[ClsDataSample]) -> None: outputs: Sequence[DataSample]) -> None:
"""Visualize every ``self.interval`` samples during test. """Visualize every ``self.interval`` samples during test.
Args: Args:

View File

@ -5,9 +5,9 @@ import numpy as np
import torch import torch
from mmengine.evaluator import BaseMetric from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
from mmengine.structures import LabelData
from mmpretrain.registry import METRICS from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .single_label import _precision_recall_f1_support, to_tensor from .single_label import _precision_recall_f1_support, to_tensor
@ -114,10 +114,10 @@ class MultiLabelMetric(BaseMetric):
(tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8)) (tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8))
>>> >>>
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample >>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_sampels = [ >>> data_sampels = [
... ClsDataSample().set_pred_score(pred).set_gt_score(gt) ... DataSample().set_pred_score(pred).set_gt_score(gt)
... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))] ... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))]
>>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5)) >>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5))
>>> evaluator.process(data_sampels) >>> evaluator.process(data_sampels)
@ -181,17 +181,15 @@ class MultiLabelMetric(BaseMetric):
""" """
for data_sample in data_samples: for data_sample in data_samples:
result = dict() result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
result['pred_score'] = pred_label['score'].clone() result['pred_score'] = data_sample['pred_score'].clone()
num_classes = result['pred_score'].size()[-1] num_classes = result['pred_score'].size()[-1]
if 'score' in gt_label: if 'gt_score' in data_sample:
result['gt_score'] = gt_label['score'].clone() result['gt_score'] = data_sample['gt_score'].clone()
else: else:
result['gt_score'] = LabelData.label_to_onehot( result['gt_score'] = label_to_onehot(data_sample['gt_label'],
gt_label['label'], num_classes) num_classes)
# Save the result to `self.results`. # Save the result to `self.results`.
self.results.append(result) self.results.append(result)
@ -331,8 +329,7 @@ class MultiLabelMetric(BaseMetric):
assert num_classes is not None, 'For index-type labels, ' \ assert num_classes is not None, 'For index-type labels, ' \
'please specify `num_classes`.' 'please specify `num_classes`.'
label = torch.stack([ label = torch.stack([
LabelData.label_to_onehot( label_to_onehot(indices, num_classes)
to_tensor(indices), num_classes)
for indices in label for indices in label
]) ])
else: else:
@ -479,10 +476,10 @@ class AveragePrecision(BaseMetric):
>>> AveragePrecision.calculate(y_pred, y_true) >>> AveragePrecision.calculate(y_pred, y_true)
tensor(70.833) tensor(70.833)
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample >>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_samples = [ >>> data_samples = [
... ClsDataSample().set_pred_score(i).set_gt_score(j) ... DataSample().set_pred_score(i).set_gt_score(j)
... for i, j in zip(y_pred, y_true) ... for i, j in zip(y_pred, y_true)
... ] ... ]
>>> evaluator = Evaluator(metrics=AveragePrecision()) >>> evaluator = Evaluator(metrics=AveragePrecision())
@ -517,17 +514,15 @@ class AveragePrecision(BaseMetric):
for data_sample in data_samples: for data_sample in data_samples:
result = dict() result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
result['pred_score'] = pred_label['score'] result['pred_score'] = data_sample['pred_score'].clone()
num_classes = result['pred_score'].size()[-1] num_classes = result['pred_score'].size()[-1]
if 'score' in gt_label: if 'gt_score' in data_sample:
result['gt_score'] = gt_label['score'] result['gt_score'] = data_sample['gt_score'].clone()
else: else:
result['gt_score'] = LabelData.label_to_onehot( result['gt_score'] = label_to_onehot(data_sample['gt_label'],
gt_label['label'], num_classes) num_classes)
# Save the result to `self.results`. # Save the result to `self.results`.
self.results.append(result) self.results.append(result)

View File

@ -5,10 +5,10 @@ import mmengine
import numpy as np import numpy as np
import torch import torch
from mmengine.evaluator import BaseMetric from mmengine.evaluator import BaseMetric
from mmengine.structures import LabelData
from mmengine.utils import is_seq_of from mmengine.utils import is_seq_of
from mmpretrain.registry import METRICS from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .single_label import to_tensor from .single_label import to_tensor
@ -48,10 +48,10 @@ class RetrievalRecall(BaseMetric):
[tensor(9.3000), tensor(48.4000)] [tensor(9.3000), tensor(48.4000)]
>>> >>>
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample >>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_samples = [ >>> data_samples = [
... ClsDataSample().set_gt_label([0, 1]).set_pred_score( ... DataSample().set_gt_label([0, 1]).set_pred_score(
... torch.rand(10)) ... torch.rand(10))
... for i in range(1000) ... for i in range(1000)
... ] ... ]
@ -95,23 +95,21 @@ class RetrievalRecall(BaseMetric):
predictions (Sequence[dict]): A batch of outputs from the model. predictions (Sequence[dict]): A batch of outputs from the model.
""" """
for data_sample in data_samples: for data_sample in data_samples:
pred_label = data_sample['pred_label'] pred_score = data_sample['pred_score'].clone()
gt_label = data_sample['gt_label'] gt_label = data_sample['gt_label']
pred = pred_label['score'].clone() if 'gt_score' in data_sample:
if 'score' in gt_label: target = data_sample.get('gt_score').clone()
target = gt_label['score'].clone()
else: else:
num_classes = pred_label['score'].size()[-1] num_classes = pred_score.size()[-1]
target = LabelData.label_to_onehot(gt_label['label'], target = label_to_onehot(gt_label, num_classes)
num_classes)
# Because the retrieval output logit vector will be much larger # Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the # compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce # evaluation results are computed each batch here and then reduce
# all results at the end. # all results at the end.
result = RetrievalRecall.calculate( result = RetrievalRecall.calculate(
pred.unsqueeze(0), target.unsqueeze(0), topk=self.topk) pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
self.results.append(result) self.results.append(result)
def compute_metrics(self, results: List): def compute_metrics(self, results: List):
@ -230,5 +228,5 @@ def _format_target(label, is_indices=False):
raise TypeError(f'The pred must be type of torch.tensor, ' raise TypeError(f'The pred must be type of torch.tensor, '
f'np.ndarray or Sequence but get {type(label)}.') f'np.ndarray or Sequence but get {type(label)}.')
indices = [LabelData.onehot_to_label(sample_gt) for sample_gt in label] indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label]
return indices return indices

View File

@ -104,10 +104,10 @@ class Accuracy(BaseMetric):
[[tensor([9.9000])], [tensor([51.5000])]] [[tensor([9.9000])], [tensor([51.5000])]]
>>> >>>
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample >>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_samples = [ >>> data_samples = [
... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10)) ... DataSample().set_gt_label(0).set_pred_score(torch.rand(10))
... for i in range(1000) ... for i in range(1000)
... ] ... ]
>>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5))) >>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
@ -150,13 +150,11 @@ class Accuracy(BaseMetric):
for data_sample in data_samples: for data_sample in data_samples:
result = dict() result = dict()
pred_label = data_sample['pred_label'] if 'pred_score' in data_sample:
gt_label = data_sample['gt_label'] result['pred_score'] = data_sample['pred_score'].cpu()
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else: else:
result['pred_label'] = pred_label['label'].cpu() result['pred_label'] = data_sample['pred_label'].cpu()
result['gt_label'] = gt_label['label'].cpu() result['gt_label'] = data_sample['gt_label'].cpu()
# Save the result to `self.results`. # Save the result to `self.results`.
self.results.append(result) self.results.append(result)
@ -358,10 +356,10 @@ class SingleLabelMetric(BaseMetric):
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))] (tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
>>> >>>
>>> # ------------------- Use with Evalutor ------------------- >>> # ------------------- Use with Evalutor -------------------
>>> from mmpretrain.structures import ClsDataSample >>> from mmpretrain.structures import DataSample
>>> from mmengine.evaluator import Evaluator >>> from mmengine.evaluator import Evaluator
>>> data_samples = [ >>> data_samples = [
... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5)) ... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
... for i in range(1000) ... for i in range(1000)
... ] ... ]
>>> evaluator = Evaluator(metrics=SingleLabelMetric()) >>> evaluator = Evaluator(metrics=SingleLabelMetric())
@ -418,19 +416,16 @@ class SingleLabelMetric(BaseMetric):
for data_sample in data_samples: for data_sample in data_samples:
result = dict() result = dict()
pred_label = data_sample['pred_label'] if 'pred_score' in data_sample:
gt_label = data_sample['gt_label'] result['pred_score'] = data_sample['pred_score'].cpu()
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else: else:
num_classes = self.num_classes or data_sample.get( num_classes = self.num_classes or data_sample.get(
'num_classes') 'num_classes')
assert num_classes is not None, \ assert num_classes is not None, \
'The `num_classes` must be specified if `pred_label` has '\ 'The `num_classes` must be specified if no `pred_score`.'
'only `label`.' result['pred_label'] = data_sample['pred_label'].cpu()
result['pred_label'] = pred_label['label'].cpu()
result['num_classes'] = num_classes result['num_classes'] = num_classes
result['gt_label'] = gt_label['label'].cpu() result['gt_label'] = data_sample['gt_label'].cpu()
# Save the result to `self.results`. # Save the result to `self.results`.
self.results.append(result) self.results.append(result)
@ -641,17 +636,16 @@ class ConfusionMatrix(BaseMetric):
def process(self, data_batch, data_samples: Sequence[dict]) -> None: def process(self, data_batch, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples: for data_sample in data_samples:
pred = data_sample['pred_label'] if 'pred_score' in data_sample:
gt_label = data_sample['gt_label']['label'] pred_score = data_sample['pred_score']
if 'score' in pred: pred_label = pred_score.argmax(dim=0, keepdim=True)
pred_label = pred['score'].argmax(dim=0, keepdim=True) self.num_classes = pred_score.size(0)
self.num_classes = pred['score'].size(0)
else: else:
pred_label = pred['label'] pred_label = data_sample['pred_label']
self.results.append({ self.results.append({
'pred_label': pred_label, 'pred_label': pred_label,
'gt_label': gt_label 'gt_label': data_sample['gt_label'],
}) })
def compute_metrics(self, results: list) -> dict: def compute_metrics(self, results: list) -> dict:

View File

@ -1,9 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence from typing import Optional, Sequence
from mmengine.structures import LabelData
from mmpretrain.registry import METRICS from mmpretrain.registry import METRICS
from mmpretrain.structures import label_to_onehot
from .multi_label import AveragePrecision, MultiLabelMetric from .multi_label import AveragePrecision, MultiLabelMetric
@ -39,18 +38,16 @@ class VOCMetricMixin:
""" """
for data_sample in data_samples: for data_sample in data_samples:
result = dict() result = dict()
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label'] gt_label = data_sample['gt_label']
gt_label_difficult = data_sample['gt_label_difficult'] gt_label_difficult = data_sample['gt_label_difficult']
result['pred_score'] = pred_label['score'].clone() result['pred_score'] = data_sample['pred_score'].clone()
num_classes = result['pred_score'].size()[-1] num_classes = result['pred_score'].size()[-1]
if 'score' in gt_label: if 'gt_score' in data_sample:
result['gt_score'] = gt_label['score'].clone() result['gt_score'] = data_sample['gt_score'].clone()
else: else:
result['gt_score'] = LabelData.label_to_onehot( result['gt_score'] = label_to_onehot(gt_label, num_classes)
gt_label['label'], num_classes)
# VOC annotation labels all the objects in a single image # VOC annotation labels all the objects in a single image
# therefore, some categories are appeared both in # therefore, some categories are appeared both in
@ -58,7 +55,7 @@ class VOCMetricMixin:
# Here we reckon those labels which are only exists in difficult # Here we reckon those labels which are only exists in difficult
# objects as difficult labels. # objects as difficult labels.
difficult_label = set(gt_label_difficult) - ( difficult_label = set(gt_label_difficult) - (
set(gt_label_difficult) & set(gt_label['label'].tolist())) set(gt_label_difficult) & set(gt_label.tolist()))
# set difficult label for better eval # set difficult label for better eval
if self.difficult_as_positive is None: if self.difficult_as_positive is None:

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from .base import BaseClassifier from .base import BaseClassifier
@ -122,14 +122,14 @@ class HuggingFaceClassifier(BaseClassifier):
raise NotImplementedError( raise NotImplementedError(
"The HuggingFaceClassifier doesn't support extract feature yet.") "The HuggingFaceClassifier doesn't support extract feature yet.")
def loss(self, inputs: torch.Tensor, data_samples: List[ClsDataSample], def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs): **kwargs):
"""Calculate losses from a batch of inputs and data samples. """Calculate losses from a batch of inputs and data samples.
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
**kwargs: Other keyword arguments of the loss module. **kwargs: Other keyword arguments of the loss module.
@ -144,14 +144,14 @@ class HuggingFaceClassifier(BaseClassifier):
return losses return losses
def _get_loss(self, cls_score: torch.Tensor, def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs): data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss.""" """Unpack data samples and compute loss."""
# Unpack data samples and pack targets # Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label: if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores. # Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples]) target = torch.stack([i.gt_score for i in data_samples])
else: else:
target = torch.cat([i.gt_label.label for i in data_samples]) target = torch.cat([i.gt_label for i in data_samples])
# compute loss # compute loss
losses = dict() losses = dict()
@ -163,17 +163,17 @@ class HuggingFaceClassifier(BaseClassifier):
def predict(self, def predict(self,
inputs: torch.Tensor, inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None): data_samples: Optional[List[DataSample]] = None):
"""Predict results from a batch of inputs. """Predict results from a batch of inputs.
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None. data of every samples. Defaults to None.
Returns: Returns:
List[ClsDataSample]: The prediction results. List[DataSample]: The prediction results.
""" """
# The part can be traced by torch.fx # The part can be traced by torch.fx
cls_score = self.model(inputs).logits cls_score = self.model(inputs).logits
@ -197,8 +197,8 @@ class HuggingFaceClassifier(BaseClassifier):
else: else:
data_samples = [] data_samples = []
for score, label in zip(pred_scores, pred_labels): for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score( data_samples.append(
score).set_pred_label(label)) DataSample().set_pred_score(score).set_pred_label(label))
return data_samples return data_samples

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from .base import BaseClassifier from .base import BaseClassifier
@ -76,7 +76,7 @@ class ImageClassifier(BaseClassifier):
def forward(self, def forward(self,
inputs: torch.Tensor, inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None, data_samples: Optional[List[DataSample]] = None,
mode: str = 'tensor'): mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test. """The unified entry for a forward process in both training and test.
@ -85,7 +85,7 @@ class ImageClassifier(BaseClassifier):
- "tensor": Forward the whole network and return tensor or tuple of - "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module. tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully - "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`ClsDataSample`. processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given - "loss": Forward and return a dict of losses according to the given
inputs and data samples. inputs and data samples.
@ -95,7 +95,7 @@ class ImageClassifier(BaseClassifier):
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``. data of every samples. It's required if ``mode="loss"``.
Defaults to None. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'. mode (str): Return what kind of value. Defaults to 'tensor'.
@ -105,7 +105,7 @@ class ImageClassifier(BaseClassifier):
- If ``mode="tensor"``, return a tensor or a tuple of tensor. - If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of - If ``mode="predict"``, return a list of
:obj:`mmpretrain.structures.ClsDataSample`. :obj:`mmpretrain.structures.DataSample`.
- If ``mode="loss"``, return a dict of tensor. - If ``mode="loss"``, return a dict of tensor.
""" """
if mode == 'tensor': if mode == 'tensor':
@ -209,13 +209,13 @@ class ImageClassifier(BaseClassifier):
return self.head.pre_logits(x) return self.head.pre_logits(x)
def loss(self, inputs: torch.Tensor, def loss(self, inputs: torch.Tensor,
data_samples: List[ClsDataSample]) -> dict: data_samples: List[DataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples. """Calculate losses from a batch of inputs and data samples.
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
Returns: Returns:
@ -226,14 +226,14 @@ class ImageClassifier(BaseClassifier):
def predict(self, def predict(self,
inputs: torch.Tensor, inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None, data_samples: Optional[List[DataSample]] = None,
**kwargs) -> List[ClsDataSample]: **kwargs) -> List[DataSample]:
"""Predict results from a batch of inputs. """Predict results from a batch of inputs.
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None. data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict`` **kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`. method of :attr:`head`.

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from .base import BaseClassifier from .base import BaseClassifier
@ -110,14 +110,14 @@ class TimmClassifier(BaseClassifier):
f"The model {type(self.model)} doesn't support extract " f"The model {type(self.model)} doesn't support extract "
"feature because it don't have `forward_features` method.") "feature because it don't have `forward_features` method.")
def loss(self, inputs: torch.Tensor, data_samples: List[ClsDataSample], def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs): **kwargs):
"""Calculate losses from a batch of inputs and data samples. """Calculate losses from a batch of inputs and data samples.
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
**kwargs: Other keyword arguments of the loss module. **kwargs: Other keyword arguments of the loss module.
@ -132,14 +132,14 @@ class TimmClassifier(BaseClassifier):
return losses return losses
def _get_loss(self, cls_score: torch.Tensor, def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs): data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss.""" """Unpack data samples and compute loss."""
# Unpack data samples and pack targets # Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label: if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores. # Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples]) target = torch.stack([i.gt_score for i in data_samples])
else: else:
target = torch.cat([i.gt_label.label for i in data_samples]) target = torch.cat([i.gt_label for i in data_samples])
# compute loss # compute loss
losses = dict() losses = dict()
@ -150,17 +150,17 @@ class TimmClassifier(BaseClassifier):
def predict(self, def predict(self,
inputs: torch.Tensor, inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None): data_samples: Optional[List[DataSample]] = None):
"""Predict results from a batch of inputs. """Predict results from a batch of inputs.
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None. data of every samples. Defaults to None.
Returns: Returns:
List[ClsDataSample]: The prediction results. List[DataSample]: The prediction results.
""" """
# The part can be traced by torch.fx # The part can be traced by torch.fx
cls_score = self(inputs) cls_score = self(inputs)
@ -184,8 +184,8 @@ class TimmClassifier(BaseClassifier):
else: else:
data_samples = [] data_samples = []
for score, label in zip(pred_scores, pred_labels): for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score( data_samples.append(
score).set_pred_label(label)) DataSample().set_pred_score(score).set_pred_label(label))
return data_samples return data_samples

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from mmpretrain.evaluation.metrics import Accuracy from mmpretrain.evaluation.metrics import Accuracy
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from .base_head import BaseHead from .base_head import BaseHead
@ -57,8 +57,8 @@ class ClsHead(BaseHead):
# just return the unpacked inputs. # just return the unpacked inputs.
return pre_logits return pre_logits
def loss(self, feats: Tuple[torch.Tensor], def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
data_samples: List[ClsDataSample], **kwargs) -> dict: **kwargs) -> dict:
"""Calculate losses from the classification score. """Calculate losses from the classification score.
Args: Args:
@ -66,7 +66,7 @@ class ClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be will be used to classify. The shape of every item should be
``(num_samples, num_classes)``. ``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
**kwargs: Other keyword arguments to forward the loss module. **kwargs: Other keyword arguments to forward the loss module.
@ -81,14 +81,14 @@ class ClsHead(BaseHead):
return losses return losses
def _get_loss(self, cls_score: torch.Tensor, def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs): data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss.""" """Unpack data samples and compute loss."""
# Unpack data samples and pack targets # Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label: if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores. # Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples]) target = torch.stack([i.gt_score for i in data_samples])
else: else:
target = torch.cat([i.gt_label.label for i in data_samples]) target = torch.cat([i.gt_label for i in data_samples])
# compute loss # compute loss
losses = dict() losses = dict()
@ -110,8 +110,8 @@ class ClsHead(BaseHead):
def predict( def predict(
self, self,
feats: Tuple[torch.Tensor], feats: Tuple[torch.Tensor],
data_samples: List[Union[ClsDataSample, None]] = None data_samples: Optional[List[Optional[DataSample]]] = None
) -> List[ClsDataSample]: ) -> List[DataSample]:
"""Inference without augmentation. """Inference without augmentation.
Args: Args:
@ -119,12 +119,12 @@ class ClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be will be used to classify. The shape of every item should be
``(num_samples, num_classes)``. ``(num_samples, num_classes)``.
data_samples (List[ClsDataSample | None], optional): The annotation data_samples (List[DataSample | None], optional): The annotation
data of every samples. If not None, set ``pred_label`` of data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None. the input data samples. Defaults to None.
Returns: Returns:
List[ClsDataSample]: A list of data samples which contains the List[DataSample]: A list of data samples which contains the
predicted results. predicted results.
""" """
# The part can be traced by torch.fx # The part can be traced by torch.fx
@ -149,7 +149,7 @@ class ClsHead(BaseHead):
for data_sample, score, label in zip(data_samples, pred_scores, for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels): pred_labels):
if data_sample is None: if data_sample is None:
data_sample = ClsDataSample() data_sample = DataSample()
data_sample.set_pred_score(score).set_pred_label(label) data_sample.set_pred_score(score).set_pred_label(label)
out_data_samples.append(data_sample) out_data_samples.append(data_sample)

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmpretrain.evaluation.metrics import Accuracy from mmpretrain.evaluation.metrics import Accuracy
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from .cls_head import ClsHead from .cls_head import ClsHead
@ -64,10 +64,9 @@ class ConformerHead(ClsHead):
return conv_cls_score, tran_cls_score return conv_cls_score, tran_cls_score
def predict( def predict(self,
self, feats: Tuple[List[torch.Tensor]],
feats: Tuple[List[torch.Tensor]], data_samples: List[DataSample] = None) -> List[DataSample]:
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
"""Inference without augmentation. """Inference without augmentation.
Args: Args:
@ -75,12 +74,12 @@ class ConformerHead(ClsHead):
Multiple stage inputs are acceptable but only the last stage Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be will be used to classify. The shape of every item should be
``(num_samples, num_classes)``. ``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None. the input data samples. Defaults to None.
Returns: Returns:
List[ClsDataSample]: A list of data samples which contains the List[DataSample]: A list of data samples which contains the
predicted results. predicted results.
""" """
# The part can be traced by torch.fx # The part can be traced by torch.fx
@ -92,14 +91,14 @@ class ConformerHead(ClsHead):
return predictions return predictions
def _get_loss(self, cls_score: Tuple[torch.Tensor], def _get_loss(self, cls_score: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict: data_samples: List[DataSample], **kwargs) -> dict:
"""Unpack data samples and compute loss.""" """Unpack data samples and compute loss."""
# Unpack data samples and pack targets # Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label: if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores. # Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples]) target = torch.stack([i.gt_score for i in data_samples])
else: else:
target = torch.cat([i.gt_label.label for i in data_samples]) target = torch.cat([i.gt_label for i in data_samples])
# compute loss # compute loss
losses = dict() losses = dict()

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from .cls_head import ClsHead from .cls_head import ClsHead
@ -65,8 +65,8 @@ class EfficientFormerClsHead(ClsHead):
# after unpacking. # after unpacking.
return feats[-1] return feats[-1]
def loss(self, feats: Tuple[torch.Tensor], def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
data_samples: List[ClsDataSample], **kwargs) -> dict: **kwargs) -> dict:
"""Calculate losses from the classification score. """Calculate losses from the classification score.
Args: Args:
@ -74,7 +74,7 @@ class EfficientFormerClsHead(ClsHead):
Multiple stage inputs are acceptable but only the last stage Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be will be used to classify. The shape of every item should be
``(num_samples, num_classes)``. ``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
**kwargs: Other keyword arguments to forward the loss module. **kwargs: Other keyword arguments to forward the loss module.

View File

@ -11,7 +11,7 @@ from mmengine.utils import is_seq_of
from mmpretrain.models.losses import convert_to_one_hot from mmpretrain.models.losses import convert_to_one_hot
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from .cls_head import ClsHead from .cls_head import ClsHead
@ -264,8 +264,8 @@ class ArcFaceClsHead(ClsHead):
return self.scale * logit return self.scale * logit
def loss(self, feats: Tuple[torch.Tensor], def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
data_samples: List[ClsDataSample], **kwargs) -> dict: **kwargs) -> dict:
"""Calculate losses from the classification score. """Calculate losses from the classification score.
Args: Args:
@ -273,7 +273,7 @@ class ArcFaceClsHead(ClsHead):
Multiple stage inputs are acceptable but only the last stage Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be will be used to classify. The shape of every item should be
``(num_samples, num_classes)``. ``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
**kwargs: Other keyword arguments to forward the loss module. **kwargs: Other keyword arguments to forward the loss module.
@ -281,12 +281,11 @@ class ArcFaceClsHead(ClsHead):
dict[str, Tensor]: a dictionary of loss components dict[str, Tensor]: a dictionary of loss components
""" """
# Unpack data samples and pack targets # Unpack data samples and pack targets
label_target = torch.cat([i.gt_label.label for i in data_samples]) label_target = torch.cat([i.gt_label for i in data_samples])
if 'score' in data_samples[0].gt_label: if 'gt_score' in data_samples[0]:
# Batch augmentation may convert labels to one-hot format scores. # Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples]) target = torch.stack([i.gt_score for i in data_samples])
else: else:
# change the labels to to one-hot format scores.
target = label_target target = label_target
# the index format target would be used # the index format target would be used

View File

@ -3,10 +3,9 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.structures import LabelData
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample, label_to_onehot
from .base_head import BaseHead from .base_head import BaseHead
@ -65,8 +64,8 @@ class MultiLabelClsHead(BaseHead):
# just return the unpacked inputs. # just return the unpacked inputs.
return pre_logits return pre_logits
def loss(self, feats: Tuple[torch.Tensor], def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
data_samples: List[ClsDataSample], **kwargs) -> dict: **kwargs) -> dict:
"""Calculate losses from the classification score. """Calculate losses from the classification score.
Args: Args:
@ -74,7 +73,7 @@ class MultiLabelClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be will be used to classify. The shape of every item should be
``(num_samples, num_classes)``. ``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
**kwargs: Other keyword arguments to forward the loss module. **kwargs: Other keyword arguments to forward the loss module.
@ -89,19 +88,16 @@ class MultiLabelClsHead(BaseHead):
return losses return losses
def _get_loss(self, cls_score: torch.Tensor, def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs): data_samples: List[DataSample], **kwargs):
"""Unpack data samples and compute loss.""" """Unpack data samples and compute loss."""
num_classes = cls_score.size()[-1] num_classes = cls_score.size()[-1]
# Unpack data samples and pack targets # Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label: if 'gt_score' in data_samples[0]:
target = torch.stack( target = torch.stack([i.gt_score for i in data_samples])
[i.gt_label.score.float() for i in data_samples])
else: else:
target = torch.stack([ target = torch.stack([
LabelData.label_to_onehot(i.gt_label.label, label_to_onehot(i.gt_label, num_classes) for i in data_samples
num_classes).float() ]).float()
for i in data_samples
])
# compute loss # compute loss
losses = dict() losses = dict()
@ -111,10 +107,9 @@ class MultiLabelClsHead(BaseHead):
return losses return losses
def predict( def predict(self,
self, feats: Tuple[torch.Tensor],
feats: Tuple[torch.Tensor], data_samples: List[DataSample] = None) -> List[DataSample]:
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
"""Inference without augmentation. """Inference without augmentation.
Args: Args:
@ -122,12 +117,12 @@ class MultiLabelClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be will be used to classify. The shape of every item should be
``(num_samples, num_classes)``. ``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None. the input data samples. Defaults to None.
Returns: Returns:
List[ClsDataSample]: A list of data samples which contains the List[DataSample]: A list of data samples which contains the
predicted results. predicted results.
""" """
# The part can be traced by torch.fx # The part can be traced by torch.fx
@ -138,7 +133,7 @@ class MultiLabelClsHead(BaseHead):
return predictions return predictions
def _get_predictions(self, cls_score: torch.Tensor, def _get_predictions(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample]): data_samples: List[DataSample]):
"""Post-process the output of head. """Post-process the output of head.
Including softmax and set ``pred_label`` of data samples. Including softmax and set ``pred_label`` of data samples.
@ -146,7 +141,7 @@ class MultiLabelClsHead(BaseHead):
pred_scores = torch.sigmoid(cls_score) pred_scores = torch.sigmoid(cls_score)
if data_samples is None: if data_samples is None:
data_samples = [ClsDataSample() for _ in range(cls_score.size(0))] data_samples = [DataSample() for _ in range(cls_score.size(0))]
for data_sample, score in zip(data_samples, pred_scores): for data_sample, score in zip(data_samples, pred_scores):
if self.thr is not None: if self.thr is not None:

View File

@ -63,7 +63,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
- "tensor": Forward the whole network and return tensor without any - "tensor": Forward the whole network and return tensor without any
post-processing, same as a common nn.Module. post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully - "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`ClsDataSample`. processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given - "loss": Forward and return a dict of losses according to the given
inputs and data samples. inputs and data samples.
@ -73,7 +73,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
Args: Args:
inputs (torch.Tensor, tuple): The input tensor with shape inputs (torch.Tensor, tuple): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``. data of every samples. It's required if ``mode="loss"``.
Defaults to None. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'. mode (str): Return what kind of value. Defaults to 'tensor'.
@ -83,7 +83,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
- If ``mode="tensor"``, return a tensor. - If ``mode="tensor"``, return a tensor.
- If ``mode="predict"``, return a list of - If ``mode="predict"``, return a list of
:obj:`mmpretrain.structures.ClsDataSample`. :obj:`mmpretrain.structures.DataSample`.
- If ``mode="loss"``, return a dict of tensor. - If ``mode="loss"``, return a dict of tensor.
""" """
pass pass
@ -107,7 +107,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
Returns: Returns:

View File

@ -8,7 +8,7 @@ from mmengine.runner import Runner
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process from mmpretrain.utils import track_on_main_process
from .base import BaseRetriever from .base import BaseRetriever
@ -114,7 +114,7 @@ class ImageToImageRetriever(BaseRetriever):
def forward(self, def forward(self,
inputs: torch.Tensor, inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None, data_samples: Optional[List[DataSample]] = None,
mode: str = 'tensor'): mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test. """The unified entry for a forward process in both training and test.
@ -123,7 +123,7 @@ class ImageToImageRetriever(BaseRetriever):
- "tensor": Forward the whole network and return tensor without any - "tensor": Forward the whole network and return tensor without any
post-processing, same as a common nn.Module. post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully - "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`ClsDataSample`. processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given - "loss": Forward and return a dict of losses according to the given
inputs and data samples. inputs and data samples.
@ -133,7 +133,7 @@ class ImageToImageRetriever(BaseRetriever):
Args: Args:
inputs (torch.Tensor, tuple): The input tensor with shape inputs (torch.Tensor, tuple): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``. data of every samples. It's required if ``mode="loss"``.
Defaults to None. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'. mode (str): Return what kind of value. Defaults to 'tensor'.
@ -143,7 +143,7 @@ class ImageToImageRetriever(BaseRetriever):
- If ``mode="tensor"``, return a tensor. - If ``mode="tensor"``, return a tensor.
- If ``mode="predict"``, return a list of - If ``mode="predict"``, return a list of
:obj:`mmpretrain.structures.ClsDataSample`. :obj:`mmpretrain.structures.DataSample`.
- If ``mode="loss"``, return a dict of tensor. - If ``mode="loss"``, return a dict of tensor.
""" """
if mode == 'tensor': if mode == 'tensor':
@ -169,13 +169,13 @@ class ImageToImageRetriever(BaseRetriever):
return feat return feat
def loss(self, inputs: torch.Tensor, def loss(self, inputs: torch.Tensor,
data_samples: List[ClsDataSample]) -> dict: data_samples: List[DataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples. """Calculate losses from a batch of inputs and data samples.
Args: Args:
inputs (torch.Tensor): The input tensor with shape inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general. (N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of data_samples (List[DataSample]): The annotation data of
every samples. every samples.
Returns: Returns:
@ -200,18 +200,18 @@ class ImageToImageRetriever(BaseRetriever):
def predict(self, def predict(self,
inputs: tuple, inputs: tuple,
data_samples: Optional[List[ClsDataSample]] = None, data_samples: Optional[List[DataSample]] = None,
**kwargs) -> List[ClsDataSample]: **kwargs) -> List[DataSample]:
"""Predict results from the extracted features. """Predict results from the extracted features.
Args: Args:
inputs (tuple): The features extracted from the backbone. inputs (tuple): The features extracted from the backbone.
data_samples (List[ClsDataSample], optional): The annotation data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None. data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict`` **kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`. method of :attr:`head`.
Returns: Returns:
List[ClsDataSample]: the raw data_samples with List[DataSample]: the raw data_samples with
the predicted results the predicted results
""" """
if not self.prototype_inited: if not self.prototype_inited:
@ -240,8 +240,8 @@ class ImageToImageRetriever(BaseRetriever):
else: else:
data_samples = [] data_samples = []
for score, label in zip(pred_scores, pred_labels): for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score( data_samples.append(
score).set_pred_label(label)) DataSample().set_pred_score(score).set_pred_label(label))
return data_samples return data_samples
def _get_prototype_vecs_from_dataloader(self, data_loader): def _get_prototype_vecs_from_dataloader(self, data_loader):

View File

@ -4,7 +4,7 @@ from typing import List
from mmengine.model import BaseTTAModel from mmengine.model import BaseTTAModel
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
@MODELS.register_module() @MODELS.register_module()
@ -12,16 +12,16 @@ class AverageClsScoreTTA(BaseTTAModel):
def merge_preds( def merge_preds(
self, self,
data_samples_list: List[List[ClsDataSample]], data_samples_list: List[List[DataSample]],
) -> List[ClsDataSample]: ) -> List[DataSample]:
"""Merge predictions of enhanced data to one prediction. """Merge predictions of enhanced data to one prediction.
Args: Args:
data_samples_list (List[List[ClsDataSample]]): List of predictions data_samples_list (List[List[DataSample]]): List of predictions
of all enhanced data. of all enhanced data.
Returns: Returns:
List[ClsDataSample]: Merged prediction. List[DataSample]: Merged prediction.
""" """
merged_data_samples = [] merged_data_samples = []
for data_samples in data_samples_list: for data_samples in data_samples_list:
@ -29,8 +29,8 @@ class AverageClsScoreTTA(BaseTTAModel):
return merged_data_samples return merged_data_samples
def _merge_single_sample(self, data_samples): def _merge_single_sample(self, data_samples):
merged_data_sample: ClsDataSample = data_samples[0].new() merged_data_sample: DataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_label.score merged_score = sum(data_sample.pred_score
for data_sample in data_samples) / len(data_samples) for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score) merged_data_sample.set_pred_score(merged_score)
return merged_data_sample return merged_data_sample

View File

@ -8,9 +8,9 @@ import torch.nn.functional as F
from mmengine.model import BaseDataPreprocessor, stack_batch from mmengine.model import BaseDataPreprocessor, stack_batch
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import (ClsDataSample, MultiTaskDataSample, from mmpretrain.structures import (DataSample, MultiTaskDataSample,
batch_label_to_onehot, cat_batch_labels, batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split) tensor_split)
from .batch_augments import RandomBatchAugment from .batch_augments import RandomBatchAugment
@ -153,23 +153,28 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
data_samples = data.get('data_samples', None) data_samples = data.get('data_samples', None)
sample_item = data_samples[0] if data_samples is not None else None sample_item = data_samples[0] if data_samples is not None else None
if isinstance(sample_item,
ClsDataSample) and 'gt_label' in sample_item:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(
gt_labels, device=self.device)
batch_score = stack_batch_scores(gt_labels, device=self.device) if isinstance(sample_item, DataSample):
if batch_score is None and self.to_onehot: batch_label = None
batch_score = None
if 'gt_label' in sample_item:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(gt_labels)
batch_label = batch_label.to(self.device)
if 'gt_score' in sample_item:
gt_scores = [sample.gt_score for sample in data_samples]
batch_score = torch.stack(gt_scores).to(self.device)
elif self.to_onehot:
assert batch_label is not None, \ assert batch_label is not None, \
'Cannot generate onehot format labels because no labels.' 'Cannot generate onehot format labels because no labels.'
num_classes = self.num_classes or data_samples[0].get( num_classes = self.num_classes or sample_item.get(
'num_classes') 'num_classes')
assert num_classes is not None, \ assert num_classes is not None, \
'Cannot generate one-hot format labels because not set ' \ 'Cannot generate one-hot format labels because not set ' \
'`num_classes` in `data_preprocessor`.' '`num_classes` in `data_preprocessor`.'
batch_score = batch_label_to_onehot(batch_label, label_indices, batch_score = batch_label_to_onehot(
num_classes) batch_label, label_indices, num_classes).to(self.device)
# ----- Batch Augmentations ---- # ----- Batch Augmentations ----
if training and self.batch_augments is not None: if training and self.batch_augments is not None:

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .cls_data_sample import ClsDataSample from .data_sample import DataSample
from .multi_task_data_sample import MultiTaskDataSample from .multi_task_data_sample import MultiTaskDataSample
from .utils import (batch_label_to_onehot, cat_batch_labels, from .utils import (batch_label_to_onehot, cat_batch_labels, label_to_onehot,
stack_batch_scores, tensor_split) tensor_split)
__all__ = [ __all__ = [
'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels', 'DataSample', 'batch_label_to_onehot', 'cat_batch_labels', 'tensor_split',
'stack_batch_scores', 'tensor_split', 'MultiTaskDataSample' 'MultiTaskDataSample', 'label_to_onehot'
] ]

View File

@ -1,235 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing.reduction import ForkingPickler
from numbers import Number
from typing import Sequence, Union
import numpy as np
import torch
from mmengine.structures import BaseDataElement, LabelData
from mmengine.utils import is_str
def format_label(
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
"""Convert various python types to label-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
Returns:
:obj:`torch.Tensor`: The foramtted label tensor.
"""
# Handle single number
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
value = int(value.item())
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).to(torch.long)
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).to(torch.long)
elif isinstance(value, int):
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
def format_score(
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
"""Convert various python types to score-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
Returns:
:obj:`torch.Tensor`: The foramtted score tensor.
"""
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).float()
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).float()
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
class ClsDataSample(BaseDataElement):
"""A data structure interface of classification task.
It's used as interfaces between different components.
Meta fields:
img_shape (Tuple): The shape of the corresponding input image.
Used for visualization.
ori_shape (Tuple): The original shape of the corresponding image.
Used for visualization.
num_classes (int): The number of all categories.
Used for label format conversion.
Data fields:
gt_label (:obj:`~mmengine.structures.LabelData`): The ground truth
label.
pred_label (:obj:`~mmengine.structures.LabelData`): The predicted
label.
scores (torch.Tensor): The outputs of model.
logits (torch.Tensor): The outputs of model without softmax nor
sigmoid.
Examples:
>>> import torch
>>> from mmpretrain.structures import ClsDataSample
>>>
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
>>> data_sample = ClsDataSample(metainfo=img_meta)
>>> data_sample.set_gt_label(3)
>>> print(data_sample)
<ClsDataSample(
META INFORMATION
num_classes = 5
img_shape = (960, 720)
DATA FIELDS
gt_label: <LabelData(
META INFORMATION
num_classes: 5
DATA FIELDS
label: tensor([3])
) at 0x7f21fb1b9190>
) at 0x7f21fb1b9880>
>>> # For multi-label data
>>> data_sample.set_gt_label([0, 1, 4])
>>> print(data_sample.gt_label)
<LabelData(
META INFORMATION
num_classes: 5
DATA FIELDS
label: tensor([0, 1, 4])
) at 0x7fd7d1b41970>
>>> # Set one-hot format score
>>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])
>>> data_sample.set_pred_score(score)
>>> print(data_sample.pred_label)
<LabelData(
META INFORMATION
num_classes: 5
DATA FIELDS
score: tensor([0.1, 0.1, 0.6, 0.1, 0.1])
) at 0x7fd7d1b41970>
"""
def set_gt_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ClsDataSample':
"""Set label of ``gt_label``."""
label_data = getattr(self, '_gt_label', LabelData())
label_data.label = format_label(value)
self.gt_label = label_data
return self
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``gt_label``."""
label_data = getattr(self, '_gt_label', LabelData())
label_data.score = format_score(value)
if hasattr(self, 'num_classes'):
assert len(label_data.score) == self.num_classes, \
f'The length of score {len(label_data.score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes',
value=len(label_data.score),
field_type='metainfo')
self.gt_label = label_data
return self
def set_pred_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ClsDataSample':
"""Set label of ``pred_label``."""
label_data = getattr(self, '_pred_label', LabelData())
label_data.label = format_label(value)
self.pred_label = label_data
return self
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample':
"""Set score of ``pred_label``."""
label_data = getattr(self, '_pred_label', LabelData())
label_data.score = format_score(value)
if hasattr(self, 'num_classes'):
assert len(label_data.score) == self.num_classes, \
f'The length of score {len(label_data.score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes',
value=len(label_data.score),
field_type='metainfo')
self.pred_label = label_data
return self
@property
def gt_label(self):
return self._gt_label
@gt_label.setter
def gt_label(self, value: LabelData):
self.set_field(value, '_gt_label', dtype=LabelData)
@gt_label.deleter
def gt_label(self):
del self._gt_label
@property
def pred_label(self):
return self._pred_label
@pred_label.setter
def pred_label(self, value: LabelData):
self.set_field(value, '_pred_label', dtype=LabelData)
@pred_label.deleter
def pred_label(self):
del self._pred_label
def _reduce_cls_datasample(data_sample):
"""reduce ClsDataSample."""
attr_dict = data_sample.__dict__
convert_keys = []
for k, v in attr_dict.items():
if isinstance(v, LabelData):
attr_dict[k] = v.numpy()
convert_keys.append(k)
return _rebuild_cls_datasample, (attr_dict, convert_keys)
def _rebuild_cls_datasample(attr_dict, convert_keys):
"""rebuild ClsDataSample."""
data_sample = ClsDataSample()
for k in convert_keys:
attr_dict[k] = attr_dict[k].to_tensor()
data_sample.__dict__ = attr_dict
return data_sample
# Due to the multi-processing strategy of PyTorch, ClsDataSample may consume
# many file descriptors because it contains multiple LabelData with tensors.
# Here we overwrite the reduce function of ClsDataSample in ForkingPickler and
# convert these tensors to np.ndarray during pickling. It may influence the
# performance of dataloader, but slightly because these tensors in LabelData
# are very small.
ForkingPickler.register(ClsDataSample, _reduce_cls_datasample)

View File

@ -0,0 +1,167 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing.reduction import ForkingPickler
from typing import Union
import numpy as np
import torch
from mmengine.structures import BaseDataElement
from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score
class DataSample(BaseDataElement):
"""A general data structure interface.
It's used as the interface between different components.
The following fields are convention names in MMPretrain, and we will set or
get these fields in data transforms, models, and metrics if needed. You can
also set any new fields for your need.
Meta fields:
img_shape (Tuple): The shape of the corresponding input image.
ori_shape (Tuple): The original shape of the corresponding image.
sample_idx (int): The index of the sample in the dataset.
num_classes (int): The number of all categories.
Data fields:
gt_label (tensor): The ground truth label.
gt_score (tensor): The ground truth score.
pred_label (tensor): The predicted label.
pred_score (tensor): The predicted score.
mask (tensor): The mask used in masked image modeling.
Examples:
>>> import torch
>>> from mmpretrain.structures import DataSample
>>>
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
>>> data_sample = DataSample(metainfo=img_meta)
>>> data_sample.set_gt_label(3)
>>> print(data_sample)
<DataSample(
META INFORMATION
num_classes: 5
img_shape: (960, 720)
DATA FIELDS
gt_label: tensor([3])
) at 0x7ff64c1c1d30>
>>>
>>> # For multi-label data
>>> data_sample = DataSample().set_gt_label([0, 1, 4])
>>> print(data_sample)
<DataSample(
DATA FIELDS
gt_label: tensor([0, 1, 4])
) at 0x7ff5b490e100>
>>>
>>> # Set one-hot format score
>>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1])
>>> print(data_sample)
<DataSample(
META INFORMATION
num_classes: 4
DATA FIELDS
pred_score: tensor([0.1000, 0.1000, 0.6000, 0.1000])
) at 0x7ff5b48ef6a0>
>>>
>>> # Set custom field
>>> data_sample = DataSample()
>>> data_sample.my_field = [1, 2, 3]
>>> print(data_sample)
<DataSample(
DATA FIELDS
my_field: [1, 2, 3]
) at 0x7f8e9603d3a0>
>>> print(data_sample.my_field)
[1, 2, 3]
"""
def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample':
"""Set ``gt_label``."""
self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor)
return self
def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample':
"""Set ``gt_score``."""
score = format_score(value)
self.set_field(score, 'gt_score', dtype=torch.Tensor)
if hasattr(self, 'num_classes'):
assert len(score) == self.num_classes, \
f'The length of score {len(score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes', value=len(score), field_type='metainfo')
return self
def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample':
"""Set ``pred_label``."""
self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor)
return self
def set_pred_score(self, value: SCORE_TYPE):
"""Set ``pred_label``."""
score = format_score(value)
self.set_field(score, 'pred_score', dtype=torch.Tensor)
if hasattr(self, 'num_classes'):
assert len(score) == self.num_classes, \
f'The length of score {len(score)} should be '\
f'equal to the num_classes {self.num_classes}.'
else:
self.set_field(
name='num_classes', value=len(score), field_type='metainfo')
return self
def set_mask(self, value: Union[torch.Tensor, np.ndarray]):
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Invalid mask type {type(value)}')
self.set_field(value, 'mask', dtype=torch.Tensor)
return self
def __repr__(self) -> str:
"""Represent the object."""
def dump_items(items, prefix=''):
return '\n'.join(f'{prefix}{k}: {v}' for k, v in items)
repr_ = ''
if len(self._metainfo_fields) > 0:
repr_ += '\n\nMETA INFORMATION\n'
repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4)
if len(self._data_fields) > 0:
repr_ += '\n\nDATA FIELDS\n'
repr_ += dump_items(self.items(), prefix=' ' * 4)
repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>'
return repr_
def _reduce_datasample(data_sample):
"""reduce DataSample."""
attr_dict = data_sample.__dict__
convert_keys = []
for k, v in attr_dict.items():
if isinstance(v, torch.Tensor):
attr_dict[k] = v.numpy()
convert_keys.append(k)
return _rebuild_datasample, (attr_dict, convert_keys)
def _rebuild_datasample(attr_dict, convert_keys):
"""rebuild DataSample."""
data_sample = DataSample()
for k in convert_keys:
attr_dict[k] = torch.from_numpy(attr_dict[k])
data_sample.__dict__ = attr_dict
return data_sample
# Due to the multi-processing strategy of PyTorch, DataSample may consume many
# file descriptors because it contains multiple tensors. Here we overwrite the
# reduce function of DataSample in ForkingPickler and convert these tensors to
# np.ndarray during pickling. It may slightly influence the performance of
# dataloader.
ForkingPickler.register(DataSample, _reduce_datasample)

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List from typing import List, Sequence, Union
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from mmengine.structures import LabelData from mmengine.utils import is_str
if hasattr(torch, 'tensor_split'): if hasattr(torch, 'tensor_split'):
tensor_split = torch.tensor_split tensor_split = torch.tensor_split
@ -16,30 +17,82 @@ else:
return outs return outs
def cat_batch_labels(elements: List[LabelData], device=None): LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int]
"""Concat the ``label`` of a batch of :obj:`LabelData` to a tensor. SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence]
def format_label(value: LABEL_TYPE) -> torch.Tensor:
"""Convert various python types to label-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int`.
Args: Args:
elements (List[LabelData]): A batch of :obj`LabelData`. value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
device (torch.device, optional): The output device of the batch label.
Defaults to None. Returns:
:obj:`torch.Tensor`: The foramtted label tensor.
"""
# Handle single number
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
value = int(value.item())
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).to(torch.long)
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).to(torch.long)
elif isinstance(value, int):
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
def format_score(value: SCORE_TYPE) -> torch.Tensor:
"""Convert various python types to score-format tensor.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
Returns:
:obj:`torch.Tensor`: The foramtted score tensor.
"""
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).float()
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).float()
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
return value
def cat_batch_labels(elements: List[torch.Tensor]):
"""Concat a batch of label tensor to one tensor.
Args:
elements (List[tensor]): A batch of labels.
Returns: Returns:
Tuple[torch.Tensor, List[int]]: The first item is the concated label Tuple[torch.Tensor, List[int]]: The first item is the concated label
tensor, and the second item is the split indices of every sample. tensor, and the second item is the split indices of every sample.
""" """
item = elements[0]
if 'label' not in item._data_fields:
return None, None
labels = [] labels = []
splits = [0] splits = [0]
for element in elements: for element in elements:
labels.append(element.label) labels.append(element)
splits.append(splits[-1] + element.label.size(0)) splits.append(splits[-1] + element.size(0))
batch_label = torch.cat(labels) batch_label = torch.cat(labels)
if device is not None:
batch_label = batch_label.to(device=device)
return batch_label, splits[1:-1] return batch_label, splits[1:-1]
@ -75,22 +128,26 @@ def batch_label_to_onehot(batch_label, split_indices, num_classes):
return torch.stack(onehot_list) return torch.stack(onehot_list)
def stack_batch_scores(elements, device=None): def label_to_onehot(label: LABEL_TYPE, num_classes: int):
"""Stack the ``score`` of a batch of :obj:`LabelData` to a tensor. """Convert a label to onehot format tensor.
Args: Args:
elements (List[LabelData]): A batch of :obj`LabelData`. label (LABEL_TYPE): Label value.
device (torch.device, optional): The output device of the batch label. num_classes (int): The number of classes.
Defaults to None.
Returns: Returns:
torch.Tensor: The stacked score tensor. torch.Tensor: The onehot format label tensor.
"""
item = elements[0]
if 'score' not in item._data_fields:
return None
batch_score = torch.stack([element.score for element in elements]) Examples:
if device is not None: >>> import torch
batch_score = batch_score.to(device) >>> from mmpretrain.structures import label_to_onehot
return batch_score >>> # Single-label
>>> label_to_onehot(1, num_classes=5)
tensor([0, 1, 0, 0, 0])
>>> # Multi-label
>>> label_to_onehot([0, 2, 3], num_classes=5)
tensor([1, 0, 1, 1, 0])
"""
label = format_label(label)
sparse_onehot = F.one_hot(label, num_classes)
return sparse_onehot.sum(0)

View File

@ -7,7 +7,7 @@ from mmengine.dist import master_only
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
from mmpretrain.registry import VISUALIZERS from mmpretrain.registry import VISUALIZERS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
def _get_adaptive_scale(img_shape: Tuple[int, int], def _get_adaptive_scale(img_shape: Tuple[int, int],
@ -57,11 +57,11 @@ class ClsVisualizer(Visualizer):
>>> import mmcv >>> import mmcv
>>> from pathlib import Path >>> from pathlib import Path
>>> from mmpretrain.visualization import ClsVisualizer >>> from mmpretrain.visualization import ClsVisualizer
>>> from mmpretrain.structures import ClsDataSample >>> from mmpretrain.structures import DataSample
>>> # Example image >>> # Example image
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb') >>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
>>> # Example annotation >>> # Example annotation
>>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\ >>> data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
... set_pred_score(torch.tensor([0.1, 0.8, 0.1])) ... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
>>> # Setup the visualizer >>> # Setup the visualizer
>>> vis = ClsVisualizer( >>> vis = ClsVisualizer(
@ -84,7 +84,7 @@ class ClsVisualizer(Visualizer):
def add_datasample(self, def add_datasample(self,
name: str, name: str,
image: np.ndarray, image: np.ndarray,
data_sample: Optional[ClsDataSample] = None, data_sample: Optional[DataSample] = None,
draw_gt: bool = True, draw_gt: bool = True,
draw_pred: bool = True, draw_pred: bool = True,
draw_score: bool = True, draw_score: bool = True,
@ -104,7 +104,7 @@ class ClsVisualizer(Visualizer):
Args: Args:
name (str): The image identifier. name (str): The image identifier.
image (np.ndarray): The image to draw. image (np.ndarray): The image to draw.
data_sample (:obj:`ClsDataSample`, optional): The annotation of the data_sample (:obj:`DataSample`, optional): The annotation of the
image. Defaults to None. image. Defaults to None.
draw_gt (bool): Whether to draw ground truth labels. draw_gt (bool): Whether to draw ground truth labels.
Defaults to True. Defaults to True.
@ -137,8 +137,7 @@ class ClsVisualizer(Visualizer):
self.set_image(image) self.set_image(image)
if draw_gt and 'gt_label' in data_sample: if draw_gt and 'gt_label' in data_sample:
gt_label = data_sample.gt_label idx = data_sample.gt_label.tolist()
idx = gt_label.label.tolist()
class_labels = [''] * len(idx) class_labels = [''] * len(idx)
if classes is not None: if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx] class_labels = [f' ({classes[i]})' for i in idx]
@ -147,13 +146,12 @@ class ClsVisualizer(Visualizer):
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels)) texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
if draw_pred and 'pred_label' in data_sample: if draw_pred and 'pred_label' in data_sample:
pred_label = data_sample.pred_label idx = data_sample.pred_label.tolist()
idx = pred_label.label.tolist()
score_labels = [''] * len(idx) score_labels = [''] * len(idx)
class_labels = [''] * len(idx) class_labels = [''] * len(idx)
if draw_score and 'score' in pred_label: if draw_score and 'pred_score' in data_sample:
score_labels = [ score_labels = [
f', {pred_label.score[i].item():.2f}' for i in idx f', {data_sample.pred_score[i].item():.2f}' for i in idx
] ]
if classes is not None: if classes is not None:

View File

@ -9,7 +9,7 @@ from mmcv.image import imread
from mmpretrain.apis import (ImageClassificationInferencer, ModelHub, from mmpretrain.apis import (ImageClassificationInferencer, ModelHub,
get_model, inference_model) get_model, inference_model)
from mmpretrain.models import MobileNetV3 from mmpretrain.models import MobileNetV3
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer from mmpretrain.visualization import ClsVisualizer
MODEL = 'mobilenet-v3-small-050_3rdparty_in1k' MODEL = 'mobilenet-v3-small-050_3rdparty_in1k'
@ -58,7 +58,7 @@ class TestImageClassificationInferencer(TestCase):
# test return_datasample=True # test return_datasample=True
results = inferencer(img, return_datasamples=True)[0] results = inferencer(img, return_datasamples=True)[0]
self.assertIsInstance(results, ClsDataSample) self.assertIsInstance(results, DataSample)
def test_visualize(self): def test_visualize(self):
img_path = osp.join(osp.dirname(__file__), '../data/color.jpg') img_path = osp.join(osp.dirname(__file__), '../data/color.jpg')

View File

@ -6,11 +6,10 @@ import unittest
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
from mmengine.structures import LabelData
from PIL import Image from PIL import Image
from mmpretrain.registry import TRANSFORMS from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample from mmpretrain.structures import DataSample, MultiTaskDataSample
class TestPackClsInputs(unittest.TestCase): class TestPackClsInputs(unittest.TestCase):
@ -34,9 +33,9 @@ class TestPackClsInputs(unittest.TestCase):
self.assertIn('inputs', results) self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor) self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_samples', results) self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], ClsDataSample) self.assertIsInstance(results['data_samples'], DataSample)
self.assertIn('flip', results['data_samples'].metainfo_keys()) self.assertIn('flip', results['data_samples'].metainfo_keys())
self.assertIsInstance(results['data_samples'].gt_label, LabelData) self.assertIsInstance(results['data_samples'].gt_label, torch.Tensor)
# Test grayscale image # Test grayscale image
data['img'] = data['img'].mean(-1) data['img'] = data['img'].mean(-1)
@ -155,7 +154,7 @@ class TestPackMultiTaskInputs(unittest.TestCase):
self.assertIsInstance(results['data_samples'], MultiTaskDataSample) self.assertIsInstance(results['data_samples'], MultiTaskDataSample)
self.assertIn('flip', results['data_samples'].task1.metainfo_keys()) self.assertIn('flip', results['data_samples'].task1.metainfo_keys())
self.assertIsInstance(results['data_samples'].task1.gt_label, self.assertIsInstance(results['data_samples'].task1.gt_label,
LabelData) torch.Tensor)
# Test grayscale image # Test grayscale image
data['img'] = data['img'].mean(-1) data['img'] = data['img'].mean(-1)

View File

@ -14,7 +14,7 @@ from mmengine.runner import Runner
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from mmpretrain.registry import HOOKS from mmpretrain.registry import HOOKS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
class ExampleDataset(Dataset): class ExampleDataset(Dataset):
@ -36,7 +36,7 @@ class MockDataPreprocessor(BaseDataPreprocessor):
def forward(self, data, training): def forward(self, data, training):
return data['imgs'], ClsDataSample() return data['imgs'], DataSample()
class ExampleModel(BaseModel): class ExampleModel(BaseModel):

View File

@ -9,7 +9,7 @@ from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
from mmpretrain.engine import VisualizationHook from mmpretrain.engine import VisualizationHook
from mmpretrain.registry import HOOKS from mmpretrain.registry import HOOKS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer from mmpretrain.visualization import ClsVisualizer
@ -18,7 +18,7 @@ class TestVisualizationHook(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
ClsVisualizer.get_instance('visualizer') ClsVisualizer.get_instance('visualizer')
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2) data_sample = DataSample().set_gt_label(1).set_pred_label(2)
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'}) data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
self.data_batch = { self.data_batch = {
'inputs': torch.randint(0, 256, (10, 3, 224, 224)), 'inputs': torch.randint(0, 256, (10, 3, 224, 224)),
@ -53,7 +53,7 @@ class TestVisualizationHook(TestCase):
cfg = dict(type='VisualizationHook', enable=True) cfg = dict(type='VisualizationHook', enable=True)
hook: VisualizationHook = HOOKS.build(cfg) hook: VisualizationHook = HOOKS.build(cfg)
with patch.object(hook._visualizer, 'add_datasample') as mock: with patch.object(hook._visualizer, 'add_datasample') as mock:
outputs = [ClsDataSample()] * 10 outputs = [DataSample()] * 10
hook._draw_samples(0, self.data_batch, outputs, step=0) hook._draw_samples(0, self.data_batch, outputs, step=0)
mock.assert_called_once_with( mock.assert_called_once_with(
'0', image=ANY, data_sample=outputs[0], step=0, show=False) '0', image=ANY, data_sample=outputs[0], step=0, show=False)

View File

@ -8,7 +8,7 @@ from mmengine.evaluator import Evaluator
from mmengine.registry import init_default_scope from mmengine.registry import init_default_scope
from mmpretrain.evaluation.metrics import AveragePrecision, MultiLabelMetric from mmpretrain.evaluation.metrics import AveragePrecision, MultiLabelMetric
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
init_default_scope('mmpretrain') init_default_scope('mmpretrain')
@ -152,7 +152,7 @@ class TestMultiLabel(TestCase):
]) ])
pred = [ pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j) DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
for i, j in zip(y_pred_score, y_true) for i, j in zip(y_pred_score, y_true)
] ]
@ -261,7 +261,7 @@ class TestMultiLabel(TestCase):
# Test with gt_score # Test with gt_score
pred = [ pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j) DataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
for i, j in zip(y_pred_score, y_true_binary) for i, j in zip(y_pred_score, y_true_binary)
] ]
@ -304,7 +304,7 @@ class TestAveragePrecision(TestCase):
]) ])
pred = [ pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j) DataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
for i, j in zip(y_pred, y_true) for i, j in zip(y_pred, y_true)
] ]
@ -328,7 +328,7 @@ class TestAveragePrecision(TestCase):
# Test with gt_label without score # Test with gt_label without score
pred = [ pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j) DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
for i, j in zip(y_pred, [[0, 1], [1], [2], [0]]) for i, j in zip(y_pred, [[0, 1], [1], [2], [0]])
] ]
evaluator = Evaluator(dict(type='AveragePrecision')) evaluator = Evaluator(dict(type='AveragePrecision'))

View File

@ -4,7 +4,7 @@ from unittest import TestCase
import torch import torch
from mmpretrain.evaluation.metrics import MultiTasksMetric from mmpretrain.evaluation.metrics import MultiTasksMetric
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
class MultiTaskMetric(TestCase): class MultiTaskMetric(TestCase):
@ -24,7 +24,7 @@ class MultiTaskMetric(TestCase):
for i, pred in enumerate(data_pred): for i, pred in enumerate(data_pred):
sample = {} sample = {}
for task_name in pred: for task_name in pred:
task_sample = ClsDataSample().set_pred_score(pred[task_name]) task_sample = DataSample().set_pred_score(pred[task_name])
if task_name in data_gt[i]: if task_name in data_gt[i]:
task_sample.set_gt_label(data_gt[i][task_name]) task_sample.set_gt_label(data_gt[i][task_name])
task_sample.set_field(True, 'eval_mask', field_type='metainfo') task_sample.set_field(True, 'eval_mask', field_type='metainfo')
@ -68,7 +68,7 @@ class MultiTaskMetric(TestCase):
sample = {} sample = {}
for task_name in score: for task_name in score:
if type(score[task_name]) != dict: if type(score[task_name]) != dict:
task_sample = ClsDataSample().set_pred_score(score[task_name]) task_sample = DataSample().set_pred_score(score[task_name])
task_sample.set_gt_label(label[task_name]) task_sample.set_gt_label(label[task_name])
sample[task_name] = task_sample.to_dict() sample[task_name] = task_sample.to_dict()
sample[task_name]['eval_mask'] = True sample[task_name]['eval_mask'] = True
@ -76,7 +76,7 @@ class MultiTaskMetric(TestCase):
sample[task_name] = {} sample[task_name] = {}
sample[task_name]['eval_mask'] = True sample[task_name]['eval_mask'] = True
for task_name2 in score[task_name]: for task_name2 in score[task_name]:
task_sample = ClsDataSample().set_pred_score( task_sample = DataSample().set_pred_score(
score[task_name][task_name2]) score[task_name][task_name2])
task_sample.set_gt_label(label[task_name][task_name2]) task_sample.set_gt_label(label[task_name][task_name2])
sample[task_name][task_name2] = task_sample.to_dict() sample[task_name][task_name2] = task_sample.to_dict()

View File

@ -6,7 +6,7 @@ import torch
from mmpretrain.evaluation.metrics import RetrievalRecall from mmpretrain.evaluation.metrics import RetrievalRecall
from mmpretrain.registry import METRICS from mmpretrain.registry import METRICS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
class TestRetrievalRecall(TestCase): class TestRetrievalRecall(TestCase):
@ -14,7 +14,7 @@ class TestRetrievalRecall(TestCase):
def test_evaluate(self): def test_evaluate(self):
"""Test using the metric in the same way as Evalutor.""" """Test using the metric in the same way as Evalutor."""
pred = [ pred = [
ClsDataSample().set_pred_score(i).set_gt_label(k).to_dict() DataSample().set_pred_score(i).set_gt_label(k).to_dict()
for i, k in zip([ for i, k in zip([
torch.tensor([0.7, 0.0, 0.3]), torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]), torch.tensor([0.5, 0.2, 0.3]),

View File

@ -8,7 +8,7 @@ import torch
from mmpretrain.evaluation.metrics import (Accuracy, ConfusionMatrix, from mmpretrain.evaluation.metrics import (Accuracy, ConfusionMatrix,
SingleLabelMetric) SingleLabelMetric)
from mmpretrain.registry import METRICS from mmpretrain.registry import METRICS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
class TestAccuracy(TestCase): class TestAccuracy(TestCase):
@ -16,7 +16,7 @@ class TestAccuracy(TestCase):
def test_evaluate(self): def test_evaluate(self):
"""Test using the metric in the same way as Evalutor.""" """Test using the metric in the same way as Evalutor."""
pred = [ pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label( DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([ k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]), torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]), torch.tensor([0.5, 0.2, 0.3]),
@ -52,7 +52,7 @@ class TestAccuracy(TestCase):
# Test with label # Test with label
for sample in pred: for sample in pred:
del sample['pred_label']['score'] del sample['pred_score']
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None))) metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
metric.process(None, pred) metric.process(None, pred)
acc = metric.evaluate(6) acc = metric.evaluate(6)
@ -123,7 +123,7 @@ class TestSingleLabel(TestCase):
def test_evaluate(self): def test_evaluate(self):
"""Test using the metric in the same way as Evalutor.""" """Test using the metric in the same way as Evalutor."""
pred = [ pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label( DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([ k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]), torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]), torch.tensor([0.5, 0.2, 0.3]),
@ -212,7 +212,7 @@ class TestSingleLabel(TestCase):
# Test with label, the thrs will be ignored # Test with label, the thrs will be ignored
pred_no_score = copy.deepcopy(pred) pred_no_score = copy.deepcopy(pred)
for sample in pred_no_score: for sample in pred_no_score:
del sample['pred_label']['score'] del sample['pred_score']
del sample['num_classes'] del sample['num_classes']
metric = METRICS.build( metric = METRICS.build(
dict(type='SingleLabelMetric', thrs=(0., 0.6), num_classes=3)) dict(type='SingleLabelMetric', thrs=(0., 0.6), num_classes=3))
@ -304,7 +304,7 @@ class TestConfusionMatrix(TestCase):
def test_evaluate(self): def test_evaluate(self):
"""Test using the metric in the same way as Evalutor.""" """Test using the metric in the same way as Evalutor."""
pred = [ pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label( DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([ k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]), torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]), torch.tensor([0.5, 0.2, 0.3]),
@ -330,7 +330,7 @@ class TestConfusionMatrix(TestCase):
# Test with label # Test with label
for sample in pred: for sample in pred:
del sample['pred_label']['score'] del sample['pred_score']
metric = METRICS.build(dict(type='ConfusionMatrix')) metric = METRICS.build(dict(type='ConfusionMatrix'))
metric.process(None, pred) metric.process(None, pred)
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(AssertionError,

View File

@ -7,7 +7,7 @@ import torch
from mmengine.evaluator import Evaluator from mmengine.evaluator import Evaluator
from mmengine.registry import init_default_scope from mmengine.registry import init_default_scope
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
init_default_scope('mmpretrain') init_default_scope('mmpretrain')
@ -27,7 +27,7 @@ class TestVOCMultiLabel(TestCase):
# generate data samples # generate data samples
pred = [ pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j) DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
for i, j in zip(y_pred_score, y_true_label) for i, j in zip(y_pred_score, y_true_label)
] ]
for sample, difficult_label in zip(pred, y_true_difficult): for sample, difficult_label in zip(pred, y_true_difficult):
@ -155,7 +155,7 @@ class TestVOCAveragePrecision(TestCase):
# generate data samples # generate data samples
pred = [ pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score( DataSample(num_classes=4).set_pred_score(i).set_gt_score(
j).set_gt_label(k) j).set_gt_label(k)
for i, j, k in zip(y_pred_score, y_true, y_true_label) for i, j, k in zip(y_pred_score, y_true, y_true_label)
] ]

View File

@ -9,7 +9,7 @@ from mmengine import ConfigDict
from mmpretrain.models import ImageClassifier from mmpretrain.models import ImageClassifier
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
def has_timm() -> bool: def has_timm() -> bool:
@ -134,7 +134,7 @@ class TestImageClassifier(TestCase):
def test_loss(self): def test_loss(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples) losses = model.loss(inputs, data_samples)
@ -142,21 +142,21 @@ class TestImageClassifier(TestCase):
def test_predict(self): def test_predict(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs) predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model.predict(inputs, data_samples) predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (10, )) self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score, torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_label.score) predictions[0].pred_score)
def test_forward(self): def test_forward(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward # test pure forward
@ -169,13 +169,13 @@ class TestImageClassifier(TestCase):
# test forward test # test forward test
predictions = model(inputs, mode='predict') predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model(inputs, data_samples, mode='predict') predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (10, )) self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score, torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_label.score) predictions[0].pred_score)
# test forward with invalid mode # test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'): with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
@ -190,7 +190,7 @@ class TestImageClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
optim_wrapper = MagicMock() optim_wrapper = MagicMock()
@ -207,11 +207,11 @@ class TestImageClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
predictions = model.val_step(data) predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
def test_test_step(self): def test_test_step(self):
cfg = { cfg = {
@ -222,11 +222,11 @@ class TestImageClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
predictions = model.test_step(data) predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
@unittest.skipIf(not has_timm(), 'timm is not installed.') @unittest.skipIf(not has_timm(), 'timm is not installed.')
@ -255,7 +255,7 @@ class TestTimmClassifier(TestCase):
def test_loss(self): def test_loss(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples) losses = model.loss(inputs, data_samples)
@ -263,21 +263,21 @@ class TestTimmClassifier(TestCase):
def test_predict(self): def test_predict(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs) predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model.predict(inputs, data_samples) predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, )) self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score, torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_label.score) predictions[0].pred_score)
def test_forward(self): def test_forward(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward # test pure forward
@ -290,13 +290,13 @@ class TestTimmClassifier(TestCase):
# test forward test # test forward test
predictions = model(inputs, mode='predict') predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model(inputs, data_samples, mode='predict') predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, )) self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score, torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_label.score) predictions[0].pred_score)
# test forward with invalid mode # test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'): with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
@ -311,7 +311,7 @@ class TestTimmClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
optim_wrapper = MagicMock() optim_wrapper = MagicMock()
@ -328,11 +328,11 @@ class TestTimmClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
predictions = model.val_step(data) predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
def test_test_step(self): def test_test_step(self):
cfg = { cfg = {
@ -343,11 +343,11 @@ class TestTimmClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
predictions = model.test_step(data) predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
@unittest.skipIf(not has_huggingface(), 'huggingface is not installed.') @unittest.skipIf(not has_huggingface(), 'huggingface is not installed.')
@ -376,7 +376,7 @@ class TestHuggingFaceClassifier(TestCase):
def test_loss(self): def test_loss(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples) losses = model.loss(inputs, data_samples)
@ -384,21 +384,21 @@ class TestHuggingFaceClassifier(TestCase):
def test_predict(self): def test_predict(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs) predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model.predict(inputs, data_samples) predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, )) self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score, torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_label.score) predictions[0].pred_score)
def test_forward(self): def test_forward(self):
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS) model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
# test pure forward # test pure forward
@ -411,13 +411,13 @@ class TestHuggingFaceClassifier(TestCase):
# test forward test # test forward test
predictions = model(inputs, mode='predict') predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
predictions = model(inputs, data_samples, mode='predict') predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, )) self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score, torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_label.score) predictions[0].pred_score)
# test forward with invalid mode # test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'): with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
@ -432,7 +432,7 @@ class TestHuggingFaceClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
optim_wrapper = MagicMock() optim_wrapper = MagicMock()
@ -449,11 +449,11 @@ class TestHuggingFaceClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
predictions = model.val_step(data) predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))
def test_test_step(self): def test_test_step(self):
cfg = { cfg = {
@ -464,8 +464,8 @@ class TestHuggingFaceClassifier(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)), 'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
predictions = model.test_step(data) predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (1000, )) self.assertEqual(predictions[0].pred_score.shape, (1000, ))

View File

@ -10,7 +10,7 @@ import torch
from mmengine import is_seq_of from mmengine import is_seq_of
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample from mmpretrain.structures import DataSample, MultiTaskDataSample
def setup_seed(seed): def setup_seed(seed):
@ -43,7 +43,7 @@ class TestClsHead(TestCase):
def test_loss(self): def test_loss(self):
feats = self.FAKE_FEATS feats = self.FAKE_FEATS
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)] data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False # with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
@ -75,23 +75,23 @@ class TestClsHead(TestCase):
def test_predict(self): def test_predict(self):
feats = (torch.rand(4, 10), ) feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)] data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples # with without data_samples
predictions = head.predict(feats) predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample)) self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions: for pred in predictions:
self.assertIn('label', pred.pred_label) self.assertIn('pred_label', pred)
self.assertIn('score', pred.pred_label) self.assertIn('pred_score', pred)
# with with data_samples # with with data_samples
predictions = head.predict(feats, data_samples) predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample)) self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions): for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred) self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label) self.assertIn('pred_label', pred)
self.assertIn('score', pred.pred_label) self.assertIn('pred_score', pred)
class TestLinearClsHead(TestCase): class TestLinearClsHead(TestCase):
@ -224,7 +224,7 @@ class TestConformerHead(TestCase):
self.assertEqual(outs[1].shape, (4, 5)) self.assertEqual(outs[1].shape, (4, 5))
def test_loss(self): def test_loss(self):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)] data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False # with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
@ -255,23 +255,23 @@ class TestConformerHead(TestCase):
head.loss(self.fake_feats, data_samples) head.loss(self.fake_feats, data_samples)
def test_predict(self): def test_predict(self):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)] data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples # with without data_samples
predictions = head.predict(self.fake_feats) predictions = head.predict(self.fake_feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample)) self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions: for pred in predictions:
self.assertIn('label', pred.pred_label) self.assertIn('pred_label', pred)
self.assertIn('score', pred.pred_label) self.assertIn('pred_score', pred)
# with with data_samples # with with data_samples
predictions = head.predict(self.fake_feats, data_samples) predictions = head.predict(self.fake_feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample)) self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions): for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred) self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label) self.assertIn('pred_label', pred)
self.assertIn('score', pred.pred_label) self.assertIn('pred_score', pred)
class TestStackedLinearClsHead(TestCase): class TestStackedLinearClsHead(TestCase):
@ -338,7 +338,7 @@ class TestMultiLabelClsHead(TestCase):
def test_loss(self): def test_loss(self):
feats = (torch.rand(4, 10), ) feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)] data_samples = [DataSample().set_gt_label([0, 3]) for _ in range(4)]
# Test with thr and topk are all None # Test with thr and topk are all None
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
@ -383,7 +383,7 @@ class TestMultiLabelClsHead(TestCase):
# Test with gt_lable with score # Test with gt_lable with score
data_samples = [ data_samples = [
ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4) DataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
] ]
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
@ -395,23 +395,23 @@ class TestMultiLabelClsHead(TestCase):
def test_predict(self): def test_predict(self):
feats = (torch.rand(4, 10), ) feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)] data_samples = [DataSample().set_gt_label([1, 2]) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples # with without data_samples
predictions = head.predict(feats) predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample)) self.assertTrue(is_seq_of(predictions, DataSample))
for pred in predictions: for pred in predictions:
self.assertIn('label', pred.pred_label) self.assertIn('pred_label', pred)
self.assertIn('score', pred.pred_label) self.assertIn('pred_score', pred)
# with with data_samples # with with data_samples
predictions = head.predict(feats, data_samples) predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample)) self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions): for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred) self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label) self.assertIn('pred_label', pred)
self.assertIn('score', pred.pred_label) self.assertIn('pred_score', pred)
# Test with topk # Test with topk
cfg = copy.deepcopy(self.DEFAULT_ARGS) cfg = copy.deepcopy(self.DEFAULT_ARGS)
@ -419,11 +419,11 @@ class TestMultiLabelClsHead(TestCase):
head = MODELS.build(cfg) head = MODELS.build(cfg)
predictions = head.predict(feats, data_samples) predictions = head.predict(feats, data_samples)
self.assertEqual(head.thr, None) self.assertEqual(head.thr, None)
self.assertTrue(is_seq_of(predictions, ClsDataSample)) self.assertTrue(is_seq_of(predictions, DataSample))
for sample, pred in zip(data_samples, predictions): for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred) self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label) self.assertIn('pred_label', pred)
self.assertIn('score', pred.pred_label) self.assertIn('pred_score', pred)
class EfficientFormerClsHead(TestClsHead): class EfficientFormerClsHead(TestClsHead):
@ -454,7 +454,7 @@ class EfficientFormerClsHead(TestClsHead):
def test_loss(self): def test_loss(self):
feats = (torch.rand(4, 10), ) feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)] data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# test with distillation head # test with distillation head
cfg = copy.deepcopy(self.DEFAULT_ARGS) cfg = copy.deepcopy(self.DEFAULT_ARGS)
@ -525,7 +525,7 @@ class TestMultiTaskHead(TestCase):
for _ in range(4): for _ in range(4):
data_sample = MultiTaskDataSample() data_sample = MultiTaskDataSample()
for task_name in self.DEFAULT_ARGS['task_heads']: for task_name in self.DEFAULT_ARGS['task_heads']:
task_sample = ClsDataSample().set_gt_label(1) task_sample = DataSample().set_gt_label(1)
data_sample.set_field(task_sample, task_name) data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample) data_samples.append(data_sample)
# with cal_acc = False # with cal_acc = False
@ -545,7 +545,7 @@ class TestMultiTaskHead(TestCase):
for _ in range(4): for _ in range(4):
data_sample = MultiTaskDataSample() data_sample = MultiTaskDataSample()
for task_name in self.DEFAULT_ARGS['task_heads']: for task_name in self.DEFAULT_ARGS['task_heads']:
task_sample = ClsDataSample().set_gt_label(1) task_sample = DataSample().set_gt_label(1)
data_sample.set_field(task_sample, task_name) data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample) data_samples.append(data_sample)
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
@ -555,7 +555,7 @@ class TestMultiTaskHead(TestCase):
for pred in predictions: for pred in predictions:
self.assertIn('task0', pred) self.assertIn('task0', pred)
task0_sample = predictions[0].task0 task0_sample = predictions[0].task0
self.assertTrue(type(task0_sample.pred_label.score), 'torch.tensor') self.assertTrue(type(task0_sample.pred_score), 'torch.tensor')
# with with data_samples # with with data_samples
predictions = head.predict(feats, data_samples) predictions = head.predict(feats, data_samples)
@ -596,7 +596,7 @@ class TestMultiTaskHead(TestCase):
head = MODELS.build(self.DEFAULT_ARGS2) head = MODELS.build(self.DEFAULT_ARGS2)
data_sample = MultiTaskDataSample() data_sample = MultiTaskDataSample()
for task_name in gt_label: for task_name in gt_label:
task_sample = ClsDataSample().set_gt_label(gt_label[task_name]) task_sample = DataSample().set_gt_label(gt_label[task_name])
data_sample.set_field(task_sample, task_name) data_sample.set_field(task_sample, task_name)
with self.assertRaises(Exception): with self.assertRaises(Exception):
head.loss(feats, data_sample) head.loss(feats, data_sample)
@ -606,11 +606,11 @@ class TestMultiTaskHead(TestCase):
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1} gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)
data_sample = MultiTaskDataSample() data_sample = MultiTaskDataSample()
task_sample = ClsDataSample().set_gt_label(gt_label['task1']) task_sample = DataSample().set_gt_label(gt_label['task1'])
data_sample.set_field(task_sample, 'task1') data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0') data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']: for task_name in gt_label['task0']:
task_sample = ClsDataSample().set_gt_label( task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name]) gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name) data_sample.task0.set_field(task_sample, task_name)
with self.assertRaises(Exception): with self.assertRaises(Exception):
@ -694,7 +694,7 @@ class TestArcFaceClsHead(TestCase):
def test_loss(self): def test_loss(self):
feats = (torch.rand(4, 10), ) feats = (torch.rand(4, 10), )
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)] data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
# test loss with used='before' # test loss with used='before'
head = MODELS.build(self.DEFAULT_ARGS) head = MODELS.build(self.DEFAULT_ARGS)

View File

@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset
from mmpretrain.datasets.transforms import PackClsInputs from mmpretrain.datasets.transforms import PackClsInputs
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
class ExampleDataset(Dataset): class ExampleDataset(Dataset):
@ -125,7 +125,7 @@ class TestImageToImageRetriever(TestCase):
def test_loss(self): def test_loss(self):
inputs = torch.rand(1, 3, 64, 64) inputs = torch.rand(1, 3, 64, 64)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model = MODELS.build(self.DEFAULT_ARGS) model = MODELS.build(self.DEFAULT_ARGS)
losses = model.loss(inputs, data_samples) losses = model.loss(inputs, data_samples)
@ -172,32 +172,32 @@ class TestImageToImageRetriever(TestCase):
def test_predict(self): def test_predict(self):
inputs = torch.rand(1, 3, 64, 64) inputs = torch.rand(1, 3, 64, 64)
data_samples = [ClsDataSample().set_gt_label([1, 2, 6])] data_samples = [DataSample().set_gt_label([1, 2, 6])]
# default # default
model = MODELS.build(self.DEFAULT_ARGS) model = MODELS.build(self.DEFAULT_ARGS)
predictions = model.predict(inputs) predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model.predict(inputs, data_samples) predictions = model.predict(inputs, data_samples)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (10, )) self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score, torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_label.score) predictions[0].pred_score)
# k is not -1 # k is not -1
cfg = {**self.DEFAULT_ARGS, 'topk': 2} cfg = {**self.DEFAULT_ARGS, 'topk': 2}
model = MODELS.build(cfg) model = MODELS.build(cfg)
predictions = model.predict(inputs) predictions = model.predict(inputs)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model.predict(inputs, data_samples) predictions = model.predict(inputs, data_samples)
assert predictions is data_samples assert predictions is data_samples
self.assertEqual(data_samples[0].pred_label.score.shape, (10, )) self.assertEqual(data_samples[0].pred_score.shape, (10, ))
def test_forward(self): def test_forward(self):
inputs = torch.rand(1, 3, 64, 64) inputs = torch.rand(1, 3, 64, 64)
data_samples = [ClsDataSample().set_gt_label(1)] data_samples = [DataSample().set_gt_label(1)]
model = MODELS.build(self.DEFAULT_ARGS) model = MODELS.build(self.DEFAULT_ARGS)
# test pure forward # test pure forward
@ -213,13 +213,13 @@ class TestImageToImageRetriever(TestCase):
# test forward test # test forward test
predictions = model(inputs, mode='predict') predictions = model(inputs, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
predictions = model(inputs, data_samples, mode='predict') predictions = model(inputs, data_samples, mode='predict')
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
self.assertEqual(data_samples[0].pred_label.score.shape, (10, )) self.assertEqual(data_samples[0].pred_score.shape, (10, ))
torch.testing.assert_allclose(data_samples[0].pred_label.score, torch.testing.assert_allclose(data_samples[0].pred_score,
predictions[0].pred_label.score) predictions[0].pred_score)
# test forward with invalid mode # test forward with invalid mode
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'): with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
@ -234,7 +234,7 @@ class TestImageToImageRetriever(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 64, 64)), 'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
optim_wrapper = MagicMock() optim_wrapper = MagicMock()
@ -251,11 +251,11 @@ class TestImageToImageRetriever(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 64, 64)), 'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
predictions = model.val_step(data) predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))
def test_test_step(self): def test_test_step(self):
cfg = { cfg = {
@ -266,8 +266,8 @@ class TestImageToImageRetriever(TestCase):
data = { data = {
'inputs': torch.randint(0, 256, (1, 3, 64, 64)), 'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
predictions = model.test_step(data) predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, )) self.assertEqual(predictions[0].pred_score.shape, (10, ))

View File

@ -8,7 +8,7 @@ from mmengine.registry import init_default_scope
from mmpretrain.models import AverageClsScoreTTA, ImageClassifier from mmpretrain.models import AverageClsScoreTTA, ImageClassifier
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
init_default_scope('mmpretrain') init_default_scope('mmpretrain')
@ -48,20 +48,20 @@ class TestAverageClsScoreTTA(TestCase):
img2 = torch.randint(0, 256, (1, 3, 224, 224)) img2 = torch.randint(0, 256, (1, 3, 224, 224))
data1 = { data1 = {
'inputs': img1, 'inputs': img1,
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
data2 = { data2 = {
'inputs': img2, 'inputs': img2,
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
data_tta = { data_tta = {
'inputs': [img1, img2], 'inputs': [img1, img2],
'data_samples': [[ClsDataSample().set_gt_label(1)], 'data_samples': [[DataSample().set_gt_label(1)],
[ClsDataSample().set_gt_label(1)]] [DataSample().set_gt_label(1)]]
} }
score1 = model.module.test_step(data1)[0].pred_label.score score1 = model.module.test_step(data1)[0].pred_score
score2 = model.module.test_step(data2)[0].pred_label.score score2 = model.module.test_step(data2)[0].pred_score
score_tta = model.test_step(data_tta)[0].pred_label.score score_tta = model.test_step(data_tta)[0].pred_score
torch.testing.assert_allclose(score_tta, (score1 + score2) / 2) torch.testing.assert_allclose(score_tta, (score1 + score2) / 2)

View File

@ -5,7 +5,7 @@ import torch
from mmpretrain.models import ClsDataPreprocessor, RandomBatchAugment from mmpretrain.models import ClsDataPreprocessor, RandomBatchAugment
from mmpretrain.registry import MODELS from mmpretrain.registry import MODELS
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
class TestClsDataPreprocessor(TestCase): class TestClsDataPreprocessor(TestCase):
@ -16,15 +16,14 @@ class TestClsDataPreprocessor(TestCase):
data = { data = {
'inputs': [torch.randint(0, 256, (3, 224, 224))], 'inputs': [torch.randint(0, 256, (3, 224, 224))],
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
processed_data = processor(data) processed_data = processor(data)
inputs = processed_data['inputs'] inputs = processed_data['inputs']
data_samples = processed_data['data_samples'] data_samples = processed_data['data_samples']
self.assertEqual(inputs.shape, (1, 3, 224, 224)) self.assertEqual(inputs.shape, (1, 3, 224, 224))
self.assertEqual(len(data_samples), 1) self.assertEqual(len(data_samples), 1)
self.assertTrue( self.assertTrue((data_samples[0].gt_label == torch.tensor([1])).all())
(data_samples[0].gt_label.label == torch.tensor([1])).all())
def test_padding(self): def test_padding(self):
cfg = dict(type='ClsDataPreprocessor', pad_size_divisor=16) cfg = dict(type='ClsDataPreprocessor', pad_size_divisor=16)
@ -87,7 +86,7 @@ class TestClsDataPreprocessor(TestCase):
self.assertIsInstance(processor.batch_augments, RandomBatchAugment) self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
data = { data = {
'inputs': [torch.randint(0, 256, (3, 224, 224))], 'inputs': [torch.randint(0, 256, (3, 224, 224))],
'data_samples': [ClsDataSample().set_gt_label(1)] 'data_samples': [DataSample().set_gt_label(1)]
} }
processed_data = processor(data, training=True) processed_data = processor(data, training=True)
self.assertIn('inputs', processed_data) self.assertIn('inputs', processed_data)

View File

@ -3,58 +3,51 @@ from unittest import TestCase
import numpy as np import numpy as np
import torch import torch
from mmengine.structures import LabelData
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample from mmpretrain.structures import DataSample, MultiTaskDataSample
class TestClsDataSample(TestCase): class TestDataSample(TestCase):
def _test_set_label(self, key): def _test_set_label(self, key):
data_sample = ClsDataSample() data_sample = DataSample()
method = getattr(data_sample, 'set_' + key) method = getattr(data_sample, 'set_' + key)
# Test number # Test number
method(1) method(1)
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, torch.LongTensor)
self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor with single number # Test tensor with single number
method(torch.tensor(2)) method(torch.tensor(2))
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, torch.LongTensor)
self.assertIsInstance(label.label, torch.LongTensor)
# Test array with single number # Test array with single number
method(np.array(3)) method(np.array(3))
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, torch.LongTensor)
self.assertIsInstance(label.label, torch.LongTensor)
# Test tensor # Test tensor
method(torch.tensor([1, 2, 3])) method(torch.tensor([1, 2, 3]))
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertIsInstance(label, torch.Tensor)
self.assertIsInstance(label.label, torch.Tensor) self.assertTrue((label == torch.tensor([1, 2, 3])).all())
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test array # Test array
method(np.array([1, 2, 3])) method(np.array([1, 2, 3]))
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertTrue((label == torch.tensor([1, 2, 3])).all())
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test Sequence # Test Sequence
method([1, 2, 3]) method([1, 2, 3])
self.assertIn(key, data_sample) self.assertIn(key, data_sample)
label = getattr(data_sample, key) label = getattr(data_sample, key)
self.assertIsInstance(label, LabelData) self.assertTrue((label == torch.tensor([1, 2, 3])).all())
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
# Test unavailable type # Test unavailable type
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"): with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
@ -66,34 +59,13 @@ class TestClsDataSample(TestCase):
def test_set_pred_label(self): def test_set_pred_label(self):
self._test_set_label('pred_label') self._test_set_label('pred_label')
def test_del_gt_label(self):
data_sample = ClsDataSample()
self.assertNotIn('gt_label', data_sample)
data_sample.set_gt_label(1)
self.assertIn('gt_label', data_sample)
del data_sample.gt_label
self.assertNotIn('gt_label', data_sample)
def test_del_pred_label(self):
data_sample = ClsDataSample()
self.assertNotIn('pred_label', data_sample)
data_sample.set_pred_label(1)
self.assertIn('pred_label', data_sample)
del data_sample.pred_label
self.assertNotIn('pred_label', data_sample)
def test_set_gt_score(self): def test_set_gt_score(self):
data_sample = ClsDataSample() data_sample = DataSample()
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])) data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.gt_label) self.assertIn('gt_score', data_sample)
torch.testing.assert_allclose(data_sample.gt_label.score, torch.testing.assert_allclose(data_sample.gt_score,
[0.1, 0.1, 0.6, 0.1, 0.1]) [0.1, 0.1, 0.6, 0.1, 0.1])
# Test set again
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.gt_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid length # Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'): with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2]) data_sample.set_gt_score([1, 2])
@ -103,17 +75,12 @@ class TestClsDataSample(TestCase):
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]])) data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
def test_set_pred_score(self): def test_set_pred_score(self):
data_sample = ClsDataSample() data_sample = DataSample()
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])) data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
self.assertIn('score', data_sample.pred_label) self.assertIn('pred_score', data_sample)
torch.testing.assert_allclose(data_sample.pred_label.score, torch.testing.assert_allclose(data_sample.pred_score,
[0.1, 0.1, 0.6, 0.1, 0.1]) [0.1, 0.1, 0.6, 0.1, 0.1])
# Test set again
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
torch.testing.assert_allclose(data_sample.pred_label.score,
[0.2, 0.1, 0.5, 0.1, 0.1])
# Test invalid length # Test invalid length
with self.assertRaisesRegex(AssertionError, 'should be equal to'): with self.assertRaisesRegex(AssertionError, 'should be equal to'):
data_sample.set_gt_score([1, 2]) data_sample.set_gt_score([1, 2])
@ -129,13 +96,13 @@ class TestMultiTaskDataSample(TestCase):
def test_multi_task_data_sample(self): def test_multi_task_data_sample(self):
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1} gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample() data_sample = MultiTaskDataSample()
task_sample = ClsDataSample().set_gt_label(gt_label['task1']) task_sample = DataSample().set_gt_label(gt_label['task1'])
data_sample.set_field(task_sample, 'task1') data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0') data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']: for task_name in gt_label['task0']:
task_sample = ClsDataSample().set_gt_label( task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name]) gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name) data_sample.task0.set_field(task_sample, task_name)
self.assertIsInstance(data_sample.task0, MultiTaskDataSample) self.assertIsInstance(data_sample.task0, MultiTaskDataSample)
self.assertIsInstance(data_sample.task1, ClsDataSample) self.assertIsInstance(data_sample.task1, DataSample)
self.assertIsInstance(data_sample.task0.task00, ClsDataSample) self.assertIsInstance(data_sample.task0.task00, DataSample)

View File

@ -2,10 +2,9 @@
from unittest import TestCase from unittest import TestCase
import torch import torch
from mmengine.structures import LabelData
from mmpretrain.structures import (batch_label_to_onehot, cat_batch_labels, from mmpretrain.structures import (batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split) tensor_split)
class TestStructureUtils(TestCase): class TestStructureUtils(TestCase):
@ -28,11 +27,11 @@ class TestStructureUtils(TestCase):
def test_cat_batch_labels(self): def test_cat_batch_labels(self):
labels = [ labels = [
LabelData(label=torch.tensor([1])), torch.tensor([1]),
LabelData(label=torch.tensor([3, 2])), torch.tensor([3, 2]),
LabelData(label=torch.tensor([0, 1, 4])), torch.tensor([0, 1, 4]),
LabelData(label=torch.tensor([], dtype=torch.int64)), torch.tensor([], dtype=torch.int64),
LabelData(label=torch.tensor([], dtype=torch.int64)), torch.tensor([], dtype=torch.int64),
] ]
batch_label, split_indices = cat_batch_labels(labels) batch_label, split_indices = cat_batch_labels(labels)
@ -45,42 +44,13 @@ class TestStructureUtils(TestCase):
self.assertEqual(labels[3].tolist(), []) self.assertEqual(labels[3].tolist(), [])
self.assertEqual(labels[4].tolist(), []) self.assertEqual(labels[4].tolist(), [])
labels = [
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
]
batch_label, split_indices = cat_batch_labels(labels)
self.assertIsNone(batch_label)
self.assertIsNone(split_indices)
def test_stack_batch_scores(self):
labels = [
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
]
batch_score = stack_batch_scores(labels)
self.assertEqual(batch_score.shape, (3, 5))
labels = [
LabelData(label=torch.tensor([1])),
LabelData(label=torch.tensor([3, 2])),
LabelData(label=torch.tensor([0, 1, 4])),
LabelData(label=torch.tensor([], dtype=torch.int64)),
LabelData(label=torch.tensor([], dtype=torch.int64)),
]
batch_score = stack_batch_scores(labels)
self.assertIsNone(batch_score)
def test_batch_label_to_onehot(self): def test_batch_label_to_onehot(self):
labels = [ labels = [
LabelData(label=torch.tensor([1])), torch.tensor([1]),
LabelData(label=torch.tensor([3, 2])), torch.tensor([3, 2]),
LabelData(label=torch.tensor([0, 1, 4])), torch.tensor([0, 1, 4]),
LabelData(label=torch.tensor([], dtype=torch.int64)), torch.tensor([], dtype=torch.int64),
LabelData(label=torch.tensor([], dtype=torch.int64)), torch.tensor([], dtype=torch.int64),
] ]
batch_label, split_indices = cat_batch_labels(labels) batch_label, split_indices = cat_batch_labels(labels)

View File

@ -7,7 +7,7 @@ from unittest.mock import patch
import numpy as np import numpy as np
import torch import torch
from mmpretrain.structures import ClsDataSample from mmpretrain.structures import DataSample
from mmpretrain.visualization import ClsVisualizer from mmpretrain.visualization import ClsVisualizer
@ -24,7 +24,7 @@ class TestClsVisualizer(TestCase):
def test_add_datasample(self): def test_add_datasample(self):
image = np.ones((10, 10, 3), np.uint8) image = np.ones((10, 10, 3), np.uint8)
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\ data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
set_pred_score(torch.tensor([0.1, 0.8, 0.1])) set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
# Test show # Test show
@ -82,7 +82,7 @@ class TestClsVisualizer(TestCase):
'test', image=image, data_sample=data_sample, draw_gt=False) 'test', image=image, data_sample=data_sample, draw_gt=False)
# Test without score # Test without score
del data_sample.pred_label.score del data_sample.pred_score
def test_texts(text, *_, **__): def test_texts(text, *_, **__):
self.assertEqual( self.assertEqual(