[Feature] Support Multi-task. (#1229)

* unit test for multi_task_head

* [Feature] MultiTaskHead (#628, #481)

* [Fix] lint for multi_task_head

* [Feature] Add `MultiTaskDataset` to support multi-task training.

* Update MultiTaskClsHead

* Update docs

* [CI] Add test mim CI. (#879)

* [Fix] Remove duplicated wide-resnet metafile.

* [Feature] Support MPS device. (#894)

* [Feature] Support MPS device.

* Add `auto_select_device`

* Add unit tests

* [Fix] Fix Albu crash bug. (#918)

* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning

* Fix common

* Using copy incase potential bug in multi-label tasks

* Improve coding

* Improve code logic

* Add unit test

* Fix typo

* Fix yapf

* Bump version to 0.23.2. (#937)

* [Improve] Use `forward_dummy` to calculate FLOPS. (#953)

* Update README

* [Docs] Fix typo for wrong reference. (#1036)

* [Doc] Fix typo in tutorial 2 (#1043)

* [Docs] Fix a typo in ImageClassifier (#1050)

* add mask to loss

* add another pipeline

* adpat the pipeline if there is no mask

* switch mask and task

* first version of multi data smaple

* fix problem with attribut by getattr

* rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label'

* training  without evaluation

* first version work

* add others metrics

* delete evaluation from dataset

* fix linter

* fix linter

* multi metrics

* first version of test

* change evaluate metric

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* add tests

* add test for multidatasample

* create a generic test

* create a generic test

* create a generic test

* change multi data sample

* correct test

* test

* add new test

* add test for dataset

* correct test

* correct test

* correct test

* correct test

* fix : #5

* run yapf

* fix linter

* fix linter

* fix linter

* fix isort

* fix isort

* fix docformmater

* fix docformmater

* fix linter

* fix linter

* fix data sample

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update data sample

* update head

* update head

* update multi data sample

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* update head

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix linter

* fix : #2

* fix : linter

* update multi head

* fix linter

* fix linter

* update data sample

* update data sample

* fix ; linter

* update test

* test pipeline

* update pipeline

* update test

* update dataset

* update dataset

* fix linter

* fix linter

* update formatting

* add test for multi-task-eval

* update formatting

* fix linter

* update test

* update

* add test

* update metrics

* update metrics

* add doc for functions

* fix linter

* training for multitask 1.x

* fix linter

* run flake8

* run linter

* update test

* add mask in evaluation

* update metric doc

* update metric doc

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update metric doc

* update metric doc

* Fix cannot import name MultiTaskDataSample

* fix test_datasets

* fix test_datasets

* fix linter

* add an example of multitask

* change name of configs dataset

* Refactor the multi-task support

* correct test and metric

* add test to multidatasample

* add test to multidatasample

* correct test

* correct metrics and clshead

* Update mmcls/models/heads/cls_head.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update cls_head.py documentation

* lint

* lint

* fix: lint

* fix linter

* add eval mask

* fix documentation

* fix: single_label.py back to 1.x

* Update mmcls/models/heads/multi_task_head.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Remove multi-task configs.

Co-authored-by: mzr1996 <mzr1996@163.com>
Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com>
Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: marouaneamz <maroineamil99@gmail.com>
Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
pull/1285/head
Colle 2022-12-30 03:36:00 +01:00 committed by GitHub
parent 5b266d9e7c
commit bac181f393
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1185 additions and 41 deletions

View File

@ -8,6 +8,7 @@ from .dataset_wrappers import KFoldDataset
from .imagenet import ImageNet, ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .multi_task import MultiTaskDataset
from .samplers import * # noqa: F401,F403
from .transforms import * # noqa: F401,F403
from .voc import VOC
@ -15,5 +16,5 @@ from .voc import VOC
__all__ = [
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
'CustomDataset', 'MultiLabelDataset'
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset'
]

View File

@ -0,0 +1,344 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from os import PathLike
from typing import Optional, Sequence
import mmengine
from mmcv.transforms import Compose
from mmengine.fileio import FileClient
from .builder import DATASETS
def expanduser(path):
if isinstance(path, (str, PathLike)):
return osp.expanduser(path)
else:
return path
def isabs(uri):
return osp.isabs(uri) or ('://' in uri)
@DATASETS.register_module()
class MultiTaskDataset:
"""Custom dataset for multi-task dataset.
To use the dataset, please generate and provide an annotation file in the
below format:
.. code-block:: json
{
"metainfo": {
"tasks":
[
'gender'
'wear'
]
},
"data_list": [
{
"img_path": "a.jpg",
gt_label:{
"gender": 0,
"wear": [1, 0, 1, 0]
}
},
{
"img_path": "b.jpg",
gt_label:{
"gender": 1,
"wear": [1, 0, 1, 0]
}
}
]
}
Assume we put our dataset in the ``data/mydataset`` folder in the
repository and organize it as the below format: ::
mmclassification/
data
mydataset
annotation
   train.json
   test.json
   val.json
train
   a.jpg
   ...
test
   b.jpg
   ...
val
c.jpg
...
We can use the below config to build datasets:
.. code:: python
>>> from mmcls.datasets import build_dataset
>>> train_cfg = dict(
... type="MultiTaskDataset",
... ann_file="annotation/train.json",
... data_root="data/mydataset",
... # The `img_path` field in the train annotation file is relative
... # to the `train` folder.
... data_prefix='train',
... )
>>> train_dataset = build_dataset(train_cfg)
Or we can put all files in the same folder: ::
mmclassification/
data
mydataset
train.json
test.json
val.json
a.jpg
b.jpg
c.jpg
...
And we can use the below config to build datasets:
.. code:: python
>>> from mmcls.datasets import build_dataset
>>> train_cfg = dict(
... type="MultiTaskDataset",
... ann_file="train.json",
... data_root="data/mydataset",
... # the `data_prefix` is not required since all paths are
... # relative to the `data_root`.
... )
>>> train_dataset = build_dataset(train_cfg)
Args:
ann_file (str): The annotation file path. It can be either absolute
path or relative path to the ``data_root``.
metainfo (dict, optional): The extra meta information. It should be
a dict with the same format as the ``"metainfo"`` field in the
annotation file. Defaults to None.
data_root (str, optional): The root path of the data directory. It's
the prefix of the ``data_prefix`` and the ``ann_file``. And it can
be a remote path like "s3://openmmlab/xxx/". Defaults to None.
data_prefix (str, optional): The base folder relative to the
``data_root`` for the ``"img_path"`` field in the annotation file.
Defaults to None.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
test_mode (bool): in train mode or test mode. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
If None, automatically inference from the ``data_root``.
Defaults to None.
"""
METAINFO = dict()
def __init__(self,
ann_file: str,
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: Optional[str] = None,
pipeline: Sequence = (),
test_mode: bool = False,
file_client_args: Optional[dict] = None):
self.data_root = expanduser(data_root)
# Inference the file client
if self.data_root is not None:
file_client = FileClient.infer_client(
file_client_args, uri=self.data_root)
else:
file_client = FileClient(file_client_args)
self.file_client: FileClient = file_client
self.ann_file = self._join_root(expanduser(ann_file))
self.data_prefix = self._join_root(data_prefix)
self.test_mode = test_mode
self.pipeline = Compose(pipeline)
self.data_list = self.load_data_list(self.ann_file, metainfo)
def _join_root(self, path):
"""Join ``self.data_root`` with the specified path.
If the path is an absolute path, just return the path. And if the
path is None, return ``self.data_root``.
Examples:
>>> self.data_root = 'a/b/c'
>>> self._join_root('d/e/')
'a/b/c/d/e'
>>> self._join_root('https://openmmlab.com')
'https://openmmlab.com'
>>> self._join_root(None)
'a/b/c'
"""
if path is None:
return self.data_root
if isabs(path):
return path
joined_path = self.file_client.join_path(self.data_root, path)
return joined_path
@classmethod
def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
in_metainfo (dict): Meta information dict.
Returns:
dict: Parsed meta information.
"""
# `cls.METAINFO` will be overwritten by in_meta
metainfo = copy.deepcopy(cls.METAINFO)
if in_metainfo is None:
return metainfo
metainfo.update(in_metainfo)
return metainfo
def load_data_list(self, ann_file, metainfo_override=None):
"""Load annotations from an annotation file.
Args:
ann_file (str): Absolute annotation file path if ``self.root=None``
or relative path if ``self.root=/path/to/data/``.
Returns:
list[dict]: A list of annotation.
"""
annotations = mmengine.load(ann_file)
if not isinstance(annotations, dict):
raise TypeError(f'The annotations loaded from annotation file '
f'should be a dict, but got {type(annotations)}!')
if 'data_list' not in annotations:
raise ValueError('The annotation file must have the `data_list` '
'field.')
metainfo = annotations.get('metainfo', {})
raw_data_list = annotations['data_list']
# Set meta information.
assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
f'annotation file should be a dict, but got {type(metainfo)}'
if metainfo_override is not None:
assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
f'argument should be a dict, but got {type(metainfo_override)}'
metainfo.update(metainfo_override)
self._metainfo = self._get_meta_info(metainfo)
data_list = []
for i, raw_data in enumerate(raw_data_list):
try:
data_list.append(self.parse_data_info(raw_data))
except AssertionError as e:
raise RuntimeError(
f'The format check fails during parse the item {i} of '
f'the annotation file with error: {e}')
return data_list
def parse_data_info(self, raw_data):
"""Parse raw annotation to target format.
This method will return a dict which contains the data information of a
sample.
Args:
raw_data (dict): Raw data information load from ``ann_file``
Returns:
dict: Parsed annotation.
"""
assert isinstance(raw_data, dict), \
f'The item should be a dict, but got {type(raw_data)}'
assert 'img_path' in raw_data, \
"The item doesn't have `img_path` field."
data = dict(
img_path=self._join_root(raw_data['img_path']),
gt_label=raw_data['gt_label'],
)
return data
@property
def metainfo(self) -> dict:
"""Get meta information of dataset.
Returns:
dict: meta information collected from ``cls.METAINFO``,
annotation file and metainfo argument during instantiation.
"""
return copy.deepcopy(self._metainfo)
def prepare_data(self, idx):
"""Get data processed by ``self.pipeline``.
Args:
idx (int): The index of ``data_info``.
Returns:
Any: Depends on ``self.pipeline``.
"""
results = copy.deepcopy(self.data_list[idx])
return self.pipeline(results)
def __len__(self):
"""Get the length of the whole dataset.
Returns:
int: The length of filtered dataset.
"""
return len(self.data_list)
def __getitem__(self, idx):
"""Get the idx-th image and data information of dataset after
``self.pipeline``.
Args:
idx (int): The index of of the data.
Returns:
dict: The idx-th image and data information after
``self.pipeline``.
"""
return self.prepare_data(idx)
def __repr__(self):
"""Print the basic information of the dataset.
Returns:
str: Formatted string.
"""
head = 'Dataset ' + self.__class__.__name__
body = [f'Number of samples: \t{self.__len__()}']
if self.data_root is not None:
body.append(f'Root location: \t{self.data_root}')
body.append(f'Annotation file: \t{self.ann_file}')
if self.data_prefix is not None:
body.append(f'Prefix of images: \t{self.data_prefix}')
# -------------------- extra repr --------------------
tasks = self.metainfo['tasks']
body.append(f'For {len(tasks)} tasks')
for task in tasks:
body.append(f' {task} ')
# ----------------------------------------------------
if len(self.pipeline.transforms) > 0:
body.append('With transforms:')
for t in self.pipeline.transforms:
body.append(f' {t}')
lines = [head] + [' ' * 4 + line for line in body]
return '\n'.join(lines)

View File

@ -3,7 +3,8 @@ from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform,
Brightness, ColorTransform, Contrast, Cutout,
Equalize, Invert, Posterize, RandAugment, Rotate,
Sharpness, Shear, Solarize, SolarizeAdd, Translate)
from .formatting import Collect, PackClsInputs, ToNumpy, ToPIL, Transpose
from .formatting import (Collect, PackClsInputs, PackMultiTaskInputs, ToNumpy,
ToPIL, Transpose)
from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop,
EfficientNetRandomCrop, Lighting, RandomCrop,
RandomErasing, RandomResizedCrop, ResizeEdge)
@ -15,5 +16,6 @@ __all__ = [
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop',
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform'
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
'PackMultiTaskInputs'
]

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import defaultdict
from collections.abc import Sequence
from functools import partial
import numpy as np
import torch
@ -9,7 +10,7 @@ from mmengine.utils import is_str
from PIL import Image
from mmcls.registry import TRANSFORMS
from mmcls.structures import ClsDataSample
from mmcls.structures import ClsDataSample, MultiTaskDataSample
def to_tensor(data):
@ -85,12 +86,6 @@ class PackClsInputs(BaseTransform):
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
else:
warnings.warn(
'Cannot get "img" in the input dict of `PackClsInputs`,'
'please make sure `LoadImageFromFile` has been added '
'in the data pipeline or images have been loaded in '
'the dataset.')
data_sample = ClsDataSample()
if 'gt_label' in results:
@ -100,7 +95,6 @@ class PackClsInputs(BaseTransform):
img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self) -> str:
@ -109,6 +103,84 @@ class PackClsInputs(BaseTransform):
return repr_str
@TRANSFORMS.register_module()
class PackMultiTaskInputs(BaseTransform):
"""Convert all image labels of multi-task dataset to a dict of tensor.
Args:
tasks (List[str]): The task names defined in the dataset.
meta_keys(Sequence[str]): The meta keys to be saved in the
``metainfo`` of the packed ``data_samples``.
Defaults to a tuple includes keys:
- ``sample_idx``: The id of the image sample.
- ``img_path``: The path to the image file.
- ``ori_shape``: The original shape of the image as a tuple (H, W).
- ``img_shape``: The shape of the image after the pipeline as a
tuple (H, W).
- ``scale_factor``: The scale factor between the resized image and
the original image.
- ``flip``: A boolean indicating if image flip transform was used.
- ``flip_direction``: The flipping direction.
"""
def __init__(self,
task_handlers=dict(),
multi_task_fields=('gt_label', ),
meta_keys=('sample_idx', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')):
self.multi_task_fields = multi_task_fields
self.meta_keys = meta_keys
self.task_handlers = defaultdict(
partial(PackClsInputs, meta_keys=meta_keys))
for task_name, task_handler in task_handlers.items():
self.task_handlers[task_name] = TRANSFORMS.build(
dict(type=task_handler, meta_keys=meta_keys))
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3},
'img': array([[[ 0, 0, 0])
"""
packed_results = dict()
results = results.copy()
if 'img' in results:
img = results.pop('img')
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
task_results = defaultdict(dict)
for field in self.multi_task_fields:
if field in results:
value = results.pop(field)
for k, v in value.items():
task_results[k].update({field: v})
data_sample = MultiTaskDataSample()
for task_name, task_result in task_results.items():
task_handler = self.task_handlers[task_name]
task_pack_result = task_handler({**results, **task_result})
data_sample.set_field(task_pack_result['data_samples'], task_name)
packed_results['data_samples'] = data_sample
return packed_results
def __repr__(self):
repr = self.__class__.__name__
task_handlers = {
name: handler.__class__.__name__
for name, handler in self.task_handlers.items()
}
repr += f'(task_handlers={task_handlers}, '
repr += f'multi_task_fields={self.multi_task_fields}, '
repr += f'meta_keys={self.meta_keys})'
return repr
@TRANSFORMS.register_module()
class Transpose(BaseTransform):
"""Transpose numpy array.

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
from .single_label import Accuracy, SingleLabelMetric
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
__all__ = [
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
'VOCAveragePrecision', 'VOCMultiLabelMetric'
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric'
]

View File

@ -0,0 +1,120 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence
from mmengine.evaluator import BaseMetric
from mmcls.registry import METRICS
@METRICS.register_module()
class MultiTasksMetric(BaseMetric):
"""Metrics for MultiTask
Args:
task_metrics(dict): a dictionary in the keys are the names of the tasks
and the values is a list of the metric corresponds to this task
Examples:
>>> import torch
>>> from mmcls.evaluation import MultiTasksMetric
# -------------------- The Basic Usage --------------------
>>>task_metrics = {
'task0': [dict(type='Accuracy', topk=(1, ))],
'task1': [dict(type='Accuracy', topk=(1, 3))]
}
>>>pred = [{
'pred_task': {
'task0': torch.tensor([0.7, 0.0, 0.3]),
'task1': torch.tensor([0.5, 0.2, 0.3])
},
'gt_task': {
'task0': torch.tensor(0),
'task1': torch.tensor(2)
}
}, {
'pred_task': {
'task0': torch.tensor([0.0, 0.0, 1.0]),
'task1': torch.tensor([0.0, 0.0, 1.0])
},
'gt_task': {
'task0': torch.tensor(2),
'task1': torch.tensor(2)
}
}]
>>>metric = MultiTasksMetric(task_metrics)
>>>metric.process(None, pred)
>>>results = metric.evaluate(2)
results = {
'task0_accuracy/top1': 100.0,
'task1_accuracy/top1': 50.0,
'task1_accuracy/top3': 100.0
}
"""
def __init__(self,
task_metrics: Dict,
collect_device: str = 'cpu') -> None:
self.task_metrics = task_metrics
super().__init__(collect_device=collect_device)
self._metrics = {}
for task_name in self.task_metrics.keys():
self._metrics[task_name] = []
for metric in self.task_metrics[task_name]:
self._metrics[task_name].append(METRICS.build(metric))
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for task_name in self.task_metrics.keys():
filtered_data_samples = []
for data_sample in data_samples:
eval_mask = data_sample[task_name]['eval_mask']
if eval_mask:
filtered_data_samples.append(data_sample[task_name])
for metric in self._metrics[task_name]:
metric.process(data_batch, filtered_data_samples)
def compute_metrics(self, results: list) -> dict:
raise NotImplementedError(
'compute metrics should not be used here directly')
def evaluate(self, size):
"""Evaluate the model performance of the whole dataset after processing
all batches.
Args:
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data based on
this size.
Returns:
dict: Evaluation metrics dict on the val dataset. The keys are
"{task_name}_{metric_name}" , and the values
are corresponding results.
"""
metrics = {}
for task_name in self._metrics:
for metric in self._metrics[task_name]:
name = metric.__class__.__name__
if name == 'MultiTasksMetric' or metric.results:
results = metric.evaluate(size)
else:
results = {metric.__class__.__name__: 0}
for key in results:
name = f'{task_name}_{key}'
if name in results:
"""Inspired from https://github.com/open-
mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2
747172e/mmengine/evalua tor/evaluator.py#L84-L87."""
raise ValueError(
'There are multiple metric results with the same'
f'metric name {name}. Please make sure all metrics'
'have different prefixes.')
metrics[name] = results[key]
return metrics

View File

@ -8,11 +8,13 @@ from .margin_head import ArcFaceClsHead
from .multi_label_cls_head import MultiLabelClsHead
from .multi_label_csra_head import CSRAClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
from .multi_task_head import MultiTaskHead
from .stacked_head import StackedLinearClsHead
from .vision_transformer_head import VisionTransformerClsHead
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead', 'CSRAClsHead'
'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead', 'CSRAClsHead',
'MultiTaskHead'
]

View File

@ -108,9 +108,10 @@ class ClsHead(BaseHead):
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
self,
feats: Tuple[torch.Tensor],
data_samples: List[Union[ClsDataSample, None]] = None
) -> List[ClsDataSample]:
"""Inference without augmentation.
Args:
@ -118,7 +119,7 @@ class ClsHead(BaseHead):
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation
data_samples (List[ClsDataSample | None], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
@ -141,14 +142,15 @@ class ClsHead(BaseHead):
pred_scores = F.softmax(cls_score, dim=1)
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
if data_samples is not None:
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
data_sample.set_pred_score(score).set_pred_label(label)
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score(
score).set_pred_label(label))
out_data_samples = []
if data_samples is None:
data_samples = [None for _ in range(pred_scores.size(0))]
return data_samples
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
if data_sample is None:
data_sample = ClsDataSample()
data_sample.set_pred_score(score).set_pred_label(label)
out_data_samples.append(data_sample)
return out_data_samples

View File

@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
from mmengine.model import ModuleDict
from mmcls.registry import MODELS
from mmcls.structures import MultiTaskDataSample
from .base_head import BaseHead
def loss_convertor(loss_func, task_name):
def wrapped(inputs, data_samples, **kwargs):
mask = torch.empty(len(data_samples), dtype=torch.bool)
task_data_samples = []
for i, data_sample in enumerate(data_samples):
assert isinstance(data_sample, MultiTaskDataSample)
sample_mask = task_name in data_sample
mask[i] = sample_mask
if sample_mask:
task_data_samples.append(data_sample.get(task_name))
if len(task_data_samples) == 0:
return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)}
# Mask the inputs of the task
def mask_inputs(inputs, mask):
if isinstance(inputs, Sequence):
return type(inputs)(
[mask_inputs(input, mask) for input in inputs])
elif isinstance(inputs, torch.Tensor):
return inputs[mask]
masked_inputs = mask_inputs(inputs, mask)
loss_output = loss_func(masked_inputs, task_data_samples, **kwargs)
loss_output['mask_size'] = mask.sum().to(torch.float)
return loss_output
return wrapped
@MODELS.register_module()
class MultiTaskHead(BaseHead):
"""Multi task head.
Args:
task_heads (dict): Sub heads to use, the key will be use to rename the
loss components.
common_cfg (dict): The common settings for all heads. Defaults to an
empty dict.
init_cfg (dict, optional): The extra initialization settings.
Defaults to None.
"""
def __init__(self, task_heads, init_cfg=None, **kwargs):
super(MultiTaskHead, self).__init__(init_cfg=init_cfg)
assert isinstance(task_heads, dict), 'The `task_heads` argument' \
"should be a dict, which's keys are task names and values are" \
'configs of head for the task.'
self.task_heads = ModuleDict()
for task_name, sub_head in task_heads.items():
if not isinstance(sub_head, nn.Module):
sub_head = MODELS.build(sub_head, default_args=kwargs)
sub_head.loss = loss_convertor(sub_head.loss, task_name)
self.task_heads[task_name] = sub_head
def forward(self, feats):
"""The forward process."""
return {
task_name: head(feats)
for task_name, head in self.task_heads.items()
}
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[MultiTaskDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
data_samples (List[MultiTaskDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components, each task loss
key will be prefixed by the task_name like "task1_loss"
"""
losses = dict()
for task_name, head in self.task_heads.items():
head_loss = head.loss(feats, data_samples, **kwargs)
for k, v in head_loss.items():
losses[f'{task_name}_{k}'] = v
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[MultiTaskDataSample] = None
) -> List[MultiTaskDataSample]:
"""Inference without augmentation.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
data_samples (List[MultiTaskDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[MultiTaskDataSample]: A list of data samples which contains
the predicted results.
"""
predictions_dict = dict()
for task_name, head in self.task_heads.items():
task_samples = head.predict(feats)
batch_size = len(task_samples)
predictions_dict[task_name] = task_samples
if data_samples is None:
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
for task_name, task_samples in predictions_dict.items():
for data_sample, task_sample in zip(data_samples, task_samples):
task_sample.set_field(
task_name in data_sample.tasks,
'eval_mask',
field_type='metainfo')
if task_name in data_sample.tasks:
data_sample.get(task_name).update(task_sample)
else:
data_sample.set_field(task_sample, task_name)
return data_samples

View File

@ -8,7 +8,8 @@ import torch.nn.functional as F
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmcls.registry import MODELS
from mmcls.structures import (batch_label_to_onehot, cat_batch_labels,
from mmcls.structures import (ClsDataSample, MultiTaskDataSample,
batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
from .batch_augments import RandomBatchAugment
@ -151,7 +152,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
self.pad_value)
data_samples = data.get('data_samples', None)
if data_samples is not None and 'gt_label' in data_samples[0]:
sample_item = data_samples[0] if data_samples is not None else None
if isinstance(sample_item,
ClsDataSample) and 'gt_label' in sample_item:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(
gt_labels, device=self.device)
@ -181,5 +184,7 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
if batch_score is not None:
for sample, score in zip(data_samples, batch_score):
sample.set_gt_score(score)
elif isinstance(sample_item, MultiTaskDataSample):
data_samples = self.cast_data(data_samples)
return {'inputs': inputs, 'data_samples': data_samples}

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cls_data_sample import ClsDataSample
from .multi_task_data_sample import MultiTaskDataSample
from .utils import (batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
__all__ = [
'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels',
'stack_batch_scores', 'tensor_split'
'stack_batch_scores', 'tensor_split', 'MultiTaskDataSample'
]

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.structures import BaseDataElement
class MultiTaskDataSample(BaseDataElement):
@property
def tasks(self):
return self._data_fields

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,40 @@
{
"metainfo": {
"tasks": [
"gender",
"wear"
]
},
"data_list": [
{
"img_path": "a/1.JPG",
"gt_label": {
"gender": 0
}
},
{
"img_path": "b/2.jpeg",
"gt_label": {
"gender": 0,
"wear": [
1,
0,
1,
0
]
}
},
{
"img_path": "b/subb/3.jpg",
"gt_label": {
"gender": 1,
"wear": [
0,
1,
0,
1
]
}
}
]
}

View File

@ -79,7 +79,7 @@ class TestBaseDataset(TestCase):
else:
self.assertIn('The `CLASSES` meta info is not set.', repr(dataset))
self.assertIn("Haven't been initialized", repr(dataset))
self.assertIn('Haven\'t been initialized', repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
@ -452,7 +452,7 @@ class TestCIFAR10(TestBaseDataset):
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}",
repr(dataset))
@classmethod
@ -597,7 +597,7 @@ class TestVOC(TestBaseDataset):
}
dataset = dataset_class(**cfg)
self.assertIn("Haven't been initialized", repr(dataset))
self.assertIn('Haven\'t been initialized', repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
@ -770,7 +770,7 @@ class TestMNIST(TestBaseDataset):
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}",
repr(dataset))
@classmethod
@ -874,3 +874,70 @@ class TestCUB(TestBaseDataset):
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestMultiTaskDataset(TestCase):
DATASET_TYPE = 'MultiTaskDataset'
DEFAULT_ARGS = dict(
data_root=ASSETS_ROOT,
ann_file=osp.join(ASSETS_ROOT, 'multi-task.json'),
pipeline=[])
def test_metainfo(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
metainfo = {'tasks': ['gender', 'wear']}
self.assertDictEqual(dataset.metainfo, metainfo)
self.assertFalse(dataset.test_mode)
def test_parse_data_info(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
data = dataset.parse_data_info({
'img_path': 'a.jpg',
'gt_label': {
'gender': 0
}
})
self.assertDictContainsSubset(
{
'img_path': os.path.join(ASSETS_ROOT, 'a.jpg'),
'gt_label': {
'gender': 0
}
}, data)
np.testing.assert_equal(data['gt_label']['gender'], 0)
# Test missing path
with self.assertRaisesRegex(AssertionError, 'have `img_path` field'):
dataset.parse_data_info(
{'gt_label': {
'gender': 0,
'wear': [1, 0, 1, 0]
}})
def test_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
task_doc = ('For 2 tasks\n gender \n wear ')
self.assertIn(task_doc, repr(dataset))
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file'])
self.assertIsInstance(data, list)
np.testing.assert_equal(len(data), 3)
np.testing.assert_equal(data[0]['gt_label'], {'gender': 0})
np.testing.assert_equal(data[1]['gt_label'], {
'gender': 0,
'wear': [1, 0, 1, 0]
})

View File

@ -10,7 +10,7 @@ from mmengine.structures import LabelData
from PIL import Image
from mmcls.registry import TRANSFORMS
from mmcls.structures import ClsDataSample
from mmcls.structures import ClsDataSample, MultiTaskDataSample
from mmcls.utils import register_all_modules
register_all_modules()
@ -51,9 +51,8 @@ class TestPackClsInputs(unittest.TestCase):
# Test without `img` and `gt_label`
del data['img']
del data['gt_label']
with self.assertWarnsRegex(Warning, 'Cannot get "img"'):
results = transform(copy.deepcopy(data))
self.assertNotIn('gt_label', results['data_samples'])
results = transform(copy.deepcopy(data))
self.assertNotIn('gt_label', results['data_samples'])
def test_repr(self):
cfg = dict(type='PackClsInputs', meta_keys=['flip', 'img_shape'])
@ -130,3 +129,54 @@ class TestCollect(unittest.TestCase):
cfg = dict(type='Collect', keys=['img'])
transform = TRANSFORMS.build(cfg)
self.assertEqual(repr(transform), "Collect(keys=['img'])")
class TestPackMultiTaskInputs(unittest.TestCase):
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
data = {
'sample_idx': 1,
'img_path': img_path,
'ori_shape': (300, 400),
'img_shape': (300, 400),
'scale_factor': 1.0,
'flip': False,
'img': mmcv.imread(img_path),
'gt_label': {
'task1': 1,
'task3': 3
},
}
cfg = dict(type='PackMultiTaskInputs', )
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], MultiTaskDataSample)
self.assertIn('flip', results['data_samples'].task1.metainfo_keys())
self.assertIsInstance(results['data_samples'].task1.gt_label,
LabelData)
# Test grayscale image
data['img'] = data['img'].mean(-1)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (1, 300, 400))
# Test without `img` and `gt_label`
del data['img']
del data['gt_label']
results = transform(copy.deepcopy(data))
self.assertNotIn('gt_label', results['data_samples'])
def test_repr(self):
cfg = dict(type='PackMultiTaskInputs', meta_keys=['img_shape'])
transform = TRANSFORMS.build(cfg)
rep = 'PackMultiTaskInputs(task_handlers={},'
rep += ' multi_task_fields=(\'gt_label\',),'
rep += ' meta_keys=[\'img_shape\'])'
self.assertEqual(repr(transform), rep)

View File

@ -0,0 +1,134 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmcls.evaluation.metrics import MultiTasksMetric
from mmcls.structures import ClsDataSample
class MultiTaskMetric(TestCase):
data_pred = [
{
'task0': torch.tensor([0.7, 0.0, 0.3]),
'task1': torch.tensor([0.5, 0.2, 0.3])
},
{
'task0': torch.tensor([0.0, 0.0, 1.0]),
'task1': torch.tensor([0.0, 0.0, 1.0])
},
]
data_gt = [{'task0': 0, 'task1': 2}, {'task1': 2}]
preds = []
for i, pred in enumerate(data_pred):
sample = {}
for task_name in pred:
task_sample = ClsDataSample().set_pred_score(pred[task_name])
if task_name in data_gt[i]:
task_sample.set_gt_label(data_gt[i][task_name])
task_sample.set_field(True, 'eval_mask', field_type='metainfo')
else:
task_sample.set_field(
False, 'eval_mask', field_type='metainfo')
sample[task_name] = task_sample.to_dict()
preds.append(sample)
data2 = zip([
{
'task0': torch.tensor([0.7, 0.0, 0.3]),
'task1': {
'task10': torch.tensor([0.5, 0.2, 0.3]),
'task11': torch.tensor([0.4, 0.3, 0.3])
}
},
{
'task0': torch.tensor([0.0, 0.0, 1.0]),
'task1': {
'task10': torch.tensor([0.1, 0.6, 0.3]),
'task11': torch.tensor([0.5, 0.2, 0.3])
}
},
], [{
'task0': 0,
'task1': {
'task10': 2,
'task11': 0
}
}, {
'task0': 2,
'task1': {
'task10': 1,
'task11': 0
}
}])
pred2 = []
for score, label in data2:
sample = {}
for task_name in score:
if type(score[task_name]) != dict:
task_sample = ClsDataSample().set_pred_score(score[task_name])
task_sample.set_gt_label(label[task_name])
sample[task_name] = task_sample.to_dict()
sample[task_name]['eval_mask'] = True
else:
sample[task_name] = {}
sample[task_name]['eval_mask'] = True
for task_name2 in score[task_name]:
task_sample = ClsDataSample().set_pred_score(
score[task_name][task_name2])
task_sample.set_gt_label(label[task_name][task_name2])
sample[task_name][task_name2] = task_sample.to_dict()
sample[task_name][task_name2]['eval_mask'] = True
pred2.append(sample)
pred3 = [{'task0': {'eval_mask': False}, 'task1': {'eval_mask': False}}]
task_metrics = {
'task0': [dict(type='Accuracy', topk=(1, ))],
'task1': [
dict(type='Accuracy', topk=(1, 3)),
dict(type='SingleLabelMetric', items=['precision', 'recall'])
]
}
task_metrics2 = {
'task0': [dict(type='Accuracy', topk=(1, ))],
'task1': [
dict(
type='MultiTasksMetric',
task_metrics={
'task10': [
dict(type='Accuracy', topk=(1, 3)),
dict(type='SingleLabelMetric', items=['precision'])
],
'task11': [dict(type='Accuracy', topk=(1, ))]
})
]
}
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
# Test with score (use score instead of label if score exists)
metric = MultiTasksMetric(self.task_metrics)
metric.process(None, self.preds)
results = metric.evaluate(2)
self.assertIsInstance(results, dict)
self.assertAlmostEqual(results['task0_accuracy/top1'], 100)
self.assertGreater(results['task1_single-label/precision'], 0)
# Test nested
metric = MultiTasksMetric(self.task_metrics2)
metric.process(None, self.pred2)
results = metric.evaluate(2)
self.assertIsInstance(results, dict)
self.assertGreater(results['task1_task10_single-label/precision'], 0)
self.assertGreater(results['task1_task11_accuracy/top1'], 0)
# Test with without any ground truth value
metric = MultiTasksMetric(self.task_metrics)
metric.process(None, self.pred3)
results = metric.evaluate(2)
self.assertIsInstance(results, dict)
self.assertEqual(results['task0_Accuracy'], 0)

View File

@ -10,7 +10,7 @@ import torch
from mmengine import is_seq_of
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from mmcls.structures import ClsDataSample, MultiTaskDataSample
from mmcls.utils import register_all_modules
register_all_modules()
@ -484,6 +484,142 @@ class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
head(feats)
class TestMultiTaskHead(TestCase):
DEFAULT_ARGS = dict(
type='MultiTaskHead', # <- Head config, depends on #675
task_heads={
'task0': dict(type='LinearClsHead', num_classes=3),
'task1': dict(type='LinearClsHead', num_classes=6),
},
in_channels=10,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
)
DEFAULT_ARGS2 = dict(
type='MultiTaskHead', # <- Head config, depends on #675
task_heads={
'task0':
dict(
type='MultiTaskHead',
task_heads={
'task00': dict(type='LinearClsHead', num_classes=3),
'task01': dict(type='LinearClsHead', num_classes=6),
}),
'task1':
dict(type='LinearClsHead', num_classes=6)
},
in_channels=10,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
)
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), )
outs = head(feats)
self.assertEqual(outs['task0'].shape, (4, 3))
self.assertEqual(outs['task1'].shape, (4, 6))
self.assertTrue(isinstance(outs, dict))
def test_loss(self):
feats = (torch.rand(4, 10), )
data_samples = []
for _ in range(4):
data_sample = MultiTaskDataSample()
for task_name in self.DEFAULT_ARGS['task_heads']:
task_sample = ClsDataSample().set_gt_label(1)
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(
losses.keys(),
{'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'})
self.assertGreater(losses['task0_loss'].item(), 0)
self.assertGreater(losses['task1_loss'].item(), 0)
def test_predict(self):
feats = (torch.rand(4, 10), )
data_samples = []
for _ in range(4):
data_sample = MultiTaskDataSample()
for task_name in self.DEFAULT_ARGS['task_heads']:
task_sample = ClsDataSample().set_gt_label(1)
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for pred in predictions:
self.assertIn('task0', pred)
task0_sample = predictions[0].task0
self.assertTrue(type(task0_sample.pred_label.score), 'torch.tensor')
# with with data_samples
predictions = head.predict(feats, data_samples)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('task0', pred)
def test_loss_empty_data_sample(self):
feats = (torch.rand(4, 10), )
data_samples = []
for _ in range(4):
data_sample = MultiTaskDataSample()
data_samples.append(data_sample)
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(feats, data_samples)
self.assertEqual(
losses.keys(),
{'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'})
self.assertEqual(losses['task0_loss'].item(), 0)
self.assertEqual(losses['task1_loss'].item(), 0)
def test_nested_multi_task_loss(self):
head = MODELS.build(self.DEFAULT_ARGS2)
# return the last item (same as pre_logits)
feats = (torch.rand(4, 10), )
outs = head(feats)
self.assertEqual(outs['task0']['task01'].shape, (4, 6))
self.assertTrue(isinstance(outs, dict))
self.assertTrue(isinstance(outs['task0'], dict))
def test_nested_invalid_sample(self):
feats = (torch.rand(4, 10), )
gt_label = {'task0': 1, 'task1': 1}
head = MODELS.build(self.DEFAULT_ARGS2)
data_sample = MultiTaskDataSample()
for task_name in gt_label:
task_sample = ClsDataSample().set_gt_label(gt_label[task_name])
data_sample.set_field(task_sample, task_name)
with self.assertRaises(Exception):
head.loss(feats, data_sample)
def test_nested_invalid_sample2(self):
feats = (torch.rand(4, 10), )
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
head = MODELS.build(self.DEFAULT_ARGS)
data_sample = MultiTaskDataSample()
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = ClsDataSample().set_gt_label(
gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name)
with self.assertRaises(Exception):
head.loss(feats, data_sample)
class TestArcFaceClsHead(TestCase):
DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5)

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from mmengine.structures import LabelData
from mmcls.structures import ClsDataSample
from mmcls.structures import ClsDataSample, MultiTaskDataSample
class TestClsDataSample(TestCase):
@ -122,3 +122,20 @@ class TestClsDataSample(TestCase):
with self.assertRaisesRegex(AssertionError, 'but got 2'):
data_sample.set_pred_score(
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
class TestMultiTaskDataSample(TestCase):
def test_multi_task_data_sample(self):
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample()
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = ClsDataSample().set_gt_label(
gt_label['task0'][task_name])
data_sample.task0.set_field(task_sample, task_name)
self.assertIsInstance(data_sample.task0, MultiTaskDataSample)
self.assertIsInstance(data_sample.task1, ClsDataSample)
self.assertIsInstance(data_sample.task0.task00, ClsDataSample)