[Refactor] Refactor ClsDatasample to a union DataSample. (#1371)
* [Refactor] Refactor ClsDatasample to a union DataSample. * Add method * Fix docstring * Update docstring.pull/1380/head
parent
4016f1348e
commit
36bea13fca
|
@ -1,13 +1,13 @@
|
|||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
.. module:: mmcls.structures
|
||||
.. module:: mmpretrain.structures
|
||||
|
||||
mmcls.structures
|
||||
mmpretrain.structures
|
||||
===================================
|
||||
|
||||
This package includes basic data structures for classification tasks.
|
||||
This package includes basic data structures.
|
||||
|
||||
ClsDataSample
|
||||
DataSample
|
||||
-------------
|
||||
.. autoclass:: ClsDataSample
|
||||
.. autoclass:: DataSample
|
||||
|
|
|
@ -12,7 +12,7 @@ from mmengine.model import BaseModel
|
|||
from mmengine.runner import load_checkpoint
|
||||
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from .model import get_model, init_model, list_models
|
||||
|
||||
ModelType = Union[BaseModel, str, Config]
|
||||
|
@ -176,7 +176,7 @@ class ImageClassificationInferencer(BaseInferencer):
|
|||
|
||||
def visualize(self,
|
||||
ori_inputs: List[InputType],
|
||||
preds: List[ClsDataSample],
|
||||
preds: List[DataSample],
|
||||
show: bool = False,
|
||||
rescale_factor: Optional[float] = None,
|
||||
draw_score=True,
|
||||
|
@ -223,7 +223,7 @@ class ImageClassificationInferencer(BaseInferencer):
|
|||
return visualization
|
||||
|
||||
def postprocess(self,
|
||||
preds: List[ClsDataSample],
|
||||
preds: List[DataSample],
|
||||
visualization: List[np.ndarray],
|
||||
return_datasamples=False) -> dict:
|
||||
if return_datasamples:
|
||||
|
@ -231,14 +231,13 @@ class ImageClassificationInferencer(BaseInferencer):
|
|||
|
||||
results = []
|
||||
for data_sample in preds:
|
||||
prediction = data_sample.pred_label
|
||||
pred_scores = prediction.score.detach().cpu().numpy()
|
||||
pred_score = torch.max(prediction.score).item()
|
||||
pred_label = torch.argmax(prediction.score).item()
|
||||
pred_scores = data_sample.pred_score
|
||||
pred_score = float(torch.max(pred_scores).item())
|
||||
pred_label = torch.argmax(pred_scores).item()
|
||||
result = {
|
||||
'pred_scores': pred_scores,
|
||||
'pred_scores': pred_scores.detach().cpu().numpy(),
|
||||
'pred_label': pred_label,
|
||||
'pred_score': float(pred_score),
|
||||
'pred_score': pred_score,
|
||||
}
|
||||
if self.classes is not None:
|
||||
result['pred_class'] = self.classes[pred_label]
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmengine.utils import is_str
|
|||
from PIL import Image
|
||||
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmpretrain.structures import DataSample, MultiTaskDataSample
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
|
@ -53,7 +53,7 @@ class PackClsInputs(BaseTransform):
|
|||
**Added Keys:**
|
||||
|
||||
- inputs (:obj:`torch.Tensor`): The forward data of models.
|
||||
- data_samples (:obj:`~mmpretrain.structures.ClsDataSample`): The
|
||||
- data_samples (:obj:`~mmpretrain.structures.DataSample`): The
|
||||
annotation info of the sample.
|
||||
|
||||
Args:
|
||||
|
@ -87,10 +87,11 @@ class PackClsInputs(BaseTransform):
|
|||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
packed_results['inputs'] = to_tensor(img)
|
||||
|
||||
data_sample = ClsDataSample()
|
||||
data_sample = DataSample()
|
||||
if 'gt_label' in results:
|
||||
gt_label = results['gt_label']
|
||||
data_sample.set_gt_label(gt_label)
|
||||
data_sample.set_gt_label(results['gt_label'])
|
||||
if 'gt_score' in results:
|
||||
data_sample.set_gt_score(results['gt_score'])
|
||||
|
||||
img_meta = {k: results[k] for k in self.meta_keys if k in results}
|
||||
data_sample.set_metainfo(img_meta)
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmengine.runner import EpochBasedTrainLoop, Runner
|
|||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmpretrain.registry import HOOKS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -51,14 +51,14 @@ class VisualizationHook(Hook):
|
|||
def _draw_samples(self,
|
||||
batch_idx: int,
|
||||
data_batch: dict,
|
||||
data_samples: Sequence[ClsDataSample],
|
||||
data_samples: Sequence[DataSample],
|
||||
step: int = 0) -> None:
|
||||
"""Visualize every ``self.interval`` samples from a data batch.
|
||||
|
||||
Args:
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
|
||||
outputs (Sequence[:obj:`DataSample`]): Outputs from model.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
"""
|
||||
if self.enable is False:
|
||||
|
@ -97,14 +97,14 @@ class VisualizationHook(Hook):
|
|||
)
|
||||
|
||||
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[ClsDataSample]) -> None:
|
||||
outputs: Sequence[DataSample]) -> None:
|
||||
"""Visualize every ``self.interval`` samples during validation.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
|
||||
outputs (Sequence[:obj:`DataSample`]): Outputs from model.
|
||||
"""
|
||||
if isinstance(runner.train_loop, EpochBasedTrainLoop):
|
||||
step = runner.epoch
|
||||
|
@ -114,7 +114,7 @@ class VisualizationHook(Hook):
|
|||
self._draw_samples(batch_idx, data_batch, outputs, step=step)
|
||||
|
||||
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[ClsDataSample]) -> None:
|
||||
outputs: Sequence[DataSample]) -> None:
|
||||
"""Visualize every ``self.interval`` samples during test.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -5,9 +5,9 @@ import numpy as np
|
|||
import torch
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmpretrain.registry import METRICS
|
||||
from mmpretrain.structures import label_to_onehot
|
||||
from .single_label import _precision_recall_f1_support, to_tensor
|
||||
|
||||
|
||||
|
@ -114,10 +114,10 @@ class MultiLabelMetric(BaseMetric):
|
|||
(tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8))
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import DataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_sampels = [
|
||||
... ClsDataSample().set_pred_score(pred).set_gt_score(gt)
|
||||
... DataSample().set_pred_score(pred).set_gt_score(gt)
|
||||
... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))]
|
||||
>>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5))
|
||||
>>> evaluator.process(data_sampels)
|
||||
|
@ -181,17 +181,15 @@ class MultiLabelMetric(BaseMetric):
|
|||
"""
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
|
||||
result['pred_score'] = pred_label['score'].clone()
|
||||
result['pred_score'] = data_sample['pred_score'].clone()
|
||||
num_classes = result['pred_score'].size()[-1]
|
||||
|
||||
if 'score' in gt_label:
|
||||
result['gt_score'] = gt_label['score'].clone()
|
||||
if 'gt_score' in data_sample:
|
||||
result['gt_score'] = data_sample['gt_score'].clone()
|
||||
else:
|
||||
result['gt_score'] = LabelData.label_to_onehot(
|
||||
gt_label['label'], num_classes)
|
||||
result['gt_score'] = label_to_onehot(data_sample['gt_label'],
|
||||
num_classes)
|
||||
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
@ -331,8 +329,7 @@ class MultiLabelMetric(BaseMetric):
|
|||
assert num_classes is not None, 'For index-type labels, ' \
|
||||
'please specify `num_classes`.'
|
||||
label = torch.stack([
|
||||
LabelData.label_to_onehot(
|
||||
to_tensor(indices), num_classes)
|
||||
label_to_onehot(indices, num_classes)
|
||||
for indices in label
|
||||
])
|
||||
else:
|
||||
|
@ -479,10 +476,10 @@ class AveragePrecision(BaseMetric):
|
|||
>>> AveragePrecision.calculate(y_pred, y_true)
|
||||
tensor(70.833)
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import DataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_pred_score(i).set_gt_score(j)
|
||||
... DataSample().set_pred_score(i).set_gt_score(j)
|
||||
... for i, j in zip(y_pred, y_true)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=AveragePrecision())
|
||||
|
@ -517,17 +514,15 @@ class AveragePrecision(BaseMetric):
|
|||
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
|
||||
result['pred_score'] = pred_label['score']
|
||||
result['pred_score'] = data_sample['pred_score'].clone()
|
||||
num_classes = result['pred_score'].size()[-1]
|
||||
|
||||
if 'score' in gt_label:
|
||||
result['gt_score'] = gt_label['score']
|
||||
if 'gt_score' in data_sample:
|
||||
result['gt_score'] = data_sample['gt_score'].clone()
|
||||
else:
|
||||
result['gt_score'] = LabelData.label_to_onehot(
|
||||
gt_label['label'], num_classes)
|
||||
result['gt_score'] = label_to_onehot(data_sample['gt_label'],
|
||||
num_classes)
|
||||
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
|
|
@ -5,10 +5,10 @@ import mmengine
|
|||
import numpy as np
|
||||
import torch
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.structures import LabelData
|
||||
from mmengine.utils import is_seq_of
|
||||
|
||||
from mmpretrain.registry import METRICS
|
||||
from mmpretrain.structures import label_to_onehot
|
||||
from .single_label import to_tensor
|
||||
|
||||
|
||||
|
@ -48,10 +48,10 @@ class RetrievalRecall(BaseMetric):
|
|||
[tensor(9.3000), tensor(48.4000)]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import DataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label([0, 1]).set_pred_score(
|
||||
... DataSample().set_gt_label([0, 1]).set_pred_score(
|
||||
... torch.rand(10))
|
||||
... for i in range(1000)
|
||||
... ]
|
||||
|
@ -95,23 +95,21 @@ class RetrievalRecall(BaseMetric):
|
|||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_label']
|
||||
pred_score = data_sample['pred_score'].clone()
|
||||
gt_label = data_sample['gt_label']
|
||||
|
||||
pred = pred_label['score'].clone()
|
||||
if 'score' in gt_label:
|
||||
target = gt_label['score'].clone()
|
||||
if 'gt_score' in data_sample:
|
||||
target = data_sample.get('gt_score').clone()
|
||||
else:
|
||||
num_classes = pred_label['score'].size()[-1]
|
||||
target = LabelData.label_to_onehot(gt_label['label'],
|
||||
num_classes)
|
||||
num_classes = pred_score.size()[-1]
|
||||
target = label_to_onehot(gt_label, num_classes)
|
||||
|
||||
# Because the retrieval output logit vector will be much larger
|
||||
# compared to the normal classification, to save resources, the
|
||||
# evaluation results are computed each batch here and then reduce
|
||||
# all results at the end.
|
||||
result = RetrievalRecall.calculate(
|
||||
pred.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
|
||||
pred_score.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
|
@ -230,5 +228,5 @@ def _format_target(label, is_indices=False):
|
|||
raise TypeError(f'The pred must be type of torch.tensor, '
|
||||
f'np.ndarray or Sequence but get {type(label)}.')
|
||||
|
||||
indices = [LabelData.onehot_to_label(sample_gt) for sample_gt in label]
|
||||
indices = [sample_gt.nonzero().squeeze(-1) for sample_gt in label]
|
||||
return indices
|
||||
|
|
|
@ -104,10 +104,10 @@ class Accuracy(BaseMetric):
|
|||
[[tensor([9.9000])], [tensor([51.5000])]]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import DataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10))
|
||||
... DataSample().set_gt_label(0).set_pred_score(torch.rand(10))
|
||||
... for i in range(1000)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
|
||||
|
@ -150,13 +150,11 @@ class Accuracy(BaseMetric):
|
|||
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
if 'score' in pred_label:
|
||||
result['pred_score'] = pred_label['score'].cpu()
|
||||
if 'pred_score' in data_sample:
|
||||
result['pred_score'] = data_sample['pred_score'].cpu()
|
||||
else:
|
||||
result['pred_label'] = pred_label['label'].cpu()
|
||||
result['gt_label'] = gt_label['label'].cpu()
|
||||
result['pred_label'] = data_sample['pred_label'].cpu()
|
||||
result['gt_label'] = data_sample['gt_label'].cpu()
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
|
@ -358,10 +356,10 @@ class SingleLabelMetric(BaseMetric):
|
|||
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import DataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
|
||||
... DataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
|
||||
... for i in range(1000)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=SingleLabelMetric())
|
||||
|
@ -418,19 +416,16 @@ class SingleLabelMetric(BaseMetric):
|
|||
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
if 'score' in pred_label:
|
||||
result['pred_score'] = pred_label['score'].cpu()
|
||||
if 'pred_score' in data_sample:
|
||||
result['pred_score'] = data_sample['pred_score'].cpu()
|
||||
else:
|
||||
num_classes = self.num_classes or data_sample.get(
|
||||
'num_classes')
|
||||
assert num_classes is not None, \
|
||||
'The `num_classes` must be specified if `pred_label` has '\
|
||||
'only `label`.'
|
||||
result['pred_label'] = pred_label['label'].cpu()
|
||||
'The `num_classes` must be specified if no `pred_score`.'
|
||||
result['pred_label'] = data_sample['pred_label'].cpu()
|
||||
result['num_classes'] = num_classes
|
||||
result['gt_label'] = gt_label['label'].cpu()
|
||||
result['gt_label'] = data_sample['gt_label'].cpu()
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
|
@ -641,17 +636,16 @@ class ConfusionMatrix(BaseMetric):
|
|||
|
||||
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
|
||||
for data_sample in data_samples:
|
||||
pred = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']['label']
|
||||
if 'score' in pred:
|
||||
pred_label = pred['score'].argmax(dim=0, keepdim=True)
|
||||
self.num_classes = pred['score'].size(0)
|
||||
if 'pred_score' in data_sample:
|
||||
pred_score = data_sample['pred_score']
|
||||
pred_label = pred_score.argmax(dim=0, keepdim=True)
|
||||
self.num_classes = pred_score.size(0)
|
||||
else:
|
||||
pred_label = pred['label']
|
||||
pred_label = data_sample['pred_label']
|
||||
|
||||
self.results.append({
|
||||
'pred_label': pred_label,
|
||||
'gt_label': gt_label
|
||||
'gt_label': data_sample['gt_label'],
|
||||
})
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmpretrain.registry import METRICS
|
||||
from mmpretrain.structures import label_to_onehot
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
|
||||
|
||||
|
@ -39,18 +38,16 @@ class VOCMetricMixin:
|
|||
"""
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
gt_label_difficult = data_sample['gt_label_difficult']
|
||||
|
||||
result['pred_score'] = pred_label['score'].clone()
|
||||
result['pred_score'] = data_sample['pred_score'].clone()
|
||||
num_classes = result['pred_score'].size()[-1]
|
||||
|
||||
if 'score' in gt_label:
|
||||
result['gt_score'] = gt_label['score'].clone()
|
||||
if 'gt_score' in data_sample:
|
||||
result['gt_score'] = data_sample['gt_score'].clone()
|
||||
else:
|
||||
result['gt_score'] = LabelData.label_to_onehot(
|
||||
gt_label['label'], num_classes)
|
||||
result['gt_score'] = label_to_onehot(gt_label, num_classes)
|
||||
|
||||
# VOC annotation labels all the objects in a single image
|
||||
# therefore, some categories are appeared both in
|
||||
|
@ -58,7 +55,7 @@ class VOCMetricMixin:
|
|||
# Here we reckon those labels which are only exists in difficult
|
||||
# objects as difficult labels.
|
||||
difficult_label = set(gt_label_difficult) - (
|
||||
set(gt_label_difficult) & set(gt_label['label'].tolist()))
|
||||
set(gt_label_difficult) & set(gt_label.tolist()))
|
||||
|
||||
# set difficult label for better eval
|
||||
if self.difficult_as_positive is None:
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from .base import BaseClassifier
|
||||
|
||||
|
||||
|
@ -122,14 +122,14 @@ class HuggingFaceClassifier(BaseClassifier):
|
|||
raise NotImplementedError(
|
||||
"The HuggingFaceClassifier doesn't support extract feature yet.")
|
||||
|
||||
def loss(self, inputs: torch.Tensor, data_samples: List[ClsDataSample],
|
||||
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
|
||||
**kwargs):
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments of the loss module.
|
||||
|
||||
|
@ -144,14 +144,14 @@ class HuggingFaceClassifier(BaseClassifier):
|
|||
return losses
|
||||
|
||||
def _get_loss(self, cls_score: torch.Tensor,
|
||||
data_samples: List[ClsDataSample], **kwargs):
|
||||
data_samples: List[DataSample], **kwargs):
|
||||
"""Unpack data samples and compute loss."""
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
if 'gt_score' in data_samples[0]:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
target = torch.stack([i.gt_label.score for i in data_samples])
|
||||
target = torch.stack([i.gt_score for i in data_samples])
|
||||
else:
|
||||
target = torch.cat([i.gt_label.label for i in data_samples])
|
||||
target = torch.cat([i.gt_label for i in data_samples])
|
||||
|
||||
# compute loss
|
||||
losses = dict()
|
||||
|
@ -163,17 +163,17 @@ class HuggingFaceClassifier(BaseClassifier):
|
|||
|
||||
def predict(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[ClsDataSample]] = None):
|
||||
data_samples: Optional[List[DataSample]] = None):
|
||||
"""Predict results from a batch of inputs.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[ClsDataSample]: The prediction results.
|
||||
List[DataSample]: The prediction results.
|
||||
"""
|
||||
# The part can be traced by torch.fx
|
||||
cls_score = self.model(inputs).logits
|
||||
|
@ -197,8 +197,8 @@ class HuggingFaceClassifier(BaseClassifier):
|
|||
else:
|
||||
data_samples = []
|
||||
for score, label in zip(pred_scores, pred_labels):
|
||||
data_samples.append(ClsDataSample().set_pred_score(
|
||||
score).set_pred_label(label))
|
||||
data_samples.append(
|
||||
DataSample().set_pred_score(score).set_pred_label(label))
|
||||
|
||||
return data_samples
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from .base import BaseClassifier
|
||||
|
||||
|
||||
|
@ -76,7 +76,7 @@ class ImageClassifier(BaseClassifier):
|
|||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[ClsDataSample]] = None,
|
||||
data_samples: Optional[List[DataSample]] = None,
|
||||
mode: str = 'tensor'):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
|
@ -85,7 +85,7 @@ class ImageClassifier(BaseClassifier):
|
|||
- "tensor": Forward the whole network and return tensor or tuple of
|
||||
tensor without any post-processing, same as a common nn.Module.
|
||||
- "predict": Forward and return the predictions, which are fully
|
||||
processed to a list of :obj:`ClsDataSample`.
|
||||
processed to a list of :obj:`DataSample`.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
|
@ -95,7 +95,7 @@ class ImageClassifier(BaseClassifier):
|
|||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. It's required if ``mode="loss"``.
|
||||
Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
@ -105,7 +105,7 @@ class ImageClassifier(BaseClassifier):
|
|||
|
||||
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
|
||||
- If ``mode="predict"``, return a list of
|
||||
:obj:`mmpretrain.structures.ClsDataSample`.
|
||||
:obj:`mmpretrain.structures.DataSample`.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'tensor':
|
||||
|
@ -209,13 +209,13 @@ class ImageClassifier(BaseClassifier):
|
|||
return self.head.pre_logits(x)
|
||||
|
||||
def loss(self, inputs: torch.Tensor,
|
||||
data_samples: List[ClsDataSample]) -> dict:
|
||||
data_samples: List[DataSample]) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
|
||||
Returns:
|
||||
|
@ -226,14 +226,14 @@ class ImageClassifier(BaseClassifier):
|
|||
|
||||
def predict(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[ClsDataSample]] = None,
|
||||
**kwargs) -> List[ClsDataSample]:
|
||||
data_samples: Optional[List[DataSample]] = None,
|
||||
**kwargs) -> List[DataSample]:
|
||||
"""Predict results from a batch of inputs.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
**kwargs: Other keyword arguments accepted by the ``predict``
|
||||
method of :attr:`head`.
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from .base import BaseClassifier
|
||||
|
||||
|
||||
|
@ -110,14 +110,14 @@ class TimmClassifier(BaseClassifier):
|
|||
f"The model {type(self.model)} doesn't support extract "
|
||||
"feature because it don't have `forward_features` method.")
|
||||
|
||||
def loss(self, inputs: torch.Tensor, data_samples: List[ClsDataSample],
|
||||
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
|
||||
**kwargs):
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments of the loss module.
|
||||
|
||||
|
@ -132,14 +132,14 @@ class TimmClassifier(BaseClassifier):
|
|||
return losses
|
||||
|
||||
def _get_loss(self, cls_score: torch.Tensor,
|
||||
data_samples: List[ClsDataSample], **kwargs):
|
||||
data_samples: List[DataSample], **kwargs):
|
||||
"""Unpack data samples and compute loss."""
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
if 'gt_score' in data_samples[0]:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
target = torch.stack([i.gt_label.score for i in data_samples])
|
||||
target = torch.stack([i.gt_score for i in data_samples])
|
||||
else:
|
||||
target = torch.cat([i.gt_label.label for i in data_samples])
|
||||
target = torch.cat([i.gt_label for i in data_samples])
|
||||
|
||||
# compute loss
|
||||
losses = dict()
|
||||
|
@ -150,17 +150,17 @@ class TimmClassifier(BaseClassifier):
|
|||
|
||||
def predict(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[ClsDataSample]] = None):
|
||||
data_samples: Optional[List[DataSample]] = None):
|
||||
"""Predict results from a batch of inputs.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[ClsDataSample]: The prediction results.
|
||||
List[DataSample]: The prediction results.
|
||||
"""
|
||||
# The part can be traced by torch.fx
|
||||
cls_score = self(inputs)
|
||||
|
@ -184,8 +184,8 @@ class TimmClassifier(BaseClassifier):
|
|||
else:
|
||||
data_samples = []
|
||||
for score, label in zip(pred_scores, pred_labels):
|
||||
data_samples.append(ClsDataSample().set_pred_score(
|
||||
score).set_pred_label(label))
|
||||
data_samples.append(
|
||||
DataSample().set_pred_score(score).set_pred_label(label))
|
||||
|
||||
return data_samples
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||
|
||||
from mmpretrain.evaluation.metrics import Accuracy
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
|
@ -57,8 +57,8 @@ class ClsHead(BaseHead):
|
|||
# just return the unpacked inputs.
|
||||
return pre_logits
|
||||
|
||||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
|
||||
**kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
|
@ -66,7 +66,7 @@ class ClsHead(BaseHead):
|
|||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
|
@ -81,14 +81,14 @@ class ClsHead(BaseHead):
|
|||
return losses
|
||||
|
||||
def _get_loss(self, cls_score: torch.Tensor,
|
||||
data_samples: List[ClsDataSample], **kwargs):
|
||||
data_samples: List[DataSample], **kwargs):
|
||||
"""Unpack data samples and compute loss."""
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
if 'gt_score' in data_samples[0]:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
target = torch.stack([i.gt_label.score for i in data_samples])
|
||||
target = torch.stack([i.gt_score for i in data_samples])
|
||||
else:
|
||||
target = torch.cat([i.gt_label.label for i in data_samples])
|
||||
target = torch.cat([i.gt_label for i in data_samples])
|
||||
|
||||
# compute loss
|
||||
losses = dict()
|
||||
|
@ -110,8 +110,8 @@ class ClsHead(BaseHead):
|
|||
def predict(
|
||||
self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
data_samples: List[Union[ClsDataSample, None]] = None
|
||||
) -> List[ClsDataSample]:
|
||||
data_samples: Optional[List[Optional[DataSample]]] = None
|
||||
) -> List[DataSample]:
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
|
@ -119,12 +119,12 @@ class ClsHead(BaseHead):
|
|||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample | None], optional): The annotation
|
||||
data_samples (List[DataSample | None], optional): The annotation
|
||||
data of every samples. If not None, set ``pred_label`` of
|
||||
the input data samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[ClsDataSample]: A list of data samples which contains the
|
||||
List[DataSample]: A list of data samples which contains the
|
||||
predicted results.
|
||||
"""
|
||||
# The part can be traced by torch.fx
|
||||
|
@ -149,7 +149,7 @@ class ClsHead(BaseHead):
|
|||
for data_sample, score, label in zip(data_samples, pred_scores,
|
||||
pred_labels):
|
||||
if data_sample is None:
|
||||
data_sample = ClsDataSample()
|
||||
data_sample = DataSample()
|
||||
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
out_data_samples.append(data_sample)
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
|
||||
from mmpretrain.evaluation.metrics import Accuracy
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
|
@ -64,10 +64,9 @@ class ConformerHead(ClsHead):
|
|||
|
||||
return conv_cls_score, tran_cls_score
|
||||
|
||||
def predict(
|
||||
self,
|
||||
feats: Tuple[List[torch.Tensor]],
|
||||
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
|
||||
def predict(self,
|
||||
feats: Tuple[List[torch.Tensor]],
|
||||
data_samples: List[DataSample] = None) -> List[DataSample]:
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
|
@ -75,12 +74,12 @@ class ConformerHead(ClsHead):
|
|||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. If not None, set ``pred_label`` of
|
||||
the input data samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[ClsDataSample]: A list of data samples which contains the
|
||||
List[DataSample]: A list of data samples which contains the
|
||||
predicted results.
|
||||
"""
|
||||
# The part can be traced by torch.fx
|
||||
|
@ -92,14 +91,14 @@ class ConformerHead(ClsHead):
|
|||
return predictions
|
||||
|
||||
def _get_loss(self, cls_score: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
data_samples: List[DataSample], **kwargs) -> dict:
|
||||
"""Unpack data samples and compute loss."""
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
if 'gt_score' in data_samples[0]:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
target = torch.stack([i.gt_label.score for i in data_samples])
|
||||
target = torch.stack([i.gt_score for i in data_samples])
|
||||
else:
|
||||
target = torch.cat([i.gt_label.label for i in data_samples])
|
||||
target = torch.cat([i.gt_label for i in data_samples])
|
||||
|
||||
# compute loss
|
||||
losses = dict()
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
|
@ -65,8 +65,8 @@ class EfficientFormerClsHead(ClsHead):
|
|||
# after unpacking.
|
||||
return feats[-1]
|
||||
|
||||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
|
||||
**kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
|
@ -74,7 +74,7 @@ class EfficientFormerClsHead(ClsHead):
|
|||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmengine.utils import is_seq_of
|
|||
|
||||
from mmpretrain.models.losses import convert_to_one_hot
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from .cls_head import ClsHead
|
||||
|
||||
|
||||
|
@ -264,8 +264,8 @@ class ArcFaceClsHead(ClsHead):
|
|||
|
||||
return self.scale * logit
|
||||
|
||||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
|
||||
**kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
|
@ -273,7 +273,7 @@ class ArcFaceClsHead(ClsHead):
|
|||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
|
@ -281,12 +281,11 @@ class ArcFaceClsHead(ClsHead):
|
|||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
# Unpack data samples and pack targets
|
||||
label_target = torch.cat([i.gt_label.label for i in data_samples])
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
label_target = torch.cat([i.gt_label for i in data_samples])
|
||||
if 'gt_score' in data_samples[0]:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
target = torch.stack([i.gt_label.score for i in data_samples])
|
||||
target = torch.stack([i.gt_score for i in data_samples])
|
||||
else:
|
||||
# change the labels to to one-hot format scores.
|
||||
target = label_target
|
||||
|
||||
# the index format target would be used
|
||||
|
|
|
@ -3,10 +3,9 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample, label_to_onehot
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
|
@ -65,8 +64,8 @@ class MultiLabelClsHead(BaseHead):
|
|||
# just return the unpacked inputs.
|
||||
return pre_logits
|
||||
|
||||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
|
||||
**kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
|
@ -74,7 +73,7 @@ class MultiLabelClsHead(BaseHead):
|
|||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
|
@ -89,19 +88,16 @@ class MultiLabelClsHead(BaseHead):
|
|||
return losses
|
||||
|
||||
def _get_loss(self, cls_score: torch.Tensor,
|
||||
data_samples: List[ClsDataSample], **kwargs):
|
||||
data_samples: List[DataSample], **kwargs):
|
||||
"""Unpack data samples and compute loss."""
|
||||
num_classes = cls_score.size()[-1]
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
target = torch.stack(
|
||||
[i.gt_label.score.float() for i in data_samples])
|
||||
if 'gt_score' in data_samples[0]:
|
||||
target = torch.stack([i.gt_score for i in data_samples])
|
||||
else:
|
||||
target = torch.stack([
|
||||
LabelData.label_to_onehot(i.gt_label.label,
|
||||
num_classes).float()
|
||||
for i in data_samples
|
||||
])
|
||||
label_to_onehot(i.gt_label, num_classes) for i in data_samples
|
||||
]).float()
|
||||
|
||||
# compute loss
|
||||
losses = dict()
|
||||
|
@ -111,10 +107,9 @@ class MultiLabelClsHead(BaseHead):
|
|||
|
||||
return losses
|
||||
|
||||
def predict(
|
||||
self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
|
||||
def predict(self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
data_samples: List[DataSample] = None) -> List[DataSample]:
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
|
@ -122,12 +117,12 @@ class MultiLabelClsHead(BaseHead):
|
|||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. If not None, set ``pred_label`` of
|
||||
the input data samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[ClsDataSample]: A list of data samples which contains the
|
||||
List[DataSample]: A list of data samples which contains the
|
||||
predicted results.
|
||||
"""
|
||||
# The part can be traced by torch.fx
|
||||
|
@ -138,7 +133,7 @@ class MultiLabelClsHead(BaseHead):
|
|||
return predictions
|
||||
|
||||
def _get_predictions(self, cls_score: torch.Tensor,
|
||||
data_samples: List[ClsDataSample]):
|
||||
data_samples: List[DataSample]):
|
||||
"""Post-process the output of head.
|
||||
|
||||
Including softmax and set ``pred_label`` of data samples.
|
||||
|
@ -146,7 +141,7 @@ class MultiLabelClsHead(BaseHead):
|
|||
pred_scores = torch.sigmoid(cls_score)
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [ClsDataSample() for _ in range(cls_score.size(0))]
|
||||
data_samples = [DataSample() for _ in range(cls_score.size(0))]
|
||||
|
||||
for data_sample, score in zip(data_samples, pred_scores):
|
||||
if self.thr is not None:
|
||||
|
|
|
@ -63,7 +63,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
|
|||
- "tensor": Forward the whole network and return tensor without any
|
||||
post-processing, same as a common nn.Module.
|
||||
- "predict": Forward and return the predictions, which are fully
|
||||
processed to a list of :obj:`ClsDataSample`.
|
||||
processed to a list of :obj:`DataSample`.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
|
@ -73,7 +73,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
|
|||
Args:
|
||||
inputs (torch.Tensor, tuple): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. It's required if ``mode="loss"``.
|
||||
Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
@ -83,7 +83,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
|
|||
|
||||
- If ``mode="tensor"``, return a tensor.
|
||||
- If ``mode="predict"``, return a list of
|
||||
:obj:`mmpretrain.structures.ClsDataSample`.
|
||||
:obj:`mmpretrain.structures.DataSample`.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
pass
|
||||
|
@ -107,7 +107,7 @@ class BaseRetriever(BaseModel, metaclass=ABCMeta):
|
|||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmengine.runner import Runner
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from mmpretrain.utils import track_on_main_process
|
||||
from .base import BaseRetriever
|
||||
|
||||
|
@ -114,7 +114,7 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[ClsDataSample]] = None,
|
||||
data_samples: Optional[List[DataSample]] = None,
|
||||
mode: str = 'tensor'):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
|
@ -123,7 +123,7 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
- "tensor": Forward the whole network and return tensor without any
|
||||
post-processing, same as a common nn.Module.
|
||||
- "predict": Forward and return the predictions, which are fully
|
||||
processed to a list of :obj:`ClsDataSample`.
|
||||
processed to a list of :obj:`DataSample`.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
|
@ -133,7 +133,7 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
Args:
|
||||
inputs (torch.Tensor, tuple): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. It's required if ``mode="loss"``.
|
||||
Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
@ -143,7 +143,7 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
|
||||
- If ``mode="tensor"``, return a tensor.
|
||||
- If ``mode="predict"``, return a list of
|
||||
:obj:`mmpretrain.structures.ClsDataSample`.
|
||||
:obj:`mmpretrain.structures.DataSample`.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'tensor':
|
||||
|
@ -169,13 +169,13 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
return feat
|
||||
|
||||
def loss(self, inputs: torch.Tensor,
|
||||
data_samples: List[ClsDataSample]) -> dict:
|
||||
data_samples: List[DataSample]) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
|
||||
Returns:
|
||||
|
@ -200,18 +200,18 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
|
||||
def predict(self,
|
||||
inputs: tuple,
|
||||
data_samples: Optional[List[ClsDataSample]] = None,
|
||||
**kwargs) -> List[ClsDataSample]:
|
||||
data_samples: Optional[List[DataSample]] = None,
|
||||
**kwargs) -> List[DataSample]:
|
||||
"""Predict results from the extracted features.
|
||||
|
||||
Args:
|
||||
inputs (tuple): The features extracted from the backbone.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
**kwargs: Other keyword arguments accepted by the ``predict``
|
||||
method of :attr:`head`.
|
||||
Returns:
|
||||
List[ClsDataSample]: the raw data_samples with
|
||||
List[DataSample]: the raw data_samples with
|
||||
the predicted results
|
||||
"""
|
||||
if not self.prototype_inited:
|
||||
|
@ -240,8 +240,8 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
else:
|
||||
data_samples = []
|
||||
for score, label in zip(pred_scores, pred_labels):
|
||||
data_samples.append(ClsDataSample().set_pred_score(
|
||||
score).set_pred_label(label))
|
||||
data_samples.append(
|
||||
DataSample().set_pred_score(score).set_pred_label(label))
|
||||
return data_samples
|
||||
|
||||
def _get_prototype_vecs_from_dataloader(self, data_loader):
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import List
|
|||
from mmengine.model import BaseTTAModel
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -12,16 +12,16 @@ class AverageClsScoreTTA(BaseTTAModel):
|
|||
|
||||
def merge_preds(
|
||||
self,
|
||||
data_samples_list: List[List[ClsDataSample]],
|
||||
) -> List[ClsDataSample]:
|
||||
data_samples_list: List[List[DataSample]],
|
||||
) -> List[DataSample]:
|
||||
"""Merge predictions of enhanced data to one prediction.
|
||||
|
||||
Args:
|
||||
data_samples_list (List[List[ClsDataSample]]): List of predictions
|
||||
data_samples_list (List[List[DataSample]]): List of predictions
|
||||
of all enhanced data.
|
||||
|
||||
Returns:
|
||||
List[ClsDataSample]: Merged prediction.
|
||||
List[DataSample]: Merged prediction.
|
||||
"""
|
||||
merged_data_samples = []
|
||||
for data_samples in data_samples_list:
|
||||
|
@ -29,8 +29,8 @@ class AverageClsScoreTTA(BaseTTAModel):
|
|||
return merged_data_samples
|
||||
|
||||
def _merge_single_sample(self, data_samples):
|
||||
merged_data_sample: ClsDataSample = data_samples[0].new()
|
||||
merged_score = sum(data_sample.pred_label.score
|
||||
merged_data_sample: DataSample = data_samples[0].new()
|
||||
merged_score = sum(data_sample.pred_score
|
||||
for data_sample in data_samples) / len(data_samples)
|
||||
merged_data_sample.set_pred_score(merged_score)
|
||||
return merged_data_sample
|
||||
|
|
|
@ -8,9 +8,9 @@ import torch.nn.functional as F
|
|||
from mmengine.model import BaseDataPreprocessor, stack_batch
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import (ClsDataSample, MultiTaskDataSample,
|
||||
from mmpretrain.structures import (DataSample, MultiTaskDataSample,
|
||||
batch_label_to_onehot, cat_batch_labels,
|
||||
stack_batch_scores, tensor_split)
|
||||
tensor_split)
|
||||
from .batch_augments import RandomBatchAugment
|
||||
|
||||
|
||||
|
@ -153,23 +153,28 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
|
||||
data_samples = data.get('data_samples', None)
|
||||
sample_item = data_samples[0] if data_samples is not None else None
|
||||
if isinstance(sample_item,
|
||||
ClsDataSample) and 'gt_label' in sample_item:
|
||||
gt_labels = [sample.gt_label for sample in data_samples]
|
||||
batch_label, label_indices = cat_batch_labels(
|
||||
gt_labels, device=self.device)
|
||||
|
||||
batch_score = stack_batch_scores(gt_labels, device=self.device)
|
||||
if batch_score is None and self.to_onehot:
|
||||
if isinstance(sample_item, DataSample):
|
||||
batch_label = None
|
||||
batch_score = None
|
||||
|
||||
if 'gt_label' in sample_item:
|
||||
gt_labels = [sample.gt_label for sample in data_samples]
|
||||
batch_label, label_indices = cat_batch_labels(gt_labels)
|
||||
batch_label = batch_label.to(self.device)
|
||||
if 'gt_score' in sample_item:
|
||||
gt_scores = [sample.gt_score for sample in data_samples]
|
||||
batch_score = torch.stack(gt_scores).to(self.device)
|
||||
elif self.to_onehot:
|
||||
assert batch_label is not None, \
|
||||
'Cannot generate onehot format labels because no labels.'
|
||||
num_classes = self.num_classes or data_samples[0].get(
|
||||
num_classes = self.num_classes or sample_item.get(
|
||||
'num_classes')
|
||||
assert num_classes is not None, \
|
||||
'Cannot generate one-hot format labels because not set ' \
|
||||
'`num_classes` in `data_preprocessor`.'
|
||||
batch_score = batch_label_to_onehot(batch_label, label_indices,
|
||||
num_classes)
|
||||
batch_score = batch_label_to_onehot(
|
||||
batch_label, label_indices, num_classes).to(self.device)
|
||||
|
||||
# ----- Batch Augmentations ----
|
||||
if training and self.batch_augments is not None:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cls_data_sample import ClsDataSample
|
||||
from .data_sample import DataSample
|
||||
from .multi_task_data_sample import MultiTaskDataSample
|
||||
from .utils import (batch_label_to_onehot, cat_batch_labels,
|
||||
stack_batch_scores, tensor_split)
|
||||
from .utils import (batch_label_to_onehot, cat_batch_labels, label_to_onehot,
|
||||
tensor_split)
|
||||
|
||||
__all__ = [
|
||||
'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels',
|
||||
'stack_batch_scores', 'tensor_split', 'MultiTaskDataSample'
|
||||
'DataSample', 'batch_label_to_onehot', 'cat_batch_labels', 'tensor_split',
|
||||
'MultiTaskDataSample', 'label_to_onehot'
|
||||
]
|
||||
|
|
|
@ -1,235 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from numbers import Number
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.structures import BaseDataElement, LabelData
|
||||
from mmengine.utils import is_str
|
||||
|
||||
|
||||
def format_label(
|
||||
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
|
||||
"""Convert various python types to label-format tensor.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int`.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: The foramtted label tensor.
|
||||
"""
|
||||
|
||||
# Handle single number
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
|
||||
value = int(value.item())
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value).to(torch.long)
|
||||
elif isinstance(value, Sequence) and not is_str(value):
|
||||
value = torch.tensor(value).to(torch.long)
|
||||
elif isinstance(value, int):
|
||||
value = torch.LongTensor([value])
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
raise TypeError(f'Type {type(value)} is not an available label type.')
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def format_score(
|
||||
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor:
|
||||
"""Convert various python types to score-format tensor.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: The foramtted score tensor.
|
||||
"""
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value).float()
|
||||
elif isinstance(value, Sequence) and not is_str(value):
|
||||
value = torch.tensor(value).float()
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
raise TypeError(f'Type {type(value)} is not an available label type.')
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class ClsDataSample(BaseDataElement):
|
||||
"""A data structure interface of classification task.
|
||||
|
||||
It's used as interfaces between different components.
|
||||
|
||||
Meta fields:
|
||||
img_shape (Tuple): The shape of the corresponding input image.
|
||||
Used for visualization.
|
||||
ori_shape (Tuple): The original shape of the corresponding image.
|
||||
Used for visualization.
|
||||
num_classes (int): The number of all categories.
|
||||
Used for label format conversion.
|
||||
|
||||
Data fields:
|
||||
gt_label (:obj:`~mmengine.structures.LabelData`): The ground truth
|
||||
label.
|
||||
pred_label (:obj:`~mmengine.structures.LabelData`): The predicted
|
||||
label.
|
||||
scores (torch.Tensor): The outputs of model.
|
||||
logits (torch.Tensor): The outputs of model without softmax nor
|
||||
sigmoid.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>>
|
||||
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
|
||||
>>> data_sample = ClsDataSample(metainfo=img_meta)
|
||||
>>> data_sample.set_gt_label(3)
|
||||
>>> print(data_sample)
|
||||
<ClsDataSample(
|
||||
META INFORMATION
|
||||
num_classes = 5
|
||||
img_shape = (960, 720)
|
||||
DATA FIELDS
|
||||
gt_label: <LabelData(
|
||||
META INFORMATION
|
||||
num_classes: 5
|
||||
DATA FIELDS
|
||||
label: tensor([3])
|
||||
) at 0x7f21fb1b9190>
|
||||
) at 0x7f21fb1b9880>
|
||||
>>> # For multi-label data
|
||||
>>> data_sample.set_gt_label([0, 1, 4])
|
||||
>>> print(data_sample.gt_label)
|
||||
<LabelData(
|
||||
META INFORMATION
|
||||
num_classes: 5
|
||||
DATA FIELDS
|
||||
label: tensor([0, 1, 4])
|
||||
) at 0x7fd7d1b41970>
|
||||
>>> # Set one-hot format score
|
||||
>>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
>>> data_sample.set_pred_score(score)
|
||||
>>> print(data_sample.pred_label)
|
||||
<LabelData(
|
||||
META INFORMATION
|
||||
num_classes: 5
|
||||
DATA FIELDS
|
||||
score: tensor([0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
) at 0x7fd7d1b41970>
|
||||
"""
|
||||
|
||||
def set_gt_label(
|
||||
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
|
||||
) -> 'ClsDataSample':
|
||||
"""Set label of ``gt_label``."""
|
||||
label_data = getattr(self, '_gt_label', LabelData())
|
||||
label_data.label = format_label(value)
|
||||
self.gt_label = label_data
|
||||
return self
|
||||
|
||||
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample':
|
||||
"""Set score of ``gt_label``."""
|
||||
label_data = getattr(self, '_gt_label', LabelData())
|
||||
label_data.score = format_score(value)
|
||||
if hasattr(self, 'num_classes'):
|
||||
assert len(label_data.score) == self.num_classes, \
|
||||
f'The length of score {len(label_data.score)} should be '\
|
||||
f'equal to the num_classes {self.num_classes}.'
|
||||
else:
|
||||
self.set_field(
|
||||
name='num_classes',
|
||||
value=len(label_data.score),
|
||||
field_type='metainfo')
|
||||
self.gt_label = label_data
|
||||
return self
|
||||
|
||||
def set_pred_label(
|
||||
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
|
||||
) -> 'ClsDataSample':
|
||||
"""Set label of ``pred_label``."""
|
||||
label_data = getattr(self, '_pred_label', LabelData())
|
||||
label_data.label = format_label(value)
|
||||
self.pred_label = label_data
|
||||
return self
|
||||
|
||||
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample':
|
||||
"""Set score of ``pred_label``."""
|
||||
label_data = getattr(self, '_pred_label', LabelData())
|
||||
label_data.score = format_score(value)
|
||||
if hasattr(self, 'num_classes'):
|
||||
assert len(label_data.score) == self.num_classes, \
|
||||
f'The length of score {len(label_data.score)} should be '\
|
||||
f'equal to the num_classes {self.num_classes}.'
|
||||
else:
|
||||
self.set_field(
|
||||
name='num_classes',
|
||||
value=len(label_data.score),
|
||||
field_type='metainfo')
|
||||
self.pred_label = label_data
|
||||
return self
|
||||
|
||||
@property
|
||||
def gt_label(self):
|
||||
return self._gt_label
|
||||
|
||||
@gt_label.setter
|
||||
def gt_label(self, value: LabelData):
|
||||
self.set_field(value, '_gt_label', dtype=LabelData)
|
||||
|
||||
@gt_label.deleter
|
||||
def gt_label(self):
|
||||
del self._gt_label
|
||||
|
||||
@property
|
||||
def pred_label(self):
|
||||
return self._pred_label
|
||||
|
||||
@pred_label.setter
|
||||
def pred_label(self, value: LabelData):
|
||||
self.set_field(value, '_pred_label', dtype=LabelData)
|
||||
|
||||
@pred_label.deleter
|
||||
def pred_label(self):
|
||||
del self._pred_label
|
||||
|
||||
|
||||
def _reduce_cls_datasample(data_sample):
|
||||
"""reduce ClsDataSample."""
|
||||
attr_dict = data_sample.__dict__
|
||||
convert_keys = []
|
||||
for k, v in attr_dict.items():
|
||||
if isinstance(v, LabelData):
|
||||
attr_dict[k] = v.numpy()
|
||||
convert_keys.append(k)
|
||||
return _rebuild_cls_datasample, (attr_dict, convert_keys)
|
||||
|
||||
|
||||
def _rebuild_cls_datasample(attr_dict, convert_keys):
|
||||
"""rebuild ClsDataSample."""
|
||||
data_sample = ClsDataSample()
|
||||
for k in convert_keys:
|
||||
attr_dict[k] = attr_dict[k].to_tensor()
|
||||
data_sample.__dict__ = attr_dict
|
||||
return data_sample
|
||||
|
||||
|
||||
# Due to the multi-processing strategy of PyTorch, ClsDataSample may consume
|
||||
# many file descriptors because it contains multiple LabelData with tensors.
|
||||
# Here we overwrite the reduce function of ClsDataSample in ForkingPickler and
|
||||
# convert these tensors to np.ndarray during pickling. It may influence the
|
||||
# performance of dataloader, but slightly because these tensors in LabelData
|
||||
# are very small.
|
||||
ForkingPickler.register(ClsDataSample, _reduce_cls_datasample)
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
from .utils import LABEL_TYPE, SCORE_TYPE, format_label, format_score
|
||||
|
||||
|
||||
class DataSample(BaseDataElement):
|
||||
"""A general data structure interface.
|
||||
|
||||
It's used as the interface between different components.
|
||||
|
||||
The following fields are convention names in MMPretrain, and we will set or
|
||||
get these fields in data transforms, models, and metrics if needed. You can
|
||||
also set any new fields for your need.
|
||||
|
||||
Meta fields:
|
||||
img_shape (Tuple): The shape of the corresponding input image.
|
||||
ori_shape (Tuple): The original shape of the corresponding image.
|
||||
sample_idx (int): The index of the sample in the dataset.
|
||||
num_classes (int): The number of all categories.
|
||||
|
||||
Data fields:
|
||||
gt_label (tensor): The ground truth label.
|
||||
gt_score (tensor): The ground truth score.
|
||||
pred_label (tensor): The predicted label.
|
||||
pred_score (tensor): The predicted score.
|
||||
mask (tensor): The mask used in masked image modeling.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmpretrain.structures import DataSample
|
||||
>>>
|
||||
>>> img_meta = dict(img_shape=(960, 720), num_classes=5)
|
||||
>>> data_sample = DataSample(metainfo=img_meta)
|
||||
>>> data_sample.set_gt_label(3)
|
||||
>>> print(data_sample)
|
||||
<DataSample(
|
||||
META INFORMATION
|
||||
num_classes: 5
|
||||
img_shape: (960, 720)
|
||||
DATA FIELDS
|
||||
gt_label: tensor([3])
|
||||
) at 0x7ff64c1c1d30>
|
||||
>>>
|
||||
>>> # For multi-label data
|
||||
>>> data_sample = DataSample().set_gt_label([0, 1, 4])
|
||||
>>> print(data_sample)
|
||||
<DataSample(
|
||||
DATA FIELDS
|
||||
gt_label: tensor([0, 1, 4])
|
||||
) at 0x7ff5b490e100>
|
||||
>>>
|
||||
>>> # Set one-hot format score
|
||||
>>> data_sample = DataSample().set_pred_score([0.1, 0.1, 0.6, 0.1])
|
||||
>>> print(data_sample)
|
||||
<DataSample(
|
||||
META INFORMATION
|
||||
num_classes: 4
|
||||
DATA FIELDS
|
||||
pred_score: tensor([0.1000, 0.1000, 0.6000, 0.1000])
|
||||
) at 0x7ff5b48ef6a0>
|
||||
>>>
|
||||
>>> # Set custom field
|
||||
>>> data_sample = DataSample()
|
||||
>>> data_sample.my_field = [1, 2, 3]
|
||||
>>> print(data_sample)
|
||||
<DataSample(
|
||||
DATA FIELDS
|
||||
my_field: [1, 2, 3]
|
||||
) at 0x7f8e9603d3a0>
|
||||
>>> print(data_sample.my_field)
|
||||
[1, 2, 3]
|
||||
"""
|
||||
|
||||
def set_gt_label(self, value: LABEL_TYPE) -> 'DataSample':
|
||||
"""Set ``gt_label``."""
|
||||
self.set_field(format_label(value), 'gt_label', dtype=torch.Tensor)
|
||||
return self
|
||||
|
||||
def set_gt_score(self, value: SCORE_TYPE) -> 'DataSample':
|
||||
"""Set ``gt_score``."""
|
||||
score = format_score(value)
|
||||
self.set_field(score, 'gt_score', dtype=torch.Tensor)
|
||||
if hasattr(self, 'num_classes'):
|
||||
assert len(score) == self.num_classes, \
|
||||
f'The length of score {len(score)} should be '\
|
||||
f'equal to the num_classes {self.num_classes}.'
|
||||
else:
|
||||
self.set_field(
|
||||
name='num_classes', value=len(score), field_type='metainfo')
|
||||
return self
|
||||
|
||||
def set_pred_label(self, value: LABEL_TYPE) -> 'DataSample':
|
||||
"""Set ``pred_label``."""
|
||||
self.set_field(format_label(value), 'pred_label', dtype=torch.Tensor)
|
||||
return self
|
||||
|
||||
def set_pred_score(self, value: SCORE_TYPE):
|
||||
"""Set ``pred_label``."""
|
||||
score = format_score(value)
|
||||
self.set_field(score, 'pred_score', dtype=torch.Tensor)
|
||||
if hasattr(self, 'num_classes'):
|
||||
assert len(score) == self.num_classes, \
|
||||
f'The length of score {len(score)} should be '\
|
||||
f'equal to the num_classes {self.num_classes}.'
|
||||
else:
|
||||
self.set_field(
|
||||
name='num_classes', value=len(score), field_type='metainfo')
|
||||
return self
|
||||
|
||||
def set_mask(self, value: Union[torch.Tensor, np.ndarray]):
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value)
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
raise TypeError(f'Invalid mask type {type(value)}')
|
||||
self.set_field(value, 'mask', dtype=torch.Tensor)
|
||||
return self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Represent the object."""
|
||||
|
||||
def dump_items(items, prefix=''):
|
||||
return '\n'.join(f'{prefix}{k}: {v}' for k, v in items)
|
||||
|
||||
repr_ = ''
|
||||
if len(self._metainfo_fields) > 0:
|
||||
repr_ += '\n\nMETA INFORMATION\n'
|
||||
repr_ += dump_items(self.metainfo_items(), prefix=' ' * 4)
|
||||
if len(self._data_fields) > 0:
|
||||
repr_ += '\n\nDATA FIELDS\n'
|
||||
repr_ += dump_items(self.items(), prefix=' ' * 4)
|
||||
|
||||
repr_ = f'<{self.__class__.__name__}({repr_}\n\n) at {hex(id(self))}>'
|
||||
return repr_
|
||||
|
||||
|
||||
def _reduce_datasample(data_sample):
|
||||
"""reduce DataSample."""
|
||||
attr_dict = data_sample.__dict__
|
||||
convert_keys = []
|
||||
for k, v in attr_dict.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
attr_dict[k] = v.numpy()
|
||||
convert_keys.append(k)
|
||||
return _rebuild_datasample, (attr_dict, convert_keys)
|
||||
|
||||
|
||||
def _rebuild_datasample(attr_dict, convert_keys):
|
||||
"""rebuild DataSample."""
|
||||
data_sample = DataSample()
|
||||
for k in convert_keys:
|
||||
attr_dict[k] = torch.from_numpy(attr_dict[k])
|
||||
data_sample.__dict__ = attr_dict
|
||||
return data_sample
|
||||
|
||||
|
||||
# Due to the multi-processing strategy of PyTorch, DataSample may consume many
|
||||
# file descriptors because it contains multiple tensors. Here we overwrite the
|
||||
# reduce function of DataSample in ForkingPickler and convert these tensors to
|
||||
# np.ndarray during pickling. It may slightly influence the performance of
|
||||
# dataloader.
|
||||
ForkingPickler.register(DataSample, _reduce_datasample)
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.structures import LabelData
|
||||
from mmengine.utils import is_str
|
||||
|
||||
if hasattr(torch, 'tensor_split'):
|
||||
tensor_split = torch.tensor_split
|
||||
|
@ -16,30 +17,82 @@ else:
|
|||
return outs
|
||||
|
||||
|
||||
def cat_batch_labels(elements: List[LabelData], device=None):
|
||||
"""Concat the ``label`` of a batch of :obj:`LabelData` to a tensor.
|
||||
LABEL_TYPE = Union[torch.Tensor, np.ndarray, Sequence, int]
|
||||
SCORE_TYPE = Union[torch.Tensor, np.ndarray, Sequence]
|
||||
|
||||
|
||||
def format_label(value: LABEL_TYPE) -> torch.Tensor:
|
||||
"""Convert various python types to label-format tensor.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int`.
|
||||
|
||||
Args:
|
||||
elements (List[LabelData]): A batch of :obj`LabelData`.
|
||||
device (torch.device, optional): The output device of the batch label.
|
||||
Defaults to None.
|
||||
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: The foramtted label tensor.
|
||||
"""
|
||||
|
||||
# Handle single number
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
|
||||
value = int(value.item())
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value).to(torch.long)
|
||||
elif isinstance(value, Sequence) and not is_str(value):
|
||||
value = torch.tensor(value).to(torch.long)
|
||||
elif isinstance(value, int):
|
||||
value = torch.LongTensor([value])
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
raise TypeError(f'Type {type(value)} is not an available label type.')
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def format_score(value: SCORE_TYPE) -> torch.Tensor:
|
||||
"""Convert various python types to score-format tensor.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor | numpy.ndarray | Sequence): Score values.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: The foramtted score tensor.
|
||||
"""
|
||||
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value).float()
|
||||
elif isinstance(value, Sequence) and not is_str(value):
|
||||
value = torch.tensor(value).float()
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
raise TypeError(f'Type {type(value)} is not an available label type.')
|
||||
assert value.ndim == 1, \
|
||||
f'The dims of value should be 1, but got {value.ndim}.'
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def cat_batch_labels(elements: List[torch.Tensor]):
|
||||
"""Concat a batch of label tensor to one tensor.
|
||||
|
||||
Args:
|
||||
elements (List[tensor]): A batch of labels.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, List[int]]: The first item is the concated label
|
||||
tensor, and the second item is the split indices of every sample.
|
||||
"""
|
||||
item = elements[0]
|
||||
if 'label' not in item._data_fields:
|
||||
return None, None
|
||||
|
||||
labels = []
|
||||
splits = [0]
|
||||
for element in elements:
|
||||
labels.append(element.label)
|
||||
splits.append(splits[-1] + element.label.size(0))
|
||||
labels.append(element)
|
||||
splits.append(splits[-1] + element.size(0))
|
||||
batch_label = torch.cat(labels)
|
||||
if device is not None:
|
||||
batch_label = batch_label.to(device=device)
|
||||
return batch_label, splits[1:-1]
|
||||
|
||||
|
||||
|
@ -75,22 +128,26 @@ def batch_label_to_onehot(batch_label, split_indices, num_classes):
|
|||
return torch.stack(onehot_list)
|
||||
|
||||
|
||||
def stack_batch_scores(elements, device=None):
|
||||
"""Stack the ``score`` of a batch of :obj:`LabelData` to a tensor.
|
||||
def label_to_onehot(label: LABEL_TYPE, num_classes: int):
|
||||
"""Convert a label to onehot format tensor.
|
||||
|
||||
Args:
|
||||
elements (List[LabelData]): A batch of :obj`LabelData`.
|
||||
device (torch.device, optional): The output device of the batch label.
|
||||
Defaults to None.
|
||||
label (LABEL_TYPE): Label value.
|
||||
num_classes (int): The number of classes.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The stacked score tensor.
|
||||
"""
|
||||
item = elements[0]
|
||||
if 'score' not in item._data_fields:
|
||||
return None
|
||||
torch.Tensor: The onehot format label tensor.
|
||||
|
||||
batch_score = torch.stack([element.score for element in elements])
|
||||
if device is not None:
|
||||
batch_score = batch_score.to(device)
|
||||
return batch_score
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmpretrain.structures import label_to_onehot
|
||||
>>> # Single-label
|
||||
>>> label_to_onehot(1, num_classes=5)
|
||||
tensor([0, 1, 0, 0, 0])
|
||||
>>> # Multi-label
|
||||
>>> label_to_onehot([0, 2, 3], num_classes=5)
|
||||
tensor([1, 0, 1, 1, 0])
|
||||
"""
|
||||
label = format_label(label)
|
||||
sparse_onehot = F.one_hot(label, num_classes)
|
||||
return sparse_onehot.sum(0)
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmengine.dist import master_only
|
|||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmpretrain.registry import VISUALIZERS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
def _get_adaptive_scale(img_shape: Tuple[int, int],
|
||||
|
@ -57,11 +57,11 @@ class ClsVisualizer(Visualizer):
|
|||
>>> import mmcv
|
||||
>>> from pathlib import Path
|
||||
>>> from mmpretrain.visualization import ClsVisualizer
|
||||
>>> from mmpretrain.structures import ClsDataSample
|
||||
>>> from mmpretrain.structures import DataSample
|
||||
>>> # Example image
|
||||
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
|
||||
>>> # Example annotation
|
||||
>>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
|
||||
>>> data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
|
||||
... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
|
||||
>>> # Setup the visualizer
|
||||
>>> vis = ClsVisualizer(
|
||||
|
@ -84,7 +84,7 @@ class ClsVisualizer(Visualizer):
|
|||
def add_datasample(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
data_sample: Optional[ClsDataSample] = None,
|
||||
data_sample: Optional[DataSample] = None,
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
draw_score: bool = True,
|
||||
|
@ -104,7 +104,7 @@ class ClsVisualizer(Visualizer):
|
|||
Args:
|
||||
name (str): The image identifier.
|
||||
image (np.ndarray): The image to draw.
|
||||
data_sample (:obj:`ClsDataSample`, optional): The annotation of the
|
||||
data_sample (:obj:`DataSample`, optional): The annotation of the
|
||||
image. Defaults to None.
|
||||
draw_gt (bool): Whether to draw ground truth labels.
|
||||
Defaults to True.
|
||||
|
@ -137,8 +137,7 @@ class ClsVisualizer(Visualizer):
|
|||
self.set_image(image)
|
||||
|
||||
if draw_gt and 'gt_label' in data_sample:
|
||||
gt_label = data_sample.gt_label
|
||||
idx = gt_label.label.tolist()
|
||||
idx = data_sample.gt_label.tolist()
|
||||
class_labels = [''] * len(idx)
|
||||
if classes is not None:
|
||||
class_labels = [f' ({classes[i]})' for i in idx]
|
||||
|
@ -147,13 +146,12 @@ class ClsVisualizer(Visualizer):
|
|||
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
|
||||
|
||||
if draw_pred and 'pred_label' in data_sample:
|
||||
pred_label = data_sample.pred_label
|
||||
idx = pred_label.label.tolist()
|
||||
idx = data_sample.pred_label.tolist()
|
||||
score_labels = [''] * len(idx)
|
||||
class_labels = [''] * len(idx)
|
||||
if draw_score and 'score' in pred_label:
|
||||
if draw_score and 'pred_score' in data_sample:
|
||||
score_labels = [
|
||||
f', {pred_label.score[i].item():.2f}' for i in idx
|
||||
f', {data_sample.pred_score[i].item():.2f}' for i in idx
|
||||
]
|
||||
|
||||
if classes is not None:
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmcv.image import imread
|
|||
from mmpretrain.apis import (ImageClassificationInferencer, ModelHub,
|
||||
get_model, inference_model)
|
||||
from mmpretrain.models import MobileNetV3
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from mmpretrain.visualization import ClsVisualizer
|
||||
|
||||
MODEL = 'mobilenet-v3-small-050_3rdparty_in1k'
|
||||
|
@ -58,7 +58,7 @@ class TestImageClassificationInferencer(TestCase):
|
|||
|
||||
# test return_datasample=True
|
||||
results = inferencer(img, return_datasamples=True)[0]
|
||||
self.assertIsInstance(results, ClsDataSample)
|
||||
self.assertIsInstance(results, DataSample)
|
||||
|
||||
def test_visualize(self):
|
||||
img_path = osp.join(osp.dirname(__file__), '../data/color.jpg')
|
||||
|
|
|
@ -6,11 +6,10 @@ import unittest
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.structures import LabelData
|
||||
from PIL import Image
|
||||
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmpretrain.structures import DataSample, MultiTaskDataSample
|
||||
|
||||
|
||||
class TestPackClsInputs(unittest.TestCase):
|
||||
|
@ -34,9 +33,9 @@ class TestPackClsInputs(unittest.TestCase):
|
|||
self.assertIn('inputs', results)
|
||||
self.assertIsInstance(results['inputs'], torch.Tensor)
|
||||
self.assertIn('data_samples', results)
|
||||
self.assertIsInstance(results['data_samples'], ClsDataSample)
|
||||
self.assertIsInstance(results['data_samples'], DataSample)
|
||||
self.assertIn('flip', results['data_samples'].metainfo_keys())
|
||||
self.assertIsInstance(results['data_samples'].gt_label, LabelData)
|
||||
self.assertIsInstance(results['data_samples'].gt_label, torch.Tensor)
|
||||
|
||||
# Test grayscale image
|
||||
data['img'] = data['img'].mean(-1)
|
||||
|
@ -155,7 +154,7 @@ class TestPackMultiTaskInputs(unittest.TestCase):
|
|||
self.assertIsInstance(results['data_samples'], MultiTaskDataSample)
|
||||
self.assertIn('flip', results['data_samples'].task1.metainfo_keys())
|
||||
self.assertIsInstance(results['data_samples'].task1.gt_label,
|
||||
LabelData)
|
||||
torch.Tensor)
|
||||
|
||||
# Test grayscale image
|
||||
data['img'] = data['img'].mean(-1)
|
||||
|
|
|
@ -14,7 +14,7 @@ from mmengine.runner import Runner
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmpretrain.registry import HOOKS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
@ -36,7 +36,7 @@ class MockDataPreprocessor(BaseDataPreprocessor):
|
|||
|
||||
def forward(self, data, training):
|
||||
|
||||
return data['imgs'], ClsDataSample()
|
||||
return data['imgs'], DataSample()
|
||||
|
||||
|
||||
class ExampleModel(BaseModel):
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop
|
|||
|
||||
from mmpretrain.engine import VisualizationHook
|
||||
from mmpretrain.registry import HOOKS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from mmpretrain.visualization import ClsVisualizer
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ class TestVisualizationHook(TestCase):
|
|||
def setUp(self) -> None:
|
||||
ClsVisualizer.get_instance('visualizer')
|
||||
|
||||
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2)
|
||||
data_sample = DataSample().set_gt_label(1).set_pred_label(2)
|
||||
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
|
||||
self.data_batch = {
|
||||
'inputs': torch.randint(0, 256, (10, 3, 224, 224)),
|
||||
|
@ -53,7 +53,7 @@ class TestVisualizationHook(TestCase):
|
|||
cfg = dict(type='VisualizationHook', enable=True)
|
||||
hook: VisualizationHook = HOOKS.build(cfg)
|
||||
with patch.object(hook._visualizer, 'add_datasample') as mock:
|
||||
outputs = [ClsDataSample()] * 10
|
||||
outputs = [DataSample()] * 10
|
||||
hook._draw_samples(0, self.data_batch, outputs, step=0)
|
||||
mock.assert_called_once_with(
|
||||
'0', image=ANY, data_sample=outputs[0], step=0, show=False)
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmengine.evaluator import Evaluator
|
|||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmpretrain.evaluation.metrics import AveragePrecision, MultiLabelMetric
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
init_default_scope('mmpretrain')
|
||||
|
||||
|
@ -152,7 +152,7 @@ class TestMultiLabel(TestCase):
|
|||
])
|
||||
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
for i, j in zip(y_pred_score, y_true)
|
||||
]
|
||||
|
||||
|
@ -261,7 +261,7 @@ class TestMultiLabel(TestCase):
|
|||
|
||||
# Test with gt_score
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
|
||||
DataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
|
||||
for i, j in zip(y_pred_score, y_true_binary)
|
||||
]
|
||||
|
||||
|
@ -304,7 +304,7 @@ class TestAveragePrecision(TestCase):
|
|||
])
|
||||
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
|
||||
DataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
|
||||
for i, j in zip(y_pred, y_true)
|
||||
]
|
||||
|
||||
|
@ -328,7 +328,7 @@ class TestAveragePrecision(TestCase):
|
|||
|
||||
# Test with gt_label without score
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
for i, j in zip(y_pred, [[0, 1], [1], [2], [0]])
|
||||
]
|
||||
evaluator = Evaluator(dict(type='AveragePrecision'))
|
||||
|
|
|
@ -4,7 +4,7 @@ from unittest import TestCase
|
|||
import torch
|
||||
|
||||
from mmpretrain.evaluation.metrics import MultiTasksMetric
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class MultiTaskMetric(TestCase):
|
||||
|
@ -24,7 +24,7 @@ class MultiTaskMetric(TestCase):
|
|||
for i, pred in enumerate(data_pred):
|
||||
sample = {}
|
||||
for task_name in pred:
|
||||
task_sample = ClsDataSample().set_pred_score(pred[task_name])
|
||||
task_sample = DataSample().set_pred_score(pred[task_name])
|
||||
if task_name in data_gt[i]:
|
||||
task_sample.set_gt_label(data_gt[i][task_name])
|
||||
task_sample.set_field(True, 'eval_mask', field_type='metainfo')
|
||||
|
@ -68,7 +68,7 @@ class MultiTaskMetric(TestCase):
|
|||
sample = {}
|
||||
for task_name in score:
|
||||
if type(score[task_name]) != dict:
|
||||
task_sample = ClsDataSample().set_pred_score(score[task_name])
|
||||
task_sample = DataSample().set_pred_score(score[task_name])
|
||||
task_sample.set_gt_label(label[task_name])
|
||||
sample[task_name] = task_sample.to_dict()
|
||||
sample[task_name]['eval_mask'] = True
|
||||
|
@ -76,7 +76,7 @@ class MultiTaskMetric(TestCase):
|
|||
sample[task_name] = {}
|
||||
sample[task_name]['eval_mask'] = True
|
||||
for task_name2 in score[task_name]:
|
||||
task_sample = ClsDataSample().set_pred_score(
|
||||
task_sample = DataSample().set_pred_score(
|
||||
score[task_name][task_name2])
|
||||
task_sample.set_gt_label(label[task_name][task_name2])
|
||||
sample[task_name][task_name2] = task_sample.to_dict()
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from mmpretrain.evaluation.metrics import RetrievalRecall
|
||||
from mmpretrain.registry import METRICS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class TestRetrievalRecall(TestCase):
|
||||
|
@ -14,7 +14,7 @@ class TestRetrievalRecall(TestCase):
|
|||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_gt_label(k).to_dict()
|
||||
DataSample().set_pred_score(i).set_gt_label(k).to_dict()
|
||||
for i, k in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
from mmpretrain.evaluation.metrics import (Accuracy, ConfusionMatrix,
|
||||
SingleLabelMetric)
|
||||
from mmpretrain.registry import METRICS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class TestAccuracy(TestCase):
|
||||
|
@ -16,7 +16,7 @@ class TestAccuracy(TestCase):
|
|||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
k).to_dict() for i, j, k in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
|
@ -52,7 +52,7 @@ class TestAccuracy(TestCase):
|
|||
|
||||
# Test with label
|
||||
for sample in pred:
|
||||
del sample['pred_label']['score']
|
||||
del sample['pred_score']
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
|
||||
metric.process(None, pred)
|
||||
acc = metric.evaluate(6)
|
||||
|
@ -123,7 +123,7 @@ class TestSingleLabel(TestCase):
|
|||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
k).to_dict() for i, j, k in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
|
@ -212,7 +212,7 @@ class TestSingleLabel(TestCase):
|
|||
# Test with label, the thrs will be ignored
|
||||
pred_no_score = copy.deepcopy(pred)
|
||||
for sample in pred_no_score:
|
||||
del sample['pred_label']['score']
|
||||
del sample['pred_score']
|
||||
del sample['num_classes']
|
||||
metric = METRICS.build(
|
||||
dict(type='SingleLabelMetric', thrs=(0., 0.6), num_classes=3))
|
||||
|
@ -304,7 +304,7 @@ class TestConfusionMatrix(TestCase):
|
|||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
DataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
k).to_dict() for i, j, k in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
|
@ -330,7 +330,7 @@ class TestConfusionMatrix(TestCase):
|
|||
|
||||
# Test with label
|
||||
for sample in pred:
|
||||
del sample['pred_label']['score']
|
||||
del sample['pred_score']
|
||||
metric = METRICS.build(dict(type='ConfusionMatrix'))
|
||||
metric.process(None, pred)
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
init_default_scope('mmpretrain')
|
||||
|
||||
|
@ -27,7 +27,7 @@ class TestVOCMultiLabel(TestCase):
|
|||
|
||||
# generate data samples
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
DataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
for i, j in zip(y_pred_score, y_true_label)
|
||||
]
|
||||
for sample, difficult_label in zip(pred, y_true_difficult):
|
||||
|
@ -155,7 +155,7 @@ class TestVOCAveragePrecision(TestCase):
|
|||
|
||||
# generate data samples
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(
|
||||
DataSample(num_classes=4).set_pred_score(i).set_gt_score(
|
||||
j).set_gt_label(k)
|
||||
for i, j, k in zip(y_pred_score, y_true, y_true_label)
|
||||
]
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmengine import ConfigDict
|
|||
|
||||
from mmpretrain.models import ImageClassifier
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
def has_timm() -> bool:
|
||||
|
@ -134,7 +134,7 @@ class TestImageClassifier(TestCase):
|
|||
|
||||
def test_loss(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
losses = model.loss(inputs, data_samples)
|
||||
|
@ -142,21 +142,21 @@ class TestImageClassifier(TestCase):
|
|||
|
||||
def test_predict(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
predictions = model.predict(inputs)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
||||
predictions = model.predict(inputs, data_samples)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
||||
predictions[0].pred_label.score)
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_score,
|
||||
predictions[0].pred_score)
|
||||
|
||||
def test_forward(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# test pure forward
|
||||
|
@ -169,13 +169,13 @@ class TestImageClassifier(TestCase):
|
|||
|
||||
# test forward test
|
||||
predictions = model(inputs, mode='predict')
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
||||
predictions = model(inputs, data_samples, mode='predict')
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
||||
predictions[0].pred_label.score)
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_score,
|
||||
predictions[0].pred_score)
|
||||
|
||||
# test forward with invalid mode
|
||||
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
|
||||
|
@ -190,7 +190,7 @@ class TestImageClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
optim_wrapper = MagicMock()
|
||||
|
@ -207,11 +207,11 @@ class TestImageClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.val_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
||||
def test_test_step(self):
|
||||
cfg = {
|
||||
|
@ -222,11 +222,11 @@ class TestImageClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.test_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
||||
|
||||
@unittest.skipIf(not has_timm(), 'timm is not installed.')
|
||||
|
@ -255,7 +255,7 @@ class TestTimmClassifier(TestCase):
|
|||
|
||||
def test_loss(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
losses = model.loss(inputs, data_samples)
|
||||
|
@ -263,21 +263,21 @@ class TestTimmClassifier(TestCase):
|
|||
|
||||
def test_predict(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
predictions = model.predict(inputs)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
|
||||
predictions = model.predict(inputs, data_samples)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
||||
predictions[0].pred_label.score)
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_score,
|
||||
predictions[0].pred_score)
|
||||
|
||||
def test_forward(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# test pure forward
|
||||
|
@ -290,13 +290,13 @@ class TestTimmClassifier(TestCase):
|
|||
|
||||
# test forward test
|
||||
predictions = model(inputs, mode='predict')
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
|
||||
predictions = model(inputs, data_samples, mode='predict')
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
||||
predictions[0].pred_label.score)
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_score,
|
||||
predictions[0].pred_score)
|
||||
|
||||
# test forward with invalid mode
|
||||
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
|
||||
|
@ -311,7 +311,7 @@ class TestTimmClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
optim_wrapper = MagicMock()
|
||||
|
@ -328,11 +328,11 @@ class TestTimmClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.val_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
|
||||
def test_test_step(self):
|
||||
cfg = {
|
||||
|
@ -343,11 +343,11 @@ class TestTimmClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.test_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
|
||||
|
||||
@unittest.skipIf(not has_huggingface(), 'huggingface is not installed.')
|
||||
|
@ -376,7 +376,7 @@ class TestHuggingFaceClassifier(TestCase):
|
|||
|
||||
def test_loss(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
losses = model.loss(inputs, data_samples)
|
||||
|
@ -384,21 +384,21 @@ class TestHuggingFaceClassifier(TestCase):
|
|||
|
||||
def test_predict(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
predictions = model.predict(inputs)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
|
||||
predictions = model.predict(inputs, data_samples)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
||||
predictions[0].pred_label.score)
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_score,
|
||||
predictions[0].pred_score)
|
||||
|
||||
def test_forward(self):
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
model: ImageClassifier = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# test pure forward
|
||||
|
@ -411,13 +411,13 @@ class TestHuggingFaceClassifier(TestCase):
|
|||
|
||||
# test forward test
|
||||
predictions = model(inputs, mode='predict')
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
|
||||
predictions = model(inputs, data_samples, mode='predict')
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (1000, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
||||
predictions[0].pred_label.score)
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (1000, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_score,
|
||||
predictions[0].pred_score)
|
||||
|
||||
# test forward with invalid mode
|
||||
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
|
||||
|
@ -432,7 +432,7 @@ class TestHuggingFaceClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
optim_wrapper = MagicMock()
|
||||
|
@ -449,11 +449,11 @@ class TestHuggingFaceClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.val_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
|
||||
def test_test_step(self):
|
||||
cfg = {
|
||||
|
@ -464,8 +464,8 @@ class TestHuggingFaceClassifier(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.test_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (1000, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (1000, ))
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
|||
from mmengine import is_seq_of
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmpretrain.structures import DataSample, MultiTaskDataSample
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
|
@ -43,7 +43,7 @@ class TestClsHead(TestCase):
|
|||
|
||||
def test_loss(self):
|
||||
feats = self.FAKE_FEATS
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
|
||||
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
||||
|
||||
# with cal_acc = False
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
@ -75,23 +75,23 @@ class TestClsHead(TestCase):
|
|||
|
||||
def test_predict(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
|
||||
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# with without data_samples
|
||||
predictions = head.predict(feats)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
self.assertTrue(is_seq_of(predictions, DataSample))
|
||||
for pred in predictions:
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
self.assertIn('pred_label', pred)
|
||||
self.assertIn('pred_score', pred)
|
||||
|
||||
# with with data_samples
|
||||
predictions = head.predict(feats, data_samples)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
self.assertTrue(is_seq_of(predictions, DataSample))
|
||||
for sample, pred in zip(data_samples, predictions):
|
||||
self.assertIs(sample, pred)
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
self.assertIn('pred_label', pred)
|
||||
self.assertIn('pred_score', pred)
|
||||
|
||||
|
||||
class TestLinearClsHead(TestCase):
|
||||
|
@ -224,7 +224,7 @@ class TestConformerHead(TestCase):
|
|||
self.assertEqual(outs[1].shape, (4, 5))
|
||||
|
||||
def test_loss(self):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
|
||||
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
||||
|
||||
# with cal_acc = False
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
@ -255,23 +255,23 @@ class TestConformerHead(TestCase):
|
|||
head.loss(self.fake_feats, data_samples)
|
||||
|
||||
def test_predict(self):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
|
||||
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# with without data_samples
|
||||
predictions = head.predict(self.fake_feats)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
self.assertTrue(is_seq_of(predictions, DataSample))
|
||||
for pred in predictions:
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
self.assertIn('pred_label', pred)
|
||||
self.assertIn('pred_score', pred)
|
||||
|
||||
# with with data_samples
|
||||
predictions = head.predict(self.fake_feats, data_samples)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
self.assertTrue(is_seq_of(predictions, DataSample))
|
||||
for sample, pred in zip(data_samples, predictions):
|
||||
self.assertIs(sample, pred)
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
self.assertIn('pred_label', pred)
|
||||
self.assertIn('pred_score', pred)
|
||||
|
||||
|
||||
class TestStackedLinearClsHead(TestCase):
|
||||
|
@ -338,7 +338,7 @@ class TestMultiLabelClsHead(TestCase):
|
|||
|
||||
def test_loss(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = [ClsDataSample().set_gt_label([0, 3]) for _ in range(4)]
|
||||
data_samples = [DataSample().set_gt_label([0, 3]) for _ in range(4)]
|
||||
|
||||
# Test with thr and topk are all None
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
@ -383,7 +383,7 @@ class TestMultiLabelClsHead(TestCase):
|
|||
|
||||
# Test with gt_lable with score
|
||||
data_samples = [
|
||||
ClsDataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
|
||||
DataSample().set_gt_score(torch.rand((10, ))) for _ in range(4)
|
||||
]
|
||||
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
@ -395,23 +395,23 @@ class TestMultiLabelClsHead(TestCase):
|
|||
|
||||
def test_predict(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2]) for _ in range(4)]
|
||||
data_samples = [DataSample().set_gt_label([1, 2]) for _ in range(4)]
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# with without data_samples
|
||||
predictions = head.predict(feats)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
self.assertTrue(is_seq_of(predictions, DataSample))
|
||||
for pred in predictions:
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
self.assertIn('pred_label', pred)
|
||||
self.assertIn('pred_score', pred)
|
||||
|
||||
# with with data_samples
|
||||
predictions = head.predict(feats, data_samples)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
self.assertTrue(is_seq_of(predictions, DataSample))
|
||||
for sample, pred in zip(data_samples, predictions):
|
||||
self.assertIs(sample, pred)
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
self.assertIn('pred_label', pred)
|
||||
self.assertIn('pred_score', pred)
|
||||
|
||||
# Test with topk
|
||||
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||
|
@ -419,11 +419,11 @@ class TestMultiLabelClsHead(TestCase):
|
|||
head = MODELS.build(cfg)
|
||||
predictions = head.predict(feats, data_samples)
|
||||
self.assertEqual(head.thr, None)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
self.assertTrue(is_seq_of(predictions, DataSample))
|
||||
for sample, pred in zip(data_samples, predictions):
|
||||
self.assertIs(sample, pred)
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
self.assertIn('pred_label', pred)
|
||||
self.assertIn('pred_score', pred)
|
||||
|
||||
|
||||
class EfficientFormerClsHead(TestClsHead):
|
||||
|
@ -454,7 +454,7 @@ class EfficientFormerClsHead(TestClsHead):
|
|||
|
||||
def test_loss(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
|
||||
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
||||
|
||||
# test with distillation head
|
||||
cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||
|
@ -525,7 +525,7 @@ class TestMultiTaskHead(TestCase):
|
|||
for _ in range(4):
|
||||
data_sample = MultiTaskDataSample()
|
||||
for task_name in self.DEFAULT_ARGS['task_heads']:
|
||||
task_sample = ClsDataSample().set_gt_label(1)
|
||||
task_sample = DataSample().set_gt_label(1)
|
||||
data_sample.set_field(task_sample, task_name)
|
||||
data_samples.append(data_sample)
|
||||
# with cal_acc = False
|
||||
|
@ -545,7 +545,7 @@ class TestMultiTaskHead(TestCase):
|
|||
for _ in range(4):
|
||||
data_sample = MultiTaskDataSample()
|
||||
for task_name in self.DEFAULT_ARGS['task_heads']:
|
||||
task_sample = ClsDataSample().set_gt_label(1)
|
||||
task_sample = DataSample().set_gt_label(1)
|
||||
data_sample.set_field(task_sample, task_name)
|
||||
data_samples.append(data_sample)
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
@ -555,7 +555,7 @@ class TestMultiTaskHead(TestCase):
|
|||
for pred in predictions:
|
||||
self.assertIn('task0', pred)
|
||||
task0_sample = predictions[0].task0
|
||||
self.assertTrue(type(task0_sample.pred_label.score), 'torch.tensor')
|
||||
self.assertTrue(type(task0_sample.pred_score), 'torch.tensor')
|
||||
|
||||
# with with data_samples
|
||||
predictions = head.predict(feats, data_samples)
|
||||
|
@ -596,7 +596,7 @@ class TestMultiTaskHead(TestCase):
|
|||
head = MODELS.build(self.DEFAULT_ARGS2)
|
||||
data_sample = MultiTaskDataSample()
|
||||
for task_name in gt_label:
|
||||
task_sample = ClsDataSample().set_gt_label(gt_label[task_name])
|
||||
task_sample = DataSample().set_gt_label(gt_label[task_name])
|
||||
data_sample.set_field(task_sample, task_name)
|
||||
with self.assertRaises(Exception):
|
||||
head.loss(feats, data_sample)
|
||||
|
@ -606,11 +606,11 @@ class TestMultiTaskHead(TestCase):
|
|||
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
data_sample = MultiTaskDataSample()
|
||||
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
|
||||
task_sample = DataSample().set_gt_label(gt_label['task1'])
|
||||
data_sample.set_field(task_sample, 'task1')
|
||||
data_sample.set_field(MultiTaskDataSample(), 'task0')
|
||||
for task_name in gt_label['task0']:
|
||||
task_sample = ClsDataSample().set_gt_label(
|
||||
task_sample = DataSample().set_gt_label(
|
||||
gt_label['task0'][task_name])
|
||||
data_sample.task0.set_field(task_sample, task_name)
|
||||
with self.assertRaises(Exception):
|
||||
|
@ -694,7 +694,7 @@ class TestArcFaceClsHead(TestCase):
|
|||
|
||||
def test_loss(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
|
||||
data_samples = [DataSample().set_gt_label(1) for _ in range(4)]
|
||||
|
||||
# test loss with used='before'
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
|
|
@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset
|
|||
|
||||
from mmpretrain.datasets.transforms import PackClsInputs
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
@ -125,7 +125,7 @@ class TestImageToImageRetriever(TestCase):
|
|||
|
||||
def test_loss(self):
|
||||
inputs = torch.rand(1, 3, 64, 64)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
|
||||
model = MODELS.build(self.DEFAULT_ARGS)
|
||||
losses = model.loss(inputs, data_samples)
|
||||
|
@ -172,32 +172,32 @@ class TestImageToImageRetriever(TestCase):
|
|||
|
||||
def test_predict(self):
|
||||
inputs = torch.rand(1, 3, 64, 64)
|
||||
data_samples = [ClsDataSample().set_gt_label([1, 2, 6])]
|
||||
data_samples = [DataSample().set_gt_label([1, 2, 6])]
|
||||
# default
|
||||
model = MODELS.build(self.DEFAULT_ARGS)
|
||||
predictions = model.predict(inputs)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
||||
predictions = model.predict(inputs, data_samples)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
||||
predictions[0].pred_label.score)
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_score,
|
||||
predictions[0].pred_score)
|
||||
|
||||
# k is not -1
|
||||
cfg = {**self.DEFAULT_ARGS, 'topk': 2}
|
||||
model = MODELS.build(cfg)
|
||||
|
||||
predictions = model.predict(inputs)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
||||
predictions = model.predict(inputs, data_samples)
|
||||
assert predictions is data_samples
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
|
||||
|
||||
def test_forward(self):
|
||||
inputs = torch.rand(1, 3, 64, 64)
|
||||
data_samples = [ClsDataSample().set_gt_label(1)]
|
||||
data_samples = [DataSample().set_gt_label(1)]
|
||||
model = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# test pure forward
|
||||
|
@ -213,13 +213,13 @@ class TestImageToImageRetriever(TestCase):
|
|||
|
||||
# test forward test
|
||||
predictions = model(inputs, mode='predict')
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
||||
predictions = model(inputs, data_samples, mode='predict')
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_label.score.shape, (10, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_label.score,
|
||||
predictions[0].pred_label.score)
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
self.assertEqual(data_samples[0].pred_score.shape, (10, ))
|
||||
torch.testing.assert_allclose(data_samples[0].pred_score,
|
||||
predictions[0].pred_score)
|
||||
|
||||
# test forward with invalid mode
|
||||
with self.assertRaisesRegex(RuntimeError, 'Invalid mode "unknown"'):
|
||||
|
@ -234,7 +234,7 @@ class TestImageToImageRetriever(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
optim_wrapper = MagicMock()
|
||||
|
@ -251,11 +251,11 @@ class TestImageToImageRetriever(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.val_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
||||
def test_test_step(self):
|
||||
cfg = {
|
||||
|
@ -266,8 +266,8 @@ class TestImageToImageRetriever(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 64, 64)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.test_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
self.assertEqual(predictions[0].pred_score.shape, (10, ))
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmengine.registry import init_default_scope
|
|||
|
||||
from mmpretrain.models import AverageClsScoreTTA, ImageClassifier
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
init_default_scope('mmpretrain')
|
||||
|
||||
|
@ -48,20 +48,20 @@ class TestAverageClsScoreTTA(TestCase):
|
|||
img2 = torch.randint(0, 256, (1, 3, 224, 224))
|
||||
data1 = {
|
||||
'inputs': img1,
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
data2 = {
|
||||
'inputs': img2,
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
data_tta = {
|
||||
'inputs': [img1, img2],
|
||||
'data_samples': [[ClsDataSample().set_gt_label(1)],
|
||||
[ClsDataSample().set_gt_label(1)]]
|
||||
'data_samples': [[DataSample().set_gt_label(1)],
|
||||
[DataSample().set_gt_label(1)]]
|
||||
}
|
||||
|
||||
score1 = model.module.test_step(data1)[0].pred_label.score
|
||||
score2 = model.module.test_step(data2)[0].pred_label.score
|
||||
score_tta = model.test_step(data_tta)[0].pred_label.score
|
||||
score1 = model.module.test_step(data1)[0].pred_score
|
||||
score2 = model.module.test_step(data2)[0].pred_score
|
||||
score_tta = model.test_step(data_tta)[0].pred_score
|
||||
|
||||
torch.testing.assert_allclose(score_tta, (score1 + score2) / 2)
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
|
||||
from mmpretrain.models import ClsDataPreprocessor, RandomBatchAugment
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class TestClsDataPreprocessor(TestCase):
|
||||
|
@ -16,15 +16,14 @@ class TestClsDataPreprocessor(TestCase):
|
|||
|
||||
data = {
|
||||
'inputs': [torch.randint(0, 256, (3, 224, 224))],
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
processed_data = processor(data)
|
||||
inputs = processed_data['inputs']
|
||||
data_samples = processed_data['data_samples']
|
||||
self.assertEqual(inputs.shape, (1, 3, 224, 224))
|
||||
self.assertEqual(len(data_samples), 1)
|
||||
self.assertTrue(
|
||||
(data_samples[0].gt_label.label == torch.tensor([1])).all())
|
||||
self.assertTrue((data_samples[0].gt_label == torch.tensor([1])).all())
|
||||
|
||||
def test_padding(self):
|
||||
cfg = dict(type='ClsDataPreprocessor', pad_size_divisor=16)
|
||||
|
@ -87,7 +86,7 @@ class TestClsDataPreprocessor(TestCase):
|
|||
self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
|
||||
data = {
|
||||
'inputs': [torch.randint(0, 256, (3, 224, 224))],
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
'data_samples': [DataSample().set_gt_label(1)]
|
||||
}
|
||||
processed_data = processor(data, training=True)
|
||||
self.assertIn('inputs', processed_data)
|
||||
|
|
|
@ -3,58 +3,51 @@ from unittest import TestCase
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmpretrain.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmpretrain.structures import DataSample, MultiTaskDataSample
|
||||
|
||||
|
||||
class TestClsDataSample(TestCase):
|
||||
class TestDataSample(TestCase):
|
||||
|
||||
def _test_set_label(self, key):
|
||||
data_sample = ClsDataSample()
|
||||
data_sample = DataSample()
|
||||
method = getattr(data_sample, 'set_' + key)
|
||||
# Test number
|
||||
method(1)
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIsInstance(label.label, torch.LongTensor)
|
||||
self.assertIsInstance(label, torch.LongTensor)
|
||||
|
||||
# Test tensor with single number
|
||||
method(torch.tensor(2))
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIsInstance(label.label, torch.LongTensor)
|
||||
self.assertIsInstance(label, torch.LongTensor)
|
||||
|
||||
# Test array with single number
|
||||
method(np.array(3))
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIsInstance(label.label, torch.LongTensor)
|
||||
self.assertIsInstance(label, torch.LongTensor)
|
||||
|
||||
# Test tensor
|
||||
method(torch.tensor([1, 2, 3]))
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertIsInstance(label.label, torch.Tensor)
|
||||
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
|
||||
self.assertIsInstance(label, torch.Tensor)
|
||||
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
|
||||
|
||||
# Test array
|
||||
method(np.array([1, 2, 3]))
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
|
||||
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
|
||||
|
||||
# Test Sequence
|
||||
method([1, 2, 3])
|
||||
self.assertIn(key, data_sample)
|
||||
label = getattr(data_sample, key)
|
||||
self.assertIsInstance(label, LabelData)
|
||||
self.assertTrue((label.label == torch.tensor([1, 2, 3])).all())
|
||||
self.assertTrue((label == torch.tensor([1, 2, 3])).all())
|
||||
|
||||
# Test unavailable type
|
||||
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
|
||||
|
@ -66,34 +59,13 @@ class TestClsDataSample(TestCase):
|
|||
def test_set_pred_label(self):
|
||||
self._test_set_label('pred_label')
|
||||
|
||||
def test_del_gt_label(self):
|
||||
data_sample = ClsDataSample()
|
||||
self.assertNotIn('gt_label', data_sample)
|
||||
data_sample.set_gt_label(1)
|
||||
self.assertIn('gt_label', data_sample)
|
||||
del data_sample.gt_label
|
||||
self.assertNotIn('gt_label', data_sample)
|
||||
|
||||
def test_del_pred_label(self):
|
||||
data_sample = ClsDataSample()
|
||||
self.assertNotIn('pred_label', data_sample)
|
||||
data_sample.set_pred_label(1)
|
||||
self.assertIn('pred_label', data_sample)
|
||||
del data_sample.pred_label
|
||||
self.assertNotIn('pred_label', data_sample)
|
||||
|
||||
def test_set_gt_score(self):
|
||||
data_sample = ClsDataSample()
|
||||
data_sample = DataSample()
|
||||
data_sample.set_gt_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertIn('score', data_sample.gt_label)
|
||||
torch.testing.assert_allclose(data_sample.gt_label.score,
|
||||
self.assertIn('gt_score', data_sample)
|
||||
torch.testing.assert_allclose(data_sample.gt_score,
|
||||
[0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
|
||||
# Test set again
|
||||
data_sample.set_gt_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
|
||||
torch.testing.assert_allclose(data_sample.gt_label.score,
|
||||
[0.2, 0.1, 0.5, 0.1, 0.1])
|
||||
|
||||
# Test invalid length
|
||||
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
|
||||
data_sample.set_gt_score([1, 2])
|
||||
|
@ -103,17 +75,12 @@ class TestClsDataSample(TestCase):
|
|||
data_sample.set_gt_score(torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
|
||||
|
||||
def test_set_pred_score(self):
|
||||
data_sample = ClsDataSample()
|
||||
data_sample = DataSample()
|
||||
data_sample.set_pred_score(torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]))
|
||||
self.assertIn('score', data_sample.pred_label)
|
||||
torch.testing.assert_allclose(data_sample.pred_label.score,
|
||||
self.assertIn('pred_score', data_sample)
|
||||
torch.testing.assert_allclose(data_sample.pred_score,
|
||||
[0.1, 0.1, 0.6, 0.1, 0.1])
|
||||
|
||||
# Test set again
|
||||
data_sample.set_pred_score(torch.tensor([0.2, 0.1, 0.5, 0.1, 0.1]))
|
||||
torch.testing.assert_allclose(data_sample.pred_label.score,
|
||||
[0.2, 0.1, 0.5, 0.1, 0.1])
|
||||
|
||||
# Test invalid length
|
||||
with self.assertRaisesRegex(AssertionError, 'should be equal to'):
|
||||
data_sample.set_gt_score([1, 2])
|
||||
|
@ -129,13 +96,13 @@ class TestMultiTaskDataSample(TestCase):
|
|||
def test_multi_task_data_sample(self):
|
||||
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
|
||||
data_sample = MultiTaskDataSample()
|
||||
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
|
||||
task_sample = DataSample().set_gt_label(gt_label['task1'])
|
||||
data_sample.set_field(task_sample, 'task1')
|
||||
data_sample.set_field(MultiTaskDataSample(), 'task0')
|
||||
for task_name in gt_label['task0']:
|
||||
task_sample = ClsDataSample().set_gt_label(
|
||||
task_sample = DataSample().set_gt_label(
|
||||
gt_label['task0'][task_name])
|
||||
data_sample.task0.set_field(task_sample, task_name)
|
||||
self.assertIsInstance(data_sample.task0, MultiTaskDataSample)
|
||||
self.assertIsInstance(data_sample.task1, ClsDataSample)
|
||||
self.assertIsInstance(data_sample.task0.task00, ClsDataSample)
|
||||
self.assertIsInstance(data_sample.task1, DataSample)
|
||||
self.assertIsInstance(data_sample.task0.task00, DataSample)
|
||||
|
|
|
@ -2,10 +2,9 @@
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmpretrain.structures import (batch_label_to_onehot, cat_batch_labels,
|
||||
stack_batch_scores, tensor_split)
|
||||
tensor_split)
|
||||
|
||||
|
||||
class TestStructureUtils(TestCase):
|
||||
|
@ -28,11 +27,11 @@ class TestStructureUtils(TestCase):
|
|||
|
||||
def test_cat_batch_labels(self):
|
||||
labels = [
|
||||
LabelData(label=torch.tensor([1])),
|
||||
LabelData(label=torch.tensor([3, 2])),
|
||||
LabelData(label=torch.tensor([0, 1, 4])),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
torch.tensor([1]),
|
||||
torch.tensor([3, 2]),
|
||||
torch.tensor([0, 1, 4]),
|
||||
torch.tensor([], dtype=torch.int64),
|
||||
torch.tensor([], dtype=torch.int64),
|
||||
]
|
||||
|
||||
batch_label, split_indices = cat_batch_labels(labels)
|
||||
|
@ -45,42 +44,13 @@ class TestStructureUtils(TestCase):
|
|||
self.assertEqual(labels[3].tolist(), [])
|
||||
self.assertEqual(labels[4].tolist(), [])
|
||||
|
||||
labels = [
|
||||
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
|
||||
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
|
||||
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
|
||||
]
|
||||
batch_label, split_indices = cat_batch_labels(labels)
|
||||
self.assertIsNone(batch_label)
|
||||
self.assertIsNone(split_indices)
|
||||
|
||||
def test_stack_batch_scores(self):
|
||||
labels = [
|
||||
LabelData(score=torch.tensor([0, 1, 0, 0, 1])),
|
||||
LabelData(score=torch.tensor([0, 0, 1, 0, 0])),
|
||||
LabelData(score=torch.tensor([1, 0, 0, 1, 0])),
|
||||
]
|
||||
|
||||
batch_score = stack_batch_scores(labels)
|
||||
self.assertEqual(batch_score.shape, (3, 5))
|
||||
|
||||
labels = [
|
||||
LabelData(label=torch.tensor([1])),
|
||||
LabelData(label=torch.tensor([3, 2])),
|
||||
LabelData(label=torch.tensor([0, 1, 4])),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
]
|
||||
batch_score = stack_batch_scores(labels)
|
||||
self.assertIsNone(batch_score)
|
||||
|
||||
def test_batch_label_to_onehot(self):
|
||||
labels = [
|
||||
LabelData(label=torch.tensor([1])),
|
||||
LabelData(label=torch.tensor([3, 2])),
|
||||
LabelData(label=torch.tensor([0, 1, 4])),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
LabelData(label=torch.tensor([], dtype=torch.int64)),
|
||||
torch.tensor([1]),
|
||||
torch.tensor([3, 2]),
|
||||
torch.tensor([0, 1, 4]),
|
||||
torch.tensor([], dtype=torch.int64),
|
||||
torch.tensor([], dtype=torch.int64),
|
||||
]
|
||||
|
||||
batch_label, split_indices = cat_batch_labels(labels)
|
||||
|
|
|
@ -7,7 +7,7 @@ from unittest.mock import patch
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmpretrain.structures import ClsDataSample
|
||||
from mmpretrain.structures import DataSample
|
||||
from mmpretrain.visualization import ClsVisualizer
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ class TestClsVisualizer(TestCase):
|
|||
|
||||
def test_add_datasample(self):
|
||||
image = np.ones((10, 10, 3), np.uint8)
|
||||
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
|
||||
data_sample = DataSample().set_gt_label(1).set_pred_label(1).\
|
||||
set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
|
||||
|
||||
# Test show
|
||||
|
@ -82,7 +82,7 @@ class TestClsVisualizer(TestCase):
|
|||
'test', image=image, data_sample=data_sample, draw_gt=False)
|
||||
|
||||
# Test without score
|
||||
del data_sample.pred_label.score
|
||||
del data_sample.pred_score
|
||||
|
||||
def test_texts(text, *_, **__):
|
||||
self.assertEqual(
|
||||
|
|
Loading…
Reference in New Issue