[Feature] Support Multi-task. (#1229)
* unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Remove multi-task configs. Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com> Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz <maroineamil99@gmail.com> Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>pull/1285/head
parent
5b266d9e7c
commit
bac181f393
|
@ -8,6 +8,7 @@ from .dataset_wrappers import KFoldDataset
|
|||
from .imagenet import ImageNet, ImageNet21k
|
||||
from .mnist import MNIST, FashionMNIST
|
||||
from .multi_label import MultiLabelDataset
|
||||
from .multi_task import MultiTaskDataset
|
||||
from .samplers import * # noqa: F401,F403
|
||||
from .transforms import * # noqa: F401,F403
|
||||
from .voc import VOC
|
||||
|
@ -15,5 +16,5 @@ from .voc import VOC
|
|||
__all__ = [
|
||||
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
||||
'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
|
||||
'CustomDataset', 'MultiLabelDataset'
|
||||
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,344 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from os import PathLike
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import mmengine
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.fileio import FileClient
|
||||
|
||||
from .builder import DATASETS
|
||||
|
||||
|
||||
def expanduser(path):
|
||||
if isinstance(path, (str, PathLike)):
|
||||
return osp.expanduser(path)
|
||||
else:
|
||||
return path
|
||||
|
||||
|
||||
def isabs(uri):
|
||||
return osp.isabs(uri) or ('://' in uri)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MultiTaskDataset:
|
||||
"""Custom dataset for multi-task dataset.
|
||||
|
||||
To use the dataset, please generate and provide an annotation file in the
|
||||
below format:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"metainfo": {
|
||||
"tasks":
|
||||
[
|
||||
'gender'
|
||||
'wear'
|
||||
]
|
||||
},
|
||||
"data_list": [
|
||||
{
|
||||
"img_path": "a.jpg",
|
||||
gt_label:{
|
||||
"gender": 0,
|
||||
"wear": [1, 0, 1, 0]
|
||||
}
|
||||
},
|
||||
{
|
||||
"img_path": "b.jpg",
|
||||
gt_label:{
|
||||
"gender": 1,
|
||||
"wear": [1, 0, 1, 0]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Assume we put our dataset in the ``data/mydataset`` folder in the
|
||||
repository and organize it as the below format: ::
|
||||
|
||||
mmclassification/
|
||||
└── data
|
||||
└── mydataset
|
||||
├── annotation
|
||||
│ ├── train.json
|
||||
│ ├── test.json
|
||||
│ └── val.json
|
||||
├── train
|
||||
│ ├── a.jpg
|
||||
│ └── ...
|
||||
├── test
|
||||
│ ├── b.jpg
|
||||
│ └── ...
|
||||
└── val
|
||||
├── c.jpg
|
||||
└── ...
|
||||
|
||||
We can use the below config to build datasets:
|
||||
|
||||
.. code:: python
|
||||
|
||||
>>> from mmcls.datasets import build_dataset
|
||||
>>> train_cfg = dict(
|
||||
... type="MultiTaskDataset",
|
||||
... ann_file="annotation/train.json",
|
||||
... data_root="data/mydataset",
|
||||
... # The `img_path` field in the train annotation file is relative
|
||||
... # to the `train` folder.
|
||||
... data_prefix='train',
|
||||
... )
|
||||
>>> train_dataset = build_dataset(train_cfg)
|
||||
|
||||
Or we can put all files in the same folder: ::
|
||||
|
||||
mmclassification/
|
||||
└── data
|
||||
└── mydataset
|
||||
├── train.json
|
||||
├── test.json
|
||||
├── val.json
|
||||
├── a.jpg
|
||||
├── b.jpg
|
||||
├── c.jpg
|
||||
└── ...
|
||||
|
||||
And we can use the below config to build datasets:
|
||||
|
||||
.. code:: python
|
||||
|
||||
>>> from mmcls.datasets import build_dataset
|
||||
>>> train_cfg = dict(
|
||||
... type="MultiTaskDataset",
|
||||
... ann_file="train.json",
|
||||
... data_root="data/mydataset",
|
||||
... # the `data_prefix` is not required since all paths are
|
||||
... # relative to the `data_root`.
|
||||
... )
|
||||
>>> train_dataset = build_dataset(train_cfg)
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): The annotation file path. It can be either absolute
|
||||
path or relative path to the ``data_root``.
|
||||
metainfo (dict, optional): The extra meta information. It should be
|
||||
a dict with the same format as the ``"metainfo"`` field in the
|
||||
annotation file. Defaults to None.
|
||||
data_root (str, optional): The root path of the data directory. It's
|
||||
the prefix of the ``data_prefix`` and the ``ann_file``. And it can
|
||||
be a remote path like "s3://openmmlab/xxx/". Defaults to None.
|
||||
data_prefix (str, optional): The base folder relative to the
|
||||
``data_root`` for the ``"img_path"`` field in the annotation file.
|
||||
Defaults to None.
|
||||
pipeline (Sequence[dict]): A list of dict, where each element
|
||||
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
|
||||
Defaults to an empty tuple.
|
||||
test_mode (bool): in train mode or test mode. Defaults to False.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
||||
If None, automatically inference from the ``data_root``.
|
||||
Defaults to None.
|
||||
"""
|
||||
METAINFO = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str,
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: Optional[str] = None,
|
||||
pipeline: Sequence = (),
|
||||
test_mode: bool = False,
|
||||
file_client_args: Optional[dict] = None):
|
||||
|
||||
self.data_root = expanduser(data_root)
|
||||
|
||||
# Inference the file client
|
||||
if self.data_root is not None:
|
||||
file_client = FileClient.infer_client(
|
||||
file_client_args, uri=self.data_root)
|
||||
else:
|
||||
file_client = FileClient(file_client_args)
|
||||
self.file_client: FileClient = file_client
|
||||
|
||||
self.ann_file = self._join_root(expanduser(ann_file))
|
||||
self.data_prefix = self._join_root(data_prefix)
|
||||
|
||||
self.test_mode = test_mode
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.data_list = self.load_data_list(self.ann_file, metainfo)
|
||||
|
||||
def _join_root(self, path):
|
||||
"""Join ``self.data_root`` with the specified path.
|
||||
|
||||
If the path is an absolute path, just return the path. And if the
|
||||
path is None, return ``self.data_root``.
|
||||
|
||||
Examples:
|
||||
>>> self.data_root = 'a/b/c'
|
||||
>>> self._join_root('d/e/')
|
||||
'a/b/c/d/e'
|
||||
>>> self._join_root('https://openmmlab.com')
|
||||
'https://openmmlab.com'
|
||||
>>> self._join_root(None)
|
||||
'a/b/c'
|
||||
"""
|
||||
if path is None:
|
||||
return self.data_root
|
||||
if isabs(path):
|
||||
return path
|
||||
|
||||
joined_path = self.file_client.join_path(self.data_root, path)
|
||||
return joined_path
|
||||
|
||||
@classmethod
|
||||
def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
|
||||
"""Collect meta information from the dictionary of meta.
|
||||
|
||||
Args:
|
||||
in_metainfo (dict): Meta information dict.
|
||||
|
||||
Returns:
|
||||
dict: Parsed meta information.
|
||||
"""
|
||||
# `cls.METAINFO` will be overwritten by in_meta
|
||||
metainfo = copy.deepcopy(cls.METAINFO)
|
||||
if in_metainfo is None:
|
||||
return metainfo
|
||||
|
||||
metainfo.update(in_metainfo)
|
||||
|
||||
return metainfo
|
||||
|
||||
def load_data_list(self, ann_file, metainfo_override=None):
|
||||
"""Load annotations from an annotation file.
|
||||
|
||||
Args:
|
||||
ann_file (str): Absolute annotation file path if ``self.root=None``
|
||||
or relative path if ``self.root=/path/to/data/``.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of annotation.
|
||||
"""
|
||||
annotations = mmengine.load(ann_file)
|
||||
if not isinstance(annotations, dict):
|
||||
raise TypeError(f'The annotations loaded from annotation file '
|
||||
f'should be a dict, but got {type(annotations)}!')
|
||||
if 'data_list' not in annotations:
|
||||
raise ValueError('The annotation file must have the `data_list` '
|
||||
'field.')
|
||||
metainfo = annotations.get('metainfo', {})
|
||||
raw_data_list = annotations['data_list']
|
||||
|
||||
# Set meta information.
|
||||
assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
|
||||
f'annotation file should be a dict, but got {type(metainfo)}'
|
||||
if metainfo_override is not None:
|
||||
assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
|
||||
f'argument should be a dict, but got {type(metainfo_override)}'
|
||||
metainfo.update(metainfo_override)
|
||||
self._metainfo = self._get_meta_info(metainfo)
|
||||
|
||||
data_list = []
|
||||
for i, raw_data in enumerate(raw_data_list):
|
||||
try:
|
||||
data_list.append(self.parse_data_info(raw_data))
|
||||
except AssertionError as e:
|
||||
raise RuntimeError(
|
||||
f'The format check fails during parse the item {i} of '
|
||||
f'the annotation file with error: {e}')
|
||||
return data_list
|
||||
|
||||
def parse_data_info(self, raw_data):
|
||||
"""Parse raw annotation to target format.
|
||||
|
||||
This method will return a dict which contains the data information of a
|
||||
sample.
|
||||
|
||||
Args:
|
||||
raw_data (dict): Raw data information load from ``ann_file``
|
||||
|
||||
Returns:
|
||||
dict: Parsed annotation.
|
||||
"""
|
||||
assert isinstance(raw_data, dict), \
|
||||
f'The item should be a dict, but got {type(raw_data)}'
|
||||
assert 'img_path' in raw_data, \
|
||||
"The item doesn't have `img_path` field."
|
||||
data = dict(
|
||||
img_path=self._join_root(raw_data['img_path']),
|
||||
gt_label=raw_data['gt_label'],
|
||||
)
|
||||
return data
|
||||
|
||||
@property
|
||||
def metainfo(self) -> dict:
|
||||
"""Get meta information of dataset.
|
||||
|
||||
Returns:
|
||||
dict: meta information collected from ``cls.METAINFO``,
|
||||
annotation file and metainfo argument during instantiation.
|
||||
"""
|
||||
return copy.deepcopy(self._metainfo)
|
||||
|
||||
def prepare_data(self, idx):
|
||||
"""Get data processed by ``self.pipeline``.
|
||||
|
||||
Args:
|
||||
idx (int): The index of ``data_info``.
|
||||
|
||||
Returns:
|
||||
Any: Depends on ``self.pipeline``.
|
||||
"""
|
||||
results = copy.deepcopy(self.data_list[idx])
|
||||
return self.pipeline(results)
|
||||
|
||||
def __len__(self):
|
||||
"""Get the length of the whole dataset.
|
||||
|
||||
Returns:
|
||||
int: The length of filtered dataset.
|
||||
"""
|
||||
return len(self.data_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get the idx-th image and data information of dataset after
|
||||
``self.pipeline``.
|
||||
|
||||
Args:
|
||||
idx (int): The index of of the data.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th image and data information after
|
||||
``self.pipeline``.
|
||||
"""
|
||||
return self.prepare_data(idx)
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the basic information of the dataset.
|
||||
|
||||
Returns:
|
||||
str: Formatted string.
|
||||
"""
|
||||
head = 'Dataset ' + self.__class__.__name__
|
||||
body = [f'Number of samples: \t{self.__len__()}']
|
||||
if self.data_root is not None:
|
||||
body.append(f'Root location: \t{self.data_root}')
|
||||
body.append(f'Annotation file: \t{self.ann_file}')
|
||||
if self.data_prefix is not None:
|
||||
body.append(f'Prefix of images: \t{self.data_prefix}')
|
||||
# -------------------- extra repr --------------------
|
||||
tasks = self.metainfo['tasks']
|
||||
body.append(f'For {len(tasks)} tasks')
|
||||
for task in tasks:
|
||||
body.append(f' {task} ')
|
||||
# ----------------------------------------------------
|
||||
|
||||
if len(self.pipeline.transforms) > 0:
|
||||
body.append('With transforms:')
|
||||
for t in self.pipeline.transforms:
|
||||
body.append(f' {t}')
|
||||
|
||||
lines = [head] + [' ' * 4 + line for line in body]
|
||||
return '\n'.join(lines)
|
|
@ -3,7 +3,8 @@ from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform,
|
|||
Brightness, ColorTransform, Contrast, Cutout,
|
||||
Equalize, Invert, Posterize, RandAugment, Rotate,
|
||||
Sharpness, Shear, Solarize, SolarizeAdd, Translate)
|
||||
from .formatting import Collect, PackClsInputs, ToNumpy, ToPIL, Transpose
|
||||
from .formatting import (Collect, PackClsInputs, PackMultiTaskInputs, ToNumpy,
|
||||
ToPIL, Transpose)
|
||||
from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop,
|
||||
EfficientNetRandomCrop, Lighting, RandomCrop,
|
||||
RandomErasing, RandomResizedCrop, ResizeEdge)
|
||||
|
@ -15,5 +16,6 @@ __all__ = [
|
|||
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
|
||||
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
|
||||
'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop',
|
||||
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform'
|
||||
'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform',
|
||||
'PackMultiTaskInputs'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -9,7 +10,7 @@ from mmengine.utils import is_str
|
|||
from PIL import Image
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.structures import ClsDataSample, MultiTaskDataSample
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
|
@ -85,12 +86,6 @@ class PackClsInputs(BaseTransform):
|
|||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
packed_results['inputs'] = to_tensor(img)
|
||||
else:
|
||||
warnings.warn(
|
||||
'Cannot get "img" in the input dict of `PackClsInputs`,'
|
||||
'please make sure `LoadImageFromFile` has been added '
|
||||
'in the data pipeline or images have been loaded in '
|
||||
'the dataset.')
|
||||
|
||||
data_sample = ClsDataSample()
|
||||
if 'gt_label' in results:
|
||||
|
@ -100,7 +95,6 @@ class PackClsInputs(BaseTransform):
|
|||
img_meta = {k: results[k] for k in self.meta_keys if k in results}
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -109,6 +103,84 @@ class PackClsInputs(BaseTransform):
|
|||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PackMultiTaskInputs(BaseTransform):
|
||||
"""Convert all image labels of multi-task dataset to a dict of tensor.
|
||||
|
||||
Args:
|
||||
tasks (List[str]): The task names defined in the dataset.
|
||||
meta_keys(Sequence[str]): The meta keys to be saved in the
|
||||
``metainfo`` of the packed ``data_samples``.
|
||||
Defaults to a tuple includes keys:
|
||||
|
||||
- ``sample_idx``: The id of the image sample.
|
||||
- ``img_path``: The path to the image file.
|
||||
- ``ori_shape``: The original shape of the image as a tuple (H, W).
|
||||
- ``img_shape``: The shape of the image after the pipeline as a
|
||||
tuple (H, W).
|
||||
- ``scale_factor``: The scale factor between the resized image and
|
||||
the original image.
|
||||
- ``flip``: A boolean indicating if image flip transform was used.
|
||||
- ``flip_direction``: The flipping direction.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
task_handlers=dict(),
|
||||
multi_task_fields=('gt_label', ),
|
||||
meta_keys=('sample_idx', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'flip', 'flip_direction')):
|
||||
self.multi_task_fields = multi_task_fields
|
||||
self.meta_keys = meta_keys
|
||||
self.task_handlers = defaultdict(
|
||||
partial(PackClsInputs, meta_keys=meta_keys))
|
||||
for task_name, task_handler in task_handlers.items():
|
||||
self.task_handlers[task_name] = TRANSFORMS.build(
|
||||
dict(type=task_handler, meta_keys=meta_keys))
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to pack the input data.
|
||||
|
||||
result = {'img_path': 'a.png', 'gt_label': {'task1': 1, 'task3': 3},
|
||||
'img': array([[[ 0, 0, 0])
|
||||
"""
|
||||
packed_results = dict()
|
||||
results = results.copy()
|
||||
|
||||
if 'img' in results:
|
||||
img = results.pop('img')
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
packed_results['inputs'] = to_tensor(img)
|
||||
|
||||
task_results = defaultdict(dict)
|
||||
for field in self.multi_task_fields:
|
||||
if field in results:
|
||||
value = results.pop(field)
|
||||
for k, v in value.items():
|
||||
task_results[k].update({field: v})
|
||||
|
||||
data_sample = MultiTaskDataSample()
|
||||
for task_name, task_result in task_results.items():
|
||||
task_handler = self.task_handlers[task_name]
|
||||
task_pack_result = task_handler({**results, **task_result})
|
||||
data_sample.set_field(task_pack_result['data_samples'], task_name)
|
||||
|
||||
packed_results['data_samples'] = data_sample
|
||||
return packed_results
|
||||
|
||||
def __repr__(self):
|
||||
repr = self.__class__.__name__
|
||||
task_handlers = {
|
||||
name: handler.__class__.__name__
|
||||
for name, handler in self.task_handlers.items()
|
||||
}
|
||||
repr += f'(task_handlers={task_handlers}, '
|
||||
repr += f'multi_task_fields={self.multi_task_fields}, '
|
||||
repr += f'meta_keys={self.meta_keys})'
|
||||
return repr
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Transpose(BaseTransform):
|
||||
"""Transpose numpy array.
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
from .multi_task import MultiTasksMetric
|
||||
from .single_label import Accuracy, SingleLabelMetric
|
||||
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
|
||||
|
||||
__all__ = [
|
||||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
|
||||
'VOCAveragePrecision', 'VOCMultiLabelMetric'
|
||||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Sequence
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class MultiTasksMetric(BaseMetric):
|
||||
"""Metrics for MultiTask
|
||||
Args:
|
||||
task_metrics(dict): a dictionary in the keys are the names of the tasks
|
||||
and the values is a list of the metric corresponds to this task
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import MultiTasksMetric
|
||||
# -------------------- The Basic Usage --------------------
|
||||
>>>task_metrics = {
|
||||
'task0': [dict(type='Accuracy', topk=(1, ))],
|
||||
'task1': [dict(type='Accuracy', topk=(1, 3))]
|
||||
}
|
||||
>>>pred = [{
|
||||
'pred_task': {
|
||||
'task0': torch.tensor([0.7, 0.0, 0.3]),
|
||||
'task1': torch.tensor([0.5, 0.2, 0.3])
|
||||
},
|
||||
'gt_task': {
|
||||
'task0': torch.tensor(0),
|
||||
'task1': torch.tensor(2)
|
||||
}
|
||||
}, {
|
||||
'pred_task': {
|
||||
'task0': torch.tensor([0.0, 0.0, 1.0]),
|
||||
'task1': torch.tensor([0.0, 0.0, 1.0])
|
||||
},
|
||||
'gt_task': {
|
||||
'task0': torch.tensor(2),
|
||||
'task1': torch.tensor(2)
|
||||
}
|
||||
}]
|
||||
>>>metric = MultiTasksMetric(task_metrics)
|
||||
>>>metric.process(None, pred)
|
||||
>>>results = metric.evaluate(2)
|
||||
results = {
|
||||
'task0_accuracy/top1': 100.0,
|
||||
'task1_accuracy/top1': 50.0,
|
||||
'task1_accuracy/top3': 100.0
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
task_metrics: Dict,
|
||||
collect_device: str = 'cpu') -> None:
|
||||
self.task_metrics = task_metrics
|
||||
super().__init__(collect_device=collect_device)
|
||||
|
||||
self._metrics = {}
|
||||
for task_name in self.task_metrics.keys():
|
||||
self._metrics[task_name] = []
|
||||
for metric in self.task_metrics[task_name]:
|
||||
self._metrics[task_name].append(METRICS.build(metric))
|
||||
|
||||
def process(self, data_batch, data_samples: Sequence[dict]):
|
||||
"""Process one batch of data samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
Args:
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for task_name in self.task_metrics.keys():
|
||||
filtered_data_samples = []
|
||||
for data_sample in data_samples:
|
||||
eval_mask = data_sample[task_name]['eval_mask']
|
||||
if eval_mask:
|
||||
filtered_data_samples.append(data_sample[task_name])
|
||||
for metric in self._metrics[task_name]:
|
||||
metric.process(data_batch, filtered_data_samples)
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
raise NotImplementedError(
|
||||
'compute metrics should not be used here directly')
|
||||
|
||||
def evaluate(self, size):
|
||||
"""Evaluate the model performance of the whole dataset after processing
|
||||
all batches.
|
||||
|
||||
Args:
|
||||
size (int): Length of the entire validation dataset. When batch
|
||||
size > 1, the dataloader may pad some data samples to make
|
||||
sure all ranks have the same length of dataset slice. The
|
||||
``collect_results`` function will drop the padded data based on
|
||||
this size.
|
||||
Returns:
|
||||
dict: Evaluation metrics dict on the val dataset. The keys are
|
||||
"{task_name}_{metric_name}" , and the values
|
||||
are corresponding results.
|
||||
"""
|
||||
metrics = {}
|
||||
for task_name in self._metrics:
|
||||
for metric in self._metrics[task_name]:
|
||||
name = metric.__class__.__name__
|
||||
if name == 'MultiTasksMetric' or metric.results:
|
||||
results = metric.evaluate(size)
|
||||
else:
|
||||
results = {metric.__class__.__name__: 0}
|
||||
for key in results:
|
||||
name = f'{task_name}_{key}'
|
||||
if name in results:
|
||||
"""Inspired from https://github.com/open-
|
||||
mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2
|
||||
747172e/mmengine/evalua tor/evaluator.py#L84-L87."""
|
||||
raise ValueError(
|
||||
'There are multiple metric results with the same'
|
||||
f'metric name {name}. Please make sure all metrics'
|
||||
'have different prefixes.')
|
||||
metrics[name] = results[key]
|
||||
return metrics
|
|
@ -8,11 +8,13 @@ from .margin_head import ArcFaceClsHead
|
|||
from .multi_label_cls_head import MultiLabelClsHead
|
||||
from .multi_label_csra_head import CSRAClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
from .multi_task_head import MultiTaskHead
|
||||
from .stacked_head import StackedLinearClsHead
|
||||
from .vision_transformer_head import VisionTransformerClsHead
|
||||
|
||||
__all__ = [
|
||||
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
|
||||
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
|
||||
'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead', 'CSRAClsHead'
|
||||
'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead', 'CSRAClsHead',
|
||||
'MultiTaskHead'
|
||||
]
|
||||
|
|
|
@ -108,9 +108,10 @@ class ClsHead(BaseHead):
|
|||
return losses
|
||||
|
||||
def predict(
|
||||
self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
|
||||
self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
data_samples: List[Union[ClsDataSample, None]] = None
|
||||
) -> List[ClsDataSample]:
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
|
@ -118,7 +119,7 @@ class ClsHead(BaseHead):
|
|||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data_samples (List[ClsDataSample | None], optional): The annotation
|
||||
data of every samples. If not None, set ``pred_label`` of
|
||||
the input data samples. Defaults to None.
|
||||
|
||||
|
@ -141,14 +142,15 @@ class ClsHead(BaseHead):
|
|||
pred_scores = F.softmax(cls_score, dim=1)
|
||||
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
|
||||
|
||||
if data_samples is not None:
|
||||
for data_sample, score, label in zip(data_samples, pred_scores,
|
||||
pred_labels):
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
else:
|
||||
data_samples = []
|
||||
for score, label in zip(pred_scores, pred_labels):
|
||||
data_samples.append(ClsDataSample().set_pred_score(
|
||||
score).set_pred_label(label))
|
||||
out_data_samples = []
|
||||
if data_samples is None:
|
||||
data_samples = [None for _ in range(pred_scores.size(0))]
|
||||
|
||||
return data_samples
|
||||
for data_sample, score, label in zip(data_samples, pred_scores,
|
||||
pred_labels):
|
||||
if data_sample is None:
|
||||
data_sample = ClsDataSample()
|
||||
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
out_data_samples.append(data_sample)
|
||||
return out_data_samples
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import ModuleDict
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import MultiTaskDataSample
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
def loss_convertor(loss_func, task_name):
|
||||
|
||||
def wrapped(inputs, data_samples, **kwargs):
|
||||
mask = torch.empty(len(data_samples), dtype=torch.bool)
|
||||
task_data_samples = []
|
||||
for i, data_sample in enumerate(data_samples):
|
||||
assert isinstance(data_sample, MultiTaskDataSample)
|
||||
sample_mask = task_name in data_sample
|
||||
mask[i] = sample_mask
|
||||
if sample_mask:
|
||||
task_data_samples.append(data_sample.get(task_name))
|
||||
|
||||
if len(task_data_samples) == 0:
|
||||
return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)}
|
||||
|
||||
# Mask the inputs of the task
|
||||
def mask_inputs(inputs, mask):
|
||||
if isinstance(inputs, Sequence):
|
||||
return type(inputs)(
|
||||
[mask_inputs(input, mask) for input in inputs])
|
||||
elif isinstance(inputs, torch.Tensor):
|
||||
return inputs[mask]
|
||||
|
||||
masked_inputs = mask_inputs(inputs, mask)
|
||||
loss_output = loss_func(masked_inputs, task_data_samples, **kwargs)
|
||||
loss_output['mask_size'] = mask.sum().to(torch.float)
|
||||
return loss_output
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MultiTaskHead(BaseHead):
|
||||
"""Multi task head.
|
||||
|
||||
Args:
|
||||
task_heads (dict): Sub heads to use, the key will be use to rename the
|
||||
loss components.
|
||||
common_cfg (dict): The common settings for all heads. Defaults to an
|
||||
empty dict.
|
||||
init_cfg (dict, optional): The extra initialization settings.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, task_heads, init_cfg=None, **kwargs):
|
||||
super(MultiTaskHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
assert isinstance(task_heads, dict), 'The `task_heads` argument' \
|
||||
"should be a dict, which's keys are task names and values are" \
|
||||
'configs of head for the task.'
|
||||
|
||||
self.task_heads = ModuleDict()
|
||||
|
||||
for task_name, sub_head in task_heads.items():
|
||||
if not isinstance(sub_head, nn.Module):
|
||||
sub_head = MODELS.build(sub_head, default_args=kwargs)
|
||||
sub_head.loss = loss_convertor(sub_head.loss, task_name)
|
||||
self.task_heads[task_name] = sub_head
|
||||
|
||||
def forward(self, feats):
|
||||
"""The forward process."""
|
||||
return {
|
||||
task_name: head(feats)
|
||||
for task_name, head in self.task_heads.items()
|
||||
}
|
||||
|
||||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[MultiTaskDataSample], **kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||
data_samples (List[MultiTaskDataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components, each task loss
|
||||
key will be prefixed by the task_name like "task1_loss"
|
||||
"""
|
||||
losses = dict()
|
||||
for task_name, head in self.task_heads.items():
|
||||
head_loss = head.loss(feats, data_samples, **kwargs)
|
||||
for k, v in head_loss.items():
|
||||
losses[f'{task_name}_{k}'] = v
|
||||
return losses
|
||||
|
||||
def predict(
|
||||
self,
|
||||
feats: Tuple[torch.Tensor],
|
||||
data_samples: List[MultiTaskDataSample] = None
|
||||
) -> List[MultiTaskDataSample]:
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||
data_samples (List[MultiTaskDataSample], optional): The annotation
|
||||
data of every samples. If not None, set ``pred_label`` of
|
||||
the input data samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[MultiTaskDataSample]: A list of data samples which contains
|
||||
the predicted results.
|
||||
"""
|
||||
predictions_dict = dict()
|
||||
|
||||
for task_name, head in self.task_heads.items():
|
||||
task_samples = head.predict(feats)
|
||||
batch_size = len(task_samples)
|
||||
predictions_dict[task_name] = task_samples
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
|
||||
|
||||
for task_name, task_samples in predictions_dict.items():
|
||||
for data_sample, task_sample in zip(data_samples, task_samples):
|
||||
task_sample.set_field(
|
||||
task_name in data_sample.tasks,
|
||||
'eval_mask',
|
||||
field_type='metainfo')
|
||||
|
||||
if task_name in data_sample.tasks:
|
||||
data_sample.get(task_name).update(task_sample)
|
||||
else:
|
||||
data_sample.set_field(task_sample, task_name)
|
||||
|
||||
return data_samples
|
|
@ -8,7 +8,8 @@ import torch.nn.functional as F
|
|||
from mmengine.model import BaseDataPreprocessor, stack_batch
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import (batch_label_to_onehot, cat_batch_labels,
|
||||
from mmcls.structures import (ClsDataSample, MultiTaskDataSample,
|
||||
batch_label_to_onehot, cat_batch_labels,
|
||||
stack_batch_scores, tensor_split)
|
||||
from .batch_augments import RandomBatchAugment
|
||||
|
||||
|
@ -151,7 +152,9 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
self.pad_value)
|
||||
|
||||
data_samples = data.get('data_samples', None)
|
||||
if data_samples is not None and 'gt_label' in data_samples[0]:
|
||||
sample_item = data_samples[0] if data_samples is not None else None
|
||||
if isinstance(sample_item,
|
||||
ClsDataSample) and 'gt_label' in sample_item:
|
||||
gt_labels = [sample.gt_label for sample in data_samples]
|
||||
batch_label, label_indices = cat_batch_labels(
|
||||
gt_labels, device=self.device)
|
||||
|
@ -181,5 +184,7 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
if batch_score is not None:
|
||||
for sample, score in zip(data_samples, batch_score):
|
||||
sample.set_gt_score(score)
|
||||
elif isinstance(sample_item, MultiTaskDataSample):
|
||||
data_samples = self.cast_data(data_samples)
|
||||
|
||||
return {'inputs': inputs, 'data_samples': data_samples}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cls_data_sample import ClsDataSample
|
||||
from .multi_task_data_sample import MultiTaskDataSample
|
||||
from .utils import (batch_label_to_onehot, cat_batch_labels,
|
||||
stack_batch_scores, tensor_split)
|
||||
|
||||
__all__ = [
|
||||
'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels',
|
||||
'stack_batch_scores', 'tensor_split'
|
||||
'stack_batch_scores', 'tensor_split', 'MultiTaskDataSample'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
|
||||
class MultiTaskDataSample(BaseDataElement):
|
||||
|
||||
@property
|
||||
def tasks(self):
|
||||
return self._data_fields
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1,40 @@
|
|||
{
|
||||
"metainfo": {
|
||||
"tasks": [
|
||||
"gender",
|
||||
"wear"
|
||||
]
|
||||
},
|
||||
"data_list": [
|
||||
{
|
||||
"img_path": "a/1.JPG",
|
||||
"gt_label": {
|
||||
"gender": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"img_path": "b/2.jpeg",
|
||||
"gt_label": {
|
||||
"gender": 0,
|
||||
"wear": [
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
0
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"img_path": "b/subb/3.jpg",
|
||||
"gt_label": {
|
||||
"gender": 1,
|
||||
"wear": [
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
|
@ -79,7 +79,7 @@ class TestBaseDataset(TestCase):
|
|||
else:
|
||||
self.assertIn('The `CLASSES` meta info is not set.', repr(dataset))
|
||||
|
||||
self.assertIn("Haven't been initialized", repr(dataset))
|
||||
self.assertIn('Haven\'t been initialized', repr(dataset))
|
||||
dataset.full_init()
|
||||
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
|
||||
|
||||
|
@ -452,7 +452,7 @@ class TestCIFAR10(TestBaseDataset):
|
|||
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
|
||||
self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}",
|
||||
repr(dataset))
|
||||
|
||||
@classmethod
|
||||
|
@ -597,7 +597,7 @@ class TestVOC(TestBaseDataset):
|
|||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
self.assertIn("Haven't been initialized", repr(dataset))
|
||||
self.assertIn('Haven\'t been initialized', repr(dataset))
|
||||
dataset.full_init()
|
||||
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
|
||||
|
||||
|
@ -770,7 +770,7 @@ class TestMNIST(TestBaseDataset):
|
|||
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
|
||||
self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}",
|
||||
repr(dataset))
|
||||
|
||||
@classmethod
|
||||
|
@ -874,3 +874,70 @@ class TestCUB(TestBaseDataset):
|
|||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
||||
|
||||
class TestMultiTaskDataset(TestCase):
|
||||
DATASET_TYPE = 'MultiTaskDataset'
|
||||
|
||||
DEFAULT_ARGS = dict(
|
||||
data_root=ASSETS_ROOT,
|
||||
ann_file=osp.join(ASSETS_ROOT, 'multi-task.json'),
|
||||
pipeline=[])
|
||||
|
||||
def test_metainfo(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test default behavior
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
metainfo = {'tasks': ['gender', 'wear']}
|
||||
self.assertDictEqual(dataset.metainfo, metainfo)
|
||||
self.assertFalse(dataset.test_mode)
|
||||
|
||||
def test_parse_data_info(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
|
||||
data = dataset.parse_data_info({
|
||||
'img_path': 'a.jpg',
|
||||
'gt_label': {
|
||||
'gender': 0
|
||||
}
|
||||
})
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
'img_path': os.path.join(ASSETS_ROOT, 'a.jpg'),
|
||||
'gt_label': {
|
||||
'gender': 0
|
||||
}
|
||||
}, data)
|
||||
np.testing.assert_equal(data['gt_label']['gender'], 0)
|
||||
|
||||
# Test missing path
|
||||
with self.assertRaisesRegex(AssertionError, 'have `img_path` field'):
|
||||
dataset.parse_data_info(
|
||||
{'gt_label': {
|
||||
'gender': 0,
|
||||
'wear': [1, 0, 1, 0]
|
||||
}})
|
||||
|
||||
def test_repr(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
|
||||
task_doc = ('For 2 tasks\n gender \n wear ')
|
||||
self.assertIn(task_doc, repr(dataset))
|
||||
|
||||
def test_load_data_list(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test default behavior
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
|
||||
data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file'])
|
||||
self.assertIsInstance(data, list)
|
||||
np.testing.assert_equal(len(data), 3)
|
||||
np.testing.assert_equal(data[0]['gt_label'], {'gender': 0})
|
||||
np.testing.assert_equal(data[1]['gt_label'], {
|
||||
'gender': 0,
|
||||
'wear': [1, 0, 1, 0]
|
||||
})
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmengine.structures import LabelData
|
|||
from PIL import Image
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
@ -51,9 +51,8 @@ class TestPackClsInputs(unittest.TestCase):
|
|||
# Test without `img` and `gt_label`
|
||||
del data['img']
|
||||
del data['gt_label']
|
||||
with self.assertWarnsRegex(Warning, 'Cannot get "img"'):
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertNotIn('gt_label', results['data_samples'])
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertNotIn('gt_label', results['data_samples'])
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='PackClsInputs', meta_keys=['flip', 'img_shape'])
|
||||
|
@ -130,3 +129,54 @@ class TestCollect(unittest.TestCase):
|
|||
cfg = dict(type='Collect', keys=['img'])
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(repr(transform), "Collect(keys=['img'])")
|
||||
|
||||
|
||||
class TestPackMultiTaskInputs(unittest.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
|
||||
data = {
|
||||
'sample_idx': 1,
|
||||
'img_path': img_path,
|
||||
'ori_shape': (300, 400),
|
||||
'img_shape': (300, 400),
|
||||
'scale_factor': 1.0,
|
||||
'flip': False,
|
||||
'img': mmcv.imread(img_path),
|
||||
'gt_label': {
|
||||
'task1': 1,
|
||||
'task3': 3
|
||||
},
|
||||
}
|
||||
|
||||
cfg = dict(type='PackMultiTaskInputs', )
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIsInstance(results['inputs'], torch.Tensor)
|
||||
self.assertIn('data_samples', results)
|
||||
self.assertIsInstance(results['data_samples'], MultiTaskDataSample)
|
||||
self.assertIn('flip', results['data_samples'].task1.metainfo_keys())
|
||||
self.assertIsInstance(results['data_samples'].task1.gt_label,
|
||||
LabelData)
|
||||
|
||||
# Test grayscale image
|
||||
data['img'] = data['img'].mean(-1)
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIsInstance(results['inputs'], torch.Tensor)
|
||||
self.assertEqual(results['inputs'].shape, (1, 300, 400))
|
||||
|
||||
# Test without `img` and `gt_label`
|
||||
del data['img']
|
||||
del data['gt_label']
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertNotIn('gt_label', results['data_samples'])
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='PackMultiTaskInputs', meta_keys=['img_shape'])
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
rep = 'PackMultiTaskInputs(task_handlers={},'
|
||||
rep += ' multi_task_fields=(\'gt_label\',),'
|
||||
rep += ' meta_keys=[\'img_shape\'])'
|
||||
self.assertEqual(repr(transform), rep)
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmcls.evaluation.metrics import MultiTasksMetric
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
||||
|
||||
class MultiTaskMetric(TestCase):
|
||||
data_pred = [
|
||||
{
|
||||
'task0': torch.tensor([0.7, 0.0, 0.3]),
|
||||
'task1': torch.tensor([0.5, 0.2, 0.3])
|
||||
},
|
||||
{
|
||||
'task0': torch.tensor([0.0, 0.0, 1.0]),
|
||||
'task1': torch.tensor([0.0, 0.0, 1.0])
|
||||
},
|
||||
]
|
||||
data_gt = [{'task0': 0, 'task1': 2}, {'task1': 2}]
|
||||
|
||||
preds = []
|
||||
for i, pred in enumerate(data_pred):
|
||||
sample = {}
|
||||
for task_name in pred:
|
||||
task_sample = ClsDataSample().set_pred_score(pred[task_name])
|
||||
if task_name in data_gt[i]:
|
||||
task_sample.set_gt_label(data_gt[i][task_name])
|
||||
task_sample.set_field(True, 'eval_mask', field_type='metainfo')
|
||||
else:
|
||||
task_sample.set_field(
|
||||
False, 'eval_mask', field_type='metainfo')
|
||||
sample[task_name] = task_sample.to_dict()
|
||||
|
||||
preds.append(sample)
|
||||
data2 = zip([
|
||||
{
|
||||
'task0': torch.tensor([0.7, 0.0, 0.3]),
|
||||
'task1': {
|
||||
'task10': torch.tensor([0.5, 0.2, 0.3]),
|
||||
'task11': torch.tensor([0.4, 0.3, 0.3])
|
||||
}
|
||||
},
|
||||
{
|
||||
'task0': torch.tensor([0.0, 0.0, 1.0]),
|
||||
'task1': {
|
||||
'task10': torch.tensor([0.1, 0.6, 0.3]),
|
||||
'task11': torch.tensor([0.5, 0.2, 0.3])
|
||||
}
|
||||
},
|
||||
], [{
|
||||
'task0': 0,
|
||||
'task1': {
|
||||
'task10': 2,
|
||||
'task11': 0
|
||||
}
|
||||
}, {
|
||||
'task0': 2,
|
||||
'task1': {
|
||||
'task10': 1,
|
||||
'task11': 0
|
||||
}
|
||||
}])
|
||||
|
||||
pred2 = []
|
||||
for score, label in data2:
|
||||
sample = {}
|
||||
for task_name in score:
|
||||
if type(score[task_name]) != dict:
|
||||
task_sample = ClsDataSample().set_pred_score(score[task_name])
|
||||
task_sample.set_gt_label(label[task_name])
|
||||
sample[task_name] = task_sample.to_dict()
|
||||
sample[task_name]['eval_mask'] = True
|
||||
else:
|
||||
sample[task_name] = {}
|
||||
sample[task_name]['eval_mask'] = True
|
||||
for task_name2 in score[task_name]:
|
||||
task_sample = ClsDataSample().set_pred_score(
|
||||
score[task_name][task_name2])
|
||||
task_sample.set_gt_label(label[task_name][task_name2])
|
||||
sample[task_name][task_name2] = task_sample.to_dict()
|
||||
sample[task_name][task_name2]['eval_mask'] = True
|
||||
|
||||
pred2.append(sample)
|
||||
|
||||
pred3 = [{'task0': {'eval_mask': False}, 'task1': {'eval_mask': False}}]
|
||||
task_metrics = {
|
||||
'task0': [dict(type='Accuracy', topk=(1, ))],
|
||||
'task1': [
|
||||
dict(type='Accuracy', topk=(1, 3)),
|
||||
dict(type='SingleLabelMetric', items=['precision', 'recall'])
|
||||
]
|
||||
}
|
||||
task_metrics2 = {
|
||||
'task0': [dict(type='Accuracy', topk=(1, ))],
|
||||
'task1': [
|
||||
dict(
|
||||
type='MultiTasksMetric',
|
||||
task_metrics={
|
||||
'task10': [
|
||||
dict(type='Accuracy', topk=(1, 3)),
|
||||
dict(type='SingleLabelMetric', items=['precision'])
|
||||
],
|
||||
'task11': [dict(type='Accuracy', topk=(1, ))]
|
||||
})
|
||||
]
|
||||
}
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
|
||||
# Test with score (use score instead of label if score exists)
|
||||
metric = MultiTasksMetric(self.task_metrics)
|
||||
metric.process(None, self.preds)
|
||||
results = metric.evaluate(2)
|
||||
self.assertIsInstance(results, dict)
|
||||
self.assertAlmostEqual(results['task0_accuracy/top1'], 100)
|
||||
self.assertGreater(results['task1_single-label/precision'], 0)
|
||||
|
||||
# Test nested
|
||||
metric = MultiTasksMetric(self.task_metrics2)
|
||||
metric.process(None, self.pred2)
|
||||
results = metric.evaluate(2)
|
||||
self.assertIsInstance(results, dict)
|
||||
self.assertGreater(results['task1_task10_single-label/precision'], 0)
|
||||
self.assertGreater(results['task1_task11_accuracy/top1'], 0)
|
||||
|
||||
# Test with without any ground truth value
|
||||
metric = MultiTasksMetric(self.task_metrics)
|
||||
metric.process(None, self.pred3)
|
||||
results = metric.evaluate(2)
|
||||
self.assertIsInstance(results, dict)
|
||||
self.assertEqual(results['task0_Accuracy'], 0)
|
|
@ -10,7 +10,7 @@ import torch
|
|||
from mmengine import is_seq_of
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.structures import ClsDataSample, MultiTaskDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
@ -484,6 +484,142 @@ class TestMultiLabelLinearClsHead(TestMultiLabelClsHead):
|
|||
head(feats)
|
||||
|
||||
|
||||
class TestMultiTaskHead(TestCase):
|
||||
DEFAULT_ARGS = dict(
|
||||
type='MultiTaskHead', # <- Head config, depends on #675
|
||||
task_heads={
|
||||
'task0': dict(type='LinearClsHead', num_classes=3),
|
||||
'task1': dict(type='LinearClsHead', num_classes=6),
|
||||
},
|
||||
in_channels=10,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
)
|
||||
|
||||
DEFAULT_ARGS2 = dict(
|
||||
type='MultiTaskHead', # <- Head config, depends on #675
|
||||
task_heads={
|
||||
'task0':
|
||||
dict(
|
||||
type='MultiTaskHead',
|
||||
task_heads={
|
||||
'task00': dict(type='LinearClsHead', num_classes=3),
|
||||
'task01': dict(type='LinearClsHead', num_classes=6),
|
||||
}),
|
||||
'task1':
|
||||
dict(type='LinearClsHead', num_classes=6)
|
||||
},
|
||||
in_channels=10,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
)
|
||||
|
||||
def test_forward(self):
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
# return the last item (same as pre_logits)
|
||||
feats = (torch.rand(4, 10), )
|
||||
outs = head(feats)
|
||||
self.assertEqual(outs['task0'].shape, (4, 3))
|
||||
self.assertEqual(outs['task1'].shape, (4, 6))
|
||||
self.assertTrue(isinstance(outs, dict))
|
||||
|
||||
def test_loss(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = []
|
||||
|
||||
for _ in range(4):
|
||||
data_sample = MultiTaskDataSample()
|
||||
for task_name in self.DEFAULT_ARGS['task_heads']:
|
||||
task_sample = ClsDataSample().set_gt_label(1)
|
||||
data_sample.set_field(task_sample, task_name)
|
||||
data_samples.append(data_sample)
|
||||
# with cal_acc = False
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(
|
||||
losses.keys(),
|
||||
{'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'})
|
||||
self.assertGreater(losses['task0_loss'].item(), 0)
|
||||
self.assertGreater(losses['task1_loss'].item(), 0)
|
||||
|
||||
def test_predict(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = []
|
||||
|
||||
for _ in range(4):
|
||||
data_sample = MultiTaskDataSample()
|
||||
for task_name in self.DEFAULT_ARGS['task_heads']:
|
||||
task_sample = ClsDataSample().set_gt_label(1)
|
||||
data_sample.set_field(task_sample, task_name)
|
||||
data_samples.append(data_sample)
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
# with without data_samples
|
||||
predictions = head.predict(feats)
|
||||
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
||||
for pred in predictions:
|
||||
self.assertIn('task0', pred)
|
||||
task0_sample = predictions[0].task0
|
||||
self.assertTrue(type(task0_sample.pred_label.score), 'torch.tensor')
|
||||
|
||||
# with with data_samples
|
||||
predictions = head.predict(feats, data_samples)
|
||||
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
||||
for sample, pred in zip(data_samples, predictions):
|
||||
self.assertIs(sample, pred)
|
||||
self.assertIn('task0', pred)
|
||||
|
||||
def test_loss_empty_data_sample(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = []
|
||||
|
||||
for _ in range(4):
|
||||
data_sample = MultiTaskDataSample()
|
||||
data_samples.append(data_sample)
|
||||
# with cal_acc = False
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
losses = head.loss(feats, data_samples)
|
||||
self.assertEqual(
|
||||
losses.keys(),
|
||||
{'task0_loss', 'task0_mask_size', 'task1_loss', 'task1_mask_size'})
|
||||
self.assertEqual(losses['task0_loss'].item(), 0)
|
||||
self.assertEqual(losses['task1_loss'].item(), 0)
|
||||
|
||||
def test_nested_multi_task_loss(self):
|
||||
|
||||
head = MODELS.build(self.DEFAULT_ARGS2)
|
||||
# return the last item (same as pre_logits)
|
||||
feats = (torch.rand(4, 10), )
|
||||
outs = head(feats)
|
||||
self.assertEqual(outs['task0']['task01'].shape, (4, 6))
|
||||
self.assertTrue(isinstance(outs, dict))
|
||||
self.assertTrue(isinstance(outs['task0'], dict))
|
||||
|
||||
def test_nested_invalid_sample(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
gt_label = {'task0': 1, 'task1': 1}
|
||||
head = MODELS.build(self.DEFAULT_ARGS2)
|
||||
data_sample = MultiTaskDataSample()
|
||||
for task_name in gt_label:
|
||||
task_sample = ClsDataSample().set_gt_label(gt_label[task_name])
|
||||
data_sample.set_field(task_sample, task_name)
|
||||
with self.assertRaises(Exception):
|
||||
head.loss(feats, data_sample)
|
||||
|
||||
def test_nested_invalid_sample2(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
data_sample = MultiTaskDataSample()
|
||||
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
|
||||
data_sample.set_field(task_sample, 'task1')
|
||||
data_sample.set_field(MultiTaskDataSample(), 'task0')
|
||||
for task_name in gt_label['task0']:
|
||||
task_sample = ClsDataSample().set_gt_label(
|
||||
gt_label['task0'][task_name])
|
||||
data_sample.task0.set_field(task_sample, task_name)
|
||||
with self.assertRaises(Exception):
|
||||
head.loss(feats, data_sample)
|
||||
|
||||
|
||||
class TestArcFaceClsHead(TestCase):
|
||||
DEFAULT_ARGS = dict(type='ArcFaceClsHead', in_channels=10, num_classes=5)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
import torch
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.structures import ClsDataSample, MultiTaskDataSample
|
||||
|
||||
|
||||
class TestClsDataSample(TestCase):
|
||||
|
@ -122,3 +122,20 @@ class TestClsDataSample(TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, 'but got 2'):
|
||||
data_sample.set_pred_score(
|
||||
torch.tensor([[0.1, 0.1, 0.6, 0.1, 0.1]]))
|
||||
|
||||
|
||||
class TestMultiTaskDataSample(TestCase):
|
||||
|
||||
def test_multi_task_data_sample(self):
|
||||
gt_label = {'task0': {'task00': 1, 'task01': 1}, 'task1': 1}
|
||||
data_sample = MultiTaskDataSample()
|
||||
task_sample = ClsDataSample().set_gt_label(gt_label['task1'])
|
||||
data_sample.set_field(task_sample, 'task1')
|
||||
data_sample.set_field(MultiTaskDataSample(), 'task0')
|
||||
for task_name in gt_label['task0']:
|
||||
task_sample = ClsDataSample().set_gt_label(
|
||||
gt_label['task0'][task_name])
|
||||
data_sample.task0.set_field(task_sample, task_name)
|
||||
self.assertIsInstance(data_sample.task0, MultiTaskDataSample)
|
||||
self.assertIsInstance(data_sample.task1, ClsDataSample)
|
||||
self.assertIsInstance(data_sample.task0.task00, ClsDataSample)
|
||||
|
|
Loading…
Reference in New Issue