mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Remove multi-task configs. Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com> Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz <maroineamil99@gmail.com> Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
345 lines
11 KiB
Python
345 lines
11 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os.path as osp
|
|
from os import PathLike
|
|
from typing import Optional, Sequence
|
|
|
|
import mmengine
|
|
from mmcv.transforms import Compose
|
|
from mmengine.fileio import FileClient
|
|
|
|
from .builder import DATASETS
|
|
|
|
|
|
def expanduser(path):
|
|
if isinstance(path, (str, PathLike)):
|
|
return osp.expanduser(path)
|
|
else:
|
|
return path
|
|
|
|
|
|
def isabs(uri):
|
|
return osp.isabs(uri) or ('://' in uri)
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class MultiTaskDataset:
|
|
"""Custom dataset for multi-task dataset.
|
|
|
|
To use the dataset, please generate and provide an annotation file in the
|
|
below format:
|
|
|
|
.. code-block:: json
|
|
|
|
{
|
|
"metainfo": {
|
|
"tasks":
|
|
[
|
|
'gender'
|
|
'wear'
|
|
]
|
|
},
|
|
"data_list": [
|
|
{
|
|
"img_path": "a.jpg",
|
|
gt_label:{
|
|
"gender": 0,
|
|
"wear": [1, 0, 1, 0]
|
|
}
|
|
},
|
|
{
|
|
"img_path": "b.jpg",
|
|
gt_label:{
|
|
"gender": 1,
|
|
"wear": [1, 0, 1, 0]
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
Assume we put our dataset in the ``data/mydataset`` folder in the
|
|
repository and organize it as the below format: ::
|
|
|
|
mmclassification/
|
|
└── data
|
|
└── mydataset
|
|
├── annotation
|
|
│ ├── train.json
|
|
│ ├── test.json
|
|
│ └── val.json
|
|
├── train
|
|
│ ├── a.jpg
|
|
│ └── ...
|
|
├── test
|
|
│ ├── b.jpg
|
|
│ └── ...
|
|
└── val
|
|
├── c.jpg
|
|
└── ...
|
|
|
|
We can use the below config to build datasets:
|
|
|
|
.. code:: python
|
|
|
|
>>> from mmcls.datasets import build_dataset
|
|
>>> train_cfg = dict(
|
|
... type="MultiTaskDataset",
|
|
... ann_file="annotation/train.json",
|
|
... data_root="data/mydataset",
|
|
... # The `img_path` field in the train annotation file is relative
|
|
... # to the `train` folder.
|
|
... data_prefix='train',
|
|
... )
|
|
>>> train_dataset = build_dataset(train_cfg)
|
|
|
|
Or we can put all files in the same folder: ::
|
|
|
|
mmclassification/
|
|
└── data
|
|
└── mydataset
|
|
├── train.json
|
|
├── test.json
|
|
├── val.json
|
|
├── a.jpg
|
|
├── b.jpg
|
|
├── c.jpg
|
|
└── ...
|
|
|
|
And we can use the below config to build datasets:
|
|
|
|
.. code:: python
|
|
|
|
>>> from mmcls.datasets import build_dataset
|
|
>>> train_cfg = dict(
|
|
... type="MultiTaskDataset",
|
|
... ann_file="train.json",
|
|
... data_root="data/mydataset",
|
|
... # the `data_prefix` is not required since all paths are
|
|
... # relative to the `data_root`.
|
|
... )
|
|
>>> train_dataset = build_dataset(train_cfg)
|
|
|
|
|
|
Args:
|
|
ann_file (str): The annotation file path. It can be either absolute
|
|
path or relative path to the ``data_root``.
|
|
metainfo (dict, optional): The extra meta information. It should be
|
|
a dict with the same format as the ``"metainfo"`` field in the
|
|
annotation file. Defaults to None.
|
|
data_root (str, optional): The root path of the data directory. It's
|
|
the prefix of the ``data_prefix`` and the ``ann_file``. And it can
|
|
be a remote path like "s3://openmmlab/xxx/". Defaults to None.
|
|
data_prefix (str, optional): The base folder relative to the
|
|
``data_root`` for the ``"img_path"`` field in the annotation file.
|
|
Defaults to None.
|
|
pipeline (Sequence[dict]): A list of dict, where each element
|
|
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
|
|
Defaults to an empty tuple.
|
|
test_mode (bool): in train mode or test mode. Defaults to False.
|
|
file_client_args (dict, optional): Arguments to instantiate a
|
|
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
|
If None, automatically inference from the ``data_root``.
|
|
Defaults to None.
|
|
"""
|
|
METAINFO = dict()
|
|
|
|
def __init__(self,
|
|
ann_file: str,
|
|
metainfo: Optional[dict] = None,
|
|
data_root: Optional[str] = None,
|
|
data_prefix: Optional[str] = None,
|
|
pipeline: Sequence = (),
|
|
test_mode: bool = False,
|
|
file_client_args: Optional[dict] = None):
|
|
|
|
self.data_root = expanduser(data_root)
|
|
|
|
# Inference the file client
|
|
if self.data_root is not None:
|
|
file_client = FileClient.infer_client(
|
|
file_client_args, uri=self.data_root)
|
|
else:
|
|
file_client = FileClient(file_client_args)
|
|
self.file_client: FileClient = file_client
|
|
|
|
self.ann_file = self._join_root(expanduser(ann_file))
|
|
self.data_prefix = self._join_root(data_prefix)
|
|
|
|
self.test_mode = test_mode
|
|
self.pipeline = Compose(pipeline)
|
|
self.data_list = self.load_data_list(self.ann_file, metainfo)
|
|
|
|
def _join_root(self, path):
|
|
"""Join ``self.data_root`` with the specified path.
|
|
|
|
If the path is an absolute path, just return the path. And if the
|
|
path is None, return ``self.data_root``.
|
|
|
|
Examples:
|
|
>>> self.data_root = 'a/b/c'
|
|
>>> self._join_root('d/e/')
|
|
'a/b/c/d/e'
|
|
>>> self._join_root('https://openmmlab.com')
|
|
'https://openmmlab.com'
|
|
>>> self._join_root(None)
|
|
'a/b/c'
|
|
"""
|
|
if path is None:
|
|
return self.data_root
|
|
if isabs(path):
|
|
return path
|
|
|
|
joined_path = self.file_client.join_path(self.data_root, path)
|
|
return joined_path
|
|
|
|
@classmethod
|
|
def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
|
|
"""Collect meta information from the dictionary of meta.
|
|
|
|
Args:
|
|
in_metainfo (dict): Meta information dict.
|
|
|
|
Returns:
|
|
dict: Parsed meta information.
|
|
"""
|
|
# `cls.METAINFO` will be overwritten by in_meta
|
|
metainfo = copy.deepcopy(cls.METAINFO)
|
|
if in_metainfo is None:
|
|
return metainfo
|
|
|
|
metainfo.update(in_metainfo)
|
|
|
|
return metainfo
|
|
|
|
def load_data_list(self, ann_file, metainfo_override=None):
|
|
"""Load annotations from an annotation file.
|
|
|
|
Args:
|
|
ann_file (str): Absolute annotation file path if ``self.root=None``
|
|
or relative path if ``self.root=/path/to/data/``.
|
|
|
|
Returns:
|
|
list[dict]: A list of annotation.
|
|
"""
|
|
annotations = mmengine.load(ann_file)
|
|
if not isinstance(annotations, dict):
|
|
raise TypeError(f'The annotations loaded from annotation file '
|
|
f'should be a dict, but got {type(annotations)}!')
|
|
if 'data_list' not in annotations:
|
|
raise ValueError('The annotation file must have the `data_list` '
|
|
'field.')
|
|
metainfo = annotations.get('metainfo', {})
|
|
raw_data_list = annotations['data_list']
|
|
|
|
# Set meta information.
|
|
assert isinstance(metainfo, dict), 'The `metainfo` field in the '\
|
|
f'annotation file should be a dict, but got {type(metainfo)}'
|
|
if metainfo_override is not None:
|
|
assert isinstance(metainfo_override, dict), 'The `metainfo` ' \
|
|
f'argument should be a dict, but got {type(metainfo_override)}'
|
|
metainfo.update(metainfo_override)
|
|
self._metainfo = self._get_meta_info(metainfo)
|
|
|
|
data_list = []
|
|
for i, raw_data in enumerate(raw_data_list):
|
|
try:
|
|
data_list.append(self.parse_data_info(raw_data))
|
|
except AssertionError as e:
|
|
raise RuntimeError(
|
|
f'The format check fails during parse the item {i} of '
|
|
f'the annotation file with error: {e}')
|
|
return data_list
|
|
|
|
def parse_data_info(self, raw_data):
|
|
"""Parse raw annotation to target format.
|
|
|
|
This method will return a dict which contains the data information of a
|
|
sample.
|
|
|
|
Args:
|
|
raw_data (dict): Raw data information load from ``ann_file``
|
|
|
|
Returns:
|
|
dict: Parsed annotation.
|
|
"""
|
|
assert isinstance(raw_data, dict), \
|
|
f'The item should be a dict, but got {type(raw_data)}'
|
|
assert 'img_path' in raw_data, \
|
|
"The item doesn't have `img_path` field."
|
|
data = dict(
|
|
img_path=self._join_root(raw_data['img_path']),
|
|
gt_label=raw_data['gt_label'],
|
|
)
|
|
return data
|
|
|
|
@property
|
|
def metainfo(self) -> dict:
|
|
"""Get meta information of dataset.
|
|
|
|
Returns:
|
|
dict: meta information collected from ``cls.METAINFO``,
|
|
annotation file and metainfo argument during instantiation.
|
|
"""
|
|
return copy.deepcopy(self._metainfo)
|
|
|
|
def prepare_data(self, idx):
|
|
"""Get data processed by ``self.pipeline``.
|
|
|
|
Args:
|
|
idx (int): The index of ``data_info``.
|
|
|
|
Returns:
|
|
Any: Depends on ``self.pipeline``.
|
|
"""
|
|
results = copy.deepcopy(self.data_list[idx])
|
|
return self.pipeline(results)
|
|
|
|
def __len__(self):
|
|
"""Get the length of the whole dataset.
|
|
|
|
Returns:
|
|
int: The length of filtered dataset.
|
|
"""
|
|
return len(self.data_list)
|
|
|
|
def __getitem__(self, idx):
|
|
"""Get the idx-th image and data information of dataset after
|
|
``self.pipeline``.
|
|
|
|
Args:
|
|
idx (int): The index of of the data.
|
|
|
|
Returns:
|
|
dict: The idx-th image and data information after
|
|
``self.pipeline``.
|
|
"""
|
|
return self.prepare_data(idx)
|
|
|
|
def __repr__(self):
|
|
"""Print the basic information of the dataset.
|
|
|
|
Returns:
|
|
str: Formatted string.
|
|
"""
|
|
head = 'Dataset ' + self.__class__.__name__
|
|
body = [f'Number of samples: \t{self.__len__()}']
|
|
if self.data_root is not None:
|
|
body.append(f'Root location: \t{self.data_root}')
|
|
body.append(f'Annotation file: \t{self.ann_file}')
|
|
if self.data_prefix is not None:
|
|
body.append(f'Prefix of images: \t{self.data_prefix}')
|
|
# -------------------- extra repr --------------------
|
|
tasks = self.metainfo['tasks']
|
|
body.append(f'For {len(tasks)} tasks')
|
|
for task in tasks:
|
|
body.append(f' {task} ')
|
|
# ----------------------------------------------------
|
|
|
|
if len(self.pipeline.transforms) > 0:
|
|
body.append('With transforms:')
|
|
for t in self.pipeline.transforms:
|
|
body.append(f' {t}')
|
|
|
|
lines = [head] + [' ' * 4 + line for line in body]
|
|
return '\n'.join(lines)
|