[Feature] Add `BaseDataset`, `CustomDataset`, `ImageNet` and `ImageNet21k`

pull/913/head
mzr1996 2022-05-18 15:55:28 +00:00
parent 98377df512
commit 27e685fe10
9 changed files with 1486 additions and 1755 deletions

View File

@ -7,8 +7,7 @@ from .cub import CUB
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
KFoldDataset, RepeatDataset)
from .imagenet import ImageNet
from .imagenet21k import ImageNet21k
from .imagenet import ImageNet, ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .samplers import DistributedSampler, RepeatAugSampler

View File

@ -1,58 +1,120 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from abc import ABCMeta, abstractmethod
from os import PathLike
from typing import List
from typing import List, Optional, Sequence, Union
import mmcv
import mmengine
import numpy as np
from torch.utils.data import Dataset
from mmengine.dataset import BaseDataset as _BaseDataset
from mmcls.core.evaluation import precision_recall_f1, support
from mmcls.models.losses import accuracy
from .pipelines import Compose
from .builder import DATASETS
def expanduser(path):
"""Expand ~ and ~user constructions.
If user or $HOME is unknown, do nothing.
"""
if isinstance(path, (str, PathLike)):
return osp.expanduser(path)
else:
return path
class BaseDataset(Dataset, metaclass=ABCMeta):
"""Base dataset.
@DATASETS.register_module()
class BaseDataset(_BaseDataset):
"""Base dataset for image classification task.
This dataset support annotation file in `OpenMMLab 2.0 style annotation
format`.
.. _OpenMMLab 2.0 style annotation format:
https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
Comparing with the :class:`mmengine.BaseDataset`, this class implemented
several useful methods.
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
test_mode (bool): in train mode or test mode
"""
ann_file (str): Annotation file path.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (str | dict, optional): Prefix for training data. Defaults
to None.
filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
dataset. Defaults to None which means using all ``data_infos``.
serialize_data (bool, optional): Whether to hold memory using
serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy. Defaults
to True.
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=False``. Defaults to False.
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Defaults to 1000.
classes (str | Sequence[str], optional): Specify names of classes.
CLASSES = None
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use categories information in ``metainfo`` argument,
annotation file or the class attribute ``METAINFO``.
Defaults to None.
""" # noqa: E501
def __init__(self,
data_prefix,
pipeline,
classes=None,
ann_file=None,
test_mode=False):
super(BaseDataset, self).__init__()
self.data_prefix = expanduser(data_prefix)
self.pipeline = Compose(pipeline)
self.CLASSES = self.get_classes(classes)
self.ann_file = expanduser(ann_file)
self.test_mode = test_mode
self.data_infos = self.load_annotations()
ann_file: str,
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: Union[str, dict, None] = None,
filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True,
pipeline: Sequence = (),
test_mode: bool = False,
lazy_init: bool = False,
max_refetch: int = 1000,
classes: Union[str, Sequence[str], None] = None):
if isinstance(data_prefix, str):
data_prefix = dict(img_path=expanduser(data_prefix))
elif data_prefix is None:
data_prefix = dict(img_path=None)
@abstractmethod
def load_annotations(self):
pass
ann_file = expanduser(ann_file)
metainfo = self._compat_classes(metainfo, classes)
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
filter_cfg=filter_cfg,
indices=indices,
serialize_data=serialize_data,
pipeline=pipeline,
test_mode=test_mode,
lazy_init=lazy_init,
max_refetch=max_refetch)
@property
def img_prefix(self):
"""The prefix of images."""
return self.data_prefix['img_path']
@property
def CLASSES(self):
"""Return all categories names."""
return self._metainfo.get('CLASSES', None)
@property
def class_to_idx(self):
@ -62,7 +124,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
dict: mapping from class name to class index.
"""
return {_class: i for i, _class in enumerate(self.CLASSES)}
return {cat: i for i, cat in enumerate(self.CLASSES)}
def get_gt_labels(self):
"""Get all ground-truth labels (categories).
@ -71,7 +133,8 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
np.ndarray: categories for all images.
"""
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
gt_labels = np.array(
[self.get_data_info(i)['gt_label'] for i in range(len(self))])
return gt_labels
def get_cat_ids(self, idx: int) -> List[int]:
@ -84,138 +147,64 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
cat_ids (List[int]): Image category of specified index.
"""
return [int(self.data_infos[idx]['gt_label'])]
def prepare_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def __len__(self):
return len(self.data_infos)
def __getitem__(self, idx):
return self.prepare_data(idx)
@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if classes is None:
return cls.CLASSES
return [int(self.get_data_info(idx)['gt_label'])]
def _compat_classes(self, metainfo, classes):
"""Merge the old style ``classes`` arguments to ``metainfo``."""
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(expanduser(classes))
class_names = mmengine.list_from_file(expanduser(classes))
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
elif classes is not None:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
if metainfo is None:
metainfo = {}
def evaluate(self,
results,
metric='accuracy',
metric_options=None,
indices=None,
logger=None):
"""Evaluate the dataset.
if classes is not None:
metainfo = {'CLASSES': tuple(class_names), **metainfo}
return metainfo
def full_init(self):
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
True."""
super().full_init()
# To support the standard OpenMMLab 2.0 annotation format. Generate
# metainfo in internal format from standard metainfo format.
if 'categories' in self._metainfo and 'CLASSES' not in self._metainfo:
categories = sorted(
self._metainfo['categories'], key=lambda x: x['id'])
self._metainfo['CLASSES'] = tuple(
[cat['category_name'] for cat in categories])
def __repr__(self):
"""Print the basic information of the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
Default value is `accuracy`.
metric_options (dict, optional): Options for calculating metrics.
Allowed keys are 'topk', 'thrs' and 'average_mode'.
Defaults to None.
indices (list, optional): The indices of samples corresponding to
the results. Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
dict: evaluation results
str: Formatted string.
"""
if metric_options is None:
metric_options = {'topk': (1, 5)}
if isinstance(metric, str):
metrics = [metric]
head = 'Dataset ' + self.__class__.__name__
body = []
if self._fully_initialized:
body.append(f'Number of samples: \t{self.__len__()}')
else:
metrics = metric
allowed_metrics = [
'accuracy', 'precision', 'recall', 'f1_score', 'support'
]
eval_results = {}
results = np.vstack(results)
gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]
num_imgs = len(results)
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
'be of the same length as gt_labels.'
body.append("Haven't been initialized")
invalid_metrics = set(metrics) - set(allowed_metrics)
if len(invalid_metrics) != 0:
raise ValueError(f'metric {invalid_metrics} is not supported.')
if self.CLASSES is not None:
body.append(f'Number of categories: \t{len(self.CLASSES)}')
else:
body.append('The `CLASSES` meta info is not set.')
topk = metric_options.get('topk', (1, 5))
thrs = metric_options.get('thrs')
average_mode = metric_options.get('average_mode', 'macro')
body.append(f'Annotation file: \t{self.ann_file}')
body.append(f'Prefix of images: \t{self.img_prefix}')
if 'accuracy' in metrics:
if thrs is not None:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
else:
acc = accuracy(results, gt_labels, topk=topk)
if isinstance(topk, tuple):
eval_results_ = {
f'accuracy_top-{k}': a
for k, a in zip(topk, acc)
}
else:
eval_results_ = {'accuracy': acc}
if isinstance(thrs, tuple):
for key, values in eval_results_.items():
eval_results.update({
f'{key}_thr_{thr:.2f}': value.item()
for thr, value in zip(thrs, values)
})
else:
eval_results.update(
{k: v.item()
for k, v in eval_results_.items()})
if len(self.pipeline.transforms) > 0:
body.append('With transforms:')
for t in self.pipeline.transforms:
body.append(f' {t}')
if 'support' in metrics:
support_value = support(
results, gt_labels, average_mode=average_mode)
eval_results['support'] = support_value
precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
if thrs is not None:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
else:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode)
for key, values in zip(precision_recall_f1_keys,
precision_recall_f1_values):
if key in metrics:
if isinstance(thrs, tuple):
eval_results.update({
f'{key}_thr_{thr:.2f}': value
for thr, value in zip(thrs, values)
})
else:
eval_results[key] = values
return eval_results
lines = [head] + [' ' * 4 + line for line in body]
return '\n'.join(lines)

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
from mmcv import FileClient
import mmengine
from mmengine import FileClient
from mmcls.registry import DATASETS
from mmcls.utils import get_root_logger
from .base_dataset import BaseDataset
@ -104,7 +103,8 @@ class CustomDataset(BaseDataset):
folder_2/nsdf3.png 3
...
Please specify the name of categories by the argument ``classes``.
Please specify the name of categories by the argument ``classes``
or ``metainfo``.
2. The samples are arranged in the specific way: ::
@ -124,58 +124,61 @@ class CustomDataset(BaseDataset):
first way, otherwise, try the second way.
Args:
data_prefix (str): The path of data directory.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, use ``cls.CLASSES`` or the names of sub folders
(If use the second way to arrange samples).
Defaults to None.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
ann_file (str, optional): Annotation file path. Defaults to None.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (str | dict, optional): Prefix for training data. Defaults
to None.
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=False``. Defaults to False.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_prefix: str,
pipeline: Sequence = (),
classes: Union[str, Sequence[str], None] = None,
ann_file: Optional[str] = None,
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: Union[str, dict, None] = None,
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
'.bmp', '.pgm', '.tif'),
test_mode: bool = False,
file_client_args: Optional[dict] = None):
lazy_init: bool = False,
**kwargs):
assert (ann_file is not None or data_prefix is not None
or data_root is not None), \
'One of `ann_file`, `data_root` and `data_prefix` must '\
'be specified.'
self.extensions = tuple(set([i.lower() for i in extensions]))
self.file_client_args = file_client_args
super().__init__(
# The base class requires string ann_file but this class doesn't
ann_file=ann_file if ann_file is not None else '',
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
pipeline=pipeline,
classes=classes,
ann_file=ann_file,
test_mode=test_mode)
# Force to lazy_init for some modification before loading data.
lazy_init=True,
**kwargs)
def _find_samples(self):
if ann_file is None:
self.ann_file = None
# Full initialize the dataset.
if not lazy_init:
self.full_init()
def _find_samples(self, file_client):
"""find samples from ``data_prefix``."""
file_client = FileClient.infer_client(self.file_client_args,
self.data_prefix)
classes, folder_to_idx = find_folders(self.data_prefix, file_client)
classes, folder_to_idx = find_folders(self.img_prefix, file_client)
samples, empty_classes = get_samples(
self.data_prefix,
self.img_prefix,
folder_to_idx,
is_valid_file=self.is_valid_file,
file_client=file_client,
@ -192,37 +195,42 @@ class CustomDataset(BaseDataset):
f'the number of specified classes ({len(self.CLASSES)}). ' \
'Please check the data folder.'
else:
self.CLASSES = classes
self._metainfo['CLASSES'] = tuple(classes)
if empty_classes:
warnings.warn(
logger = get_root_logger()
logger.warning(
'Found no valid file in the folder '
f'{", ".join(empty_classes)}. '
f"Supported extensions are: {', '.join(self.extensions)}",
UserWarning)
f"Supported extensions are: {', '.join(self.extensions)}")
self.folder_to_idx = folder_to_idx
return samples
def load_annotations(self):
def load_data_list(self):
"""Load image paths and gt_labels."""
if self.ann_file is None:
samples = self._find_samples()
elif isinstance(self.ann_file, str):
lines = mmcv.list_from_file(
self.ann_file, file_client_args=self.file_client_args)
samples = [x.strip().rsplit(' ', 1) for x in lines]
else:
raise TypeError('ann_file must be a str or None')
if self.img_prefix is not None:
file_client = FileClient.infer_client(uri=self.img_prefix)
data_infos = []
if self.ann_file is None:
samples = self._find_samples(file_client)
else:
lines = mmengine.list_from_file(self.ann_file)
samples = [x.strip().rsplit(' ', 1) for x in lines]
def add_prefix(filename, prefix=None):
if prefix is None:
return filename
else:
return file_client.join_path(prefix, filename)
data_list = []
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
img_path = add_prefix(filename, self.img_prefix)
info = {'img_path': img_path, 'gt_label': int(gt_label)}
data_list.append(info)
return data_list
def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample."""

File diff suppressed because it is too large Load Diff

View File

@ -1,174 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import gc
import pickle
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
from mmcls.registry import DATASETS
from .custom import CustomDataset
@DATASETS.register_module()
class ImageNet21k(CustomDataset):
"""ImageNet21k Dataset.
Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
and 1.4B files. This class has improved the following points on the
basis of the class ``ImageNet``, in order to save memory, we enable the
``serialize_data`` optional by default. With this option, the annotation
won't be stored in the list ``data_infos``, but be serialized as an
array.
Args:
data_prefix (str): The path of data directory.
pipeline (Sequence[dict]): A list of dict, where each element
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
Defaults to an empty tuple.
classes (str | Sequence[str], optional): Specify names of classes.
- If is string, it should be a file path, and the every line of
the file is a name of a class.
- If is a sequence of string, every item is a name of class.
- If is None, the object won't have category information.
(Not recommended)
Defaults to None.
ann_file (str, optional): The annotation file. If is string, read
samples paths from the ann_file. If is None, find samples in
``data_prefix``. Defaults to None.
serialize_data (bool): Whether to hold memory using serialized objects,
when enabled, data loader workers can use shared RAM from master
process instead of making a copy. Defaults to True.
multi_label (bool): Not implement by now. Use multi label or not.
Defaults to False.
recursion_subdir(bool): Deprecated, and the dataset will recursively
get all images now.
test_mode (bool): In train mode or test mode. It's only a mark and
won't be used in this class. Defaults to False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
If None, automatically inference from the specified path.
Defaults to None.
"""
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
CLASSES = None
def __init__(self,
data_prefix: str,
pipeline: Sequence = (),
classes: Union[str, Sequence[str], None] = None,
ann_file: Optional[str] = None,
serialize_data: bool = True,
multi_label: bool = False,
recursion_subdir: bool = True,
test_mode=False,
file_client_args: Optional[dict] = None):
assert recursion_subdir, 'The `recursion_subdir` option is ' \
'deprecated. Now the dataset will recursively get all images.'
if multi_label:
raise NotImplementedError(
'The `multi_label` option is not supported by now.')
self.multi_label = multi_label
self.serialize_data = serialize_data
if ann_file is None:
warnings.warn(
'The ImageNet21k dataset is large, and scanning directory may '
'consume long time. Considering to specify the `ann_file` to '
'accelerate the initialization.', UserWarning)
if classes is None:
warnings.warn(
'The CLASSES is not stored in the `ImageNet21k` class. '
'Considering to specify the `classes` argument if you need '
'do inference on the ImageNet-21k dataset', UserWarning)
super().__init__(
data_prefix=data_prefix,
pipeline=pipeline,
classes=classes,
ann_file=ann_file,
extensions=self.IMG_EXTENSIONS,
test_mode=test_mode,
file_client_args=file_client_args)
if self.serialize_data:
self.data_infos_bytes, self.data_address = self._serialize_data()
# Empty cache for preventing making multiple copies of
# `self.data_infos` when loading data multi-processes.
self.data_infos.clear()
gc.collect()
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category id by index.
Args:
idx (int): Index of data.
Returns:
cat_ids (List[int]): Image category of specified index.
"""
return [int(self.get_data_info(idx)['gt_label'])]
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): The index of data.
Returns:
dict: The idx-th annotation of the dataset.
"""
if self.serialize_data:
start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
end_addr = self.data_address[idx].item()
bytes = memoryview(self.data_infos_bytes[start_addr:end_addr])
data_info = pickle.loads(bytes)
else:
data_info = self.data_infos[idx]
return data_info
def prepare_data(self, idx):
data_info = self.get_data_info(idx)
return self.pipeline(data_info)
def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]:
"""Serialize ``self.data_infos`` to save memory when launching multiple
workers in data loading. This function will be called in ``full_init``.
Hold memory using serialized objects, and data loader workers can use
shared RAM from master process instead of making a copy.
Returns:
Tuple[np.ndarray, np.ndarray]: serialize result and corresponding
address.
"""
def _serialize(data):
buffer = pickle.dumps(data, protocol=4)
return np.frombuffer(buffer, dtype=np.uint8)
serialized_data_infos_list = [_serialize(x) for x in self.data_infos]
address_list = np.asarray([len(x) for x in serialized_data_infos_list],
dtype=np.int64)
data_address: np.ndarray = np.cumsum(address_list)
serialized_data_infos = np.concatenate(serialized_data_infos_list)
return serialized_data_infos, data_address
def __len__(self) -> int:
"""Get the length of filtered dataset and automatically call
``full_init`` if the dataset has not been fully init.
Returns:
int: The length of filtered dataset.
"""
if self.serialize_data:
return len(self.data_address)
else:
return len(self.data_infos)

View File

@ -0,0 +1,28 @@
{
"metainfo": {
"categories": [
{
"category_name": "first",
"id": 0
},
{
"category_name": "second",
"id": 1
}
]
},
"data_list": [
{
"img_path": "a/1.JPG",
"gt_label": 0
},
{
"img_path": "b/2.jpeg",
"gt_label": 1
},
{
"img_path": "b/subb/2.jpeg",
"gt_label": 1
}
]
}

View File

@ -1,3 +1,3 @@
a/1.JPG 0
b/2.jpeg 1
b/subb/2.jpeg 1
b/subb/3.jpg 1

View File

@ -1,256 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
# import os
import os.path as osp
import pickle
import tempfile
from unittest import TestCase
from unittest.mock import patch
from unittest.mock import MagicMock
import numpy as np
import torch
# import pickle
# import tempfile
from mmengine.registry import TRANSFORMS
from mmcls.datasets import DATASETS
from mmcls.datasets import BaseDataset as _BaseDataset
from mmcls.datasets import MultiLabelDataset as _MultiLabelDataset
from mmcls.utils import get_root_logger
# import torch
mmcls_logger = get_root_logger()
ASSETS_ROOT = osp.abspath(
osp.join(osp.dirname(__file__), '../../data/dataset'))
class BaseDataset(_BaseDataset):
def load_annotations(self):
pass
class MultiLabelDataset(_MultiLabelDataset):
def load_annotations(self):
pass
DATASETS.module_dict['BaseDataset'] = BaseDataset
DATASETS.module_dict['MultiLabelDataset'] = MultiLabelDataset
class TestBaseDataset(TestCase):
DATASET_TYPE = 'BaseDataset'
DEFAULT_ARGS = dict(data_prefix='', pipeline=[])
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.json')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
with patch.object(dataset_class, 'load_annotations'):
# Test default behavior
cfg = {**self.DEFAULT_ARGS, 'classes': None, 'ann_file': None}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
self.assertFalse(dataset.test_mode)
self.assertIsNone(dataset.ann_file)
# Test loading metainfo from ann_file
cfg = {**self.DEFAULT_ARGS, 'metainfo': None, 'classes': None}
dataset = dataset_class(**cfg)
self.assertEqual(
dataset.CLASSES,
dataset_class.METAINFO.get('CLASSES', ('first', 'second')))
self.assertFalse(dataset.test_mode)
# Test setting classes as a tuple
cfg = {**self.DEFAULT_ARGS, 'classes': ('bus', 'car')}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'CLASSES': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test setting classes as a tuple
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ['bus', 'car'])
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test setting classes through a file
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ['bus', 'car'])
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
fake_ann = [
dict(
img_prefix='',
img_info=dict(),
gt_label=np.array(0, dtype=np.int64))
]
with patch.object(dataset_class, 'load_annotations') as mock_load:
mock_load.return_value = fake_ann
dataset = dataset_class(**self.DEFAULT_ARGS)
dataset = dataset_class(**self.DEFAULT_ARGS)
cat_ids = dataset.get_cat_ids(0)
self.assertIsInstance(cat_ids, list)
self.assertEqual(len(cat_ids), 1)
self.assertIsInstance(cat_ids[0], int)
def test_evaluate(self):
def test_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
fake_ann = [
dict(gt_label=np.array(0, dtype=np.int64)),
dict(gt_label=np.array(0, dtype=np.int64)),
dict(gt_label=np.array(1, dtype=np.int64)),
dict(gt_label=np.array(2, dtype=np.int64)),
dict(gt_label=np.array(1, dtype=np.int64)),
dict(gt_label=np.array(0, dtype=np.int64)),
]
head = 'Dataset ' + dataset.__class__.__name__
self.assertIn(head, repr(dataset))
with patch.object(dataset_class, 'load_annotations') as mock_load:
mock_load.return_value = fake_ann
dataset = dataset_class(**self.DEFAULT_ARGS)
if dataset.CLASSES is not None:
num_classes = len(dataset.CLASSES)
self.assertIn(f'Number of categories: \t{num_classes}',
repr(dataset))
else:
self.assertIn('The `CLASSES` meta info is not set.', repr(dataset))
fake_results = np.array([
[0.7, 0.0, 0.3],
[0.5, 0.2, 0.3],
[0.4, 0.5, 0.1],
[0.0, 0.0, 1.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 1.0],
])
self.assertIn("Haven't been initialized", repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
eval_results = dataset.evaluate(
fake_results,
metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'],
metric_options={'topk': 1})
self.assertIn(f'Annotation file: \t{dataset.ann_file}', repr(dataset))
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
repr(dataset))
# Test results
self.assertAlmostEqual(
eval_results['precision'], (1 + 1 + 1 / 3) / 3 * 100.0, places=4)
self.assertAlmostEqual(
eval_results['recall'], (2 / 3 + 1 / 2 + 1) / 3 * 100.0, places=4)
self.assertAlmostEqual(
eval_results['f1_score'], (4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0,
places=4)
self.assertEqual(eval_results['support'], 6)
self.assertAlmostEqual(eval_results['accuracy'], 4 / 6 * 100, places=4)
# test indices
eval_results_ = dataset.evaluate(
fake_results[:5],
metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'],
metric_options={'topk': 1},
indices=range(5))
self.assertAlmostEqual(
eval_results_['precision'], (1 + 1 + 1 / 2) / 3 * 100.0, places=4)
self.assertAlmostEqual(
eval_results_['recall'], (1 + 1 / 2 + 1) / 3 * 100.0, places=4)
self.assertAlmostEqual(
eval_results_['f1_score'], (1 + 2 / 3 + 2 / 3) / 3 * 100.0,
places=4)
self.assertEqual(eval_results_['support'], 5)
self.assertAlmostEqual(
eval_results_['accuracy'], 4 / 5 * 100, places=4)
# test input as tensor
fake_results_tensor = torch.from_numpy(fake_results)
eval_results_ = dataset.evaluate(
fake_results_tensor,
metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'],
metric_options={'topk': 1})
assert eval_results_ == eval_results
# test thr
eval_results = dataset.evaluate(
fake_results,
metric=['precision', 'recall', 'f1_score', 'accuracy'],
metric_options={
'thrs': 0.6,
'topk': 1
})
self.assertAlmostEqual(
eval_results['precision'], (1 + 0 + 1 / 3) / 3 * 100.0, places=4)
self.assertAlmostEqual(
eval_results['recall'], (1 / 3 + 0 + 1) / 3 * 100.0, places=4)
self.assertAlmostEqual(
eval_results['f1_score'], (1 / 2 + 0 + 1 / 2) / 3 * 100.0,
places=4)
self.assertAlmostEqual(eval_results['accuracy'], 2 / 6 * 100, places=4)
# thrs must be a number or tuple
with self.assertRaises(TypeError):
dataset.evaluate(
fake_results,
metric=['precision', 'recall', 'f1_score', 'accuracy'],
metric_options={
'thrs': 'thr',
'topk': 1
})
# test topk and thr as tuple
eval_results = dataset.evaluate(
fake_results,
metric=['precision', 'recall', 'f1_score', 'accuracy'],
metric_options={
'thrs': (0.5, 0.6),
'topk': (1, 2)
})
self.assertEqual(
{
'precision_thr_0.50', 'precision_thr_0.60', 'recall_thr_0.50',
'recall_thr_0.60', 'f1_score_thr_0.50', 'f1_score_thr_0.60',
'accuracy_top-1_thr_0.50', 'accuracy_top-1_thr_0.60',
'accuracy_top-2_thr_0.50', 'accuracy_top-2_thr_0.60'
}, eval_results.keys())
self.assertIsInstance(eval_results['precision_thr_0.50'], float)
self.assertIsInstance(eval_results['recall_thr_0.50'], float)
self.assertIsInstance(eval_results['f1_score_thr_0.50'], float)
self.assertIsInstance(eval_results['accuracy_top-1_thr_0.50'], float)
# test topk is tuple while thrs is number
eval_results = dataset.evaluate(
fake_results,
metric='accuracy',
metric_options={
'thrs': 0.5,
'topk': (1, 2)
})
self.assertEqual({'accuracy_top-1', 'accuracy_top-2'},
eval_results.keys())
self.assertIsInstance(eval_results['accuracy_top-1'], float)
# test topk is number while thrs is tuple
eval_results = dataset.evaluate(
fake_results,
metric='accuracy',
metric_options={
'thrs': (0.5, 0.6),
'topk': 1
})
self.assertEqual({'accuracy_thr_0.50', 'accuracy_thr_0.60'},
eval_results.keys())
self.assertIsInstance(eval_results['accuracy_thr_0.50'], float)
# test evaluation results for classes
eval_results = dataset.evaluate(
fake_results,
metric=['precision', 'recall', 'f1_score', 'support'],
metric_options={'average_mode': 'none'})
self.assertEqual(eval_results['precision'].shape, (3, ))
self.assertEqual(eval_results['recall'].shape, (3, ))
self.assertEqual(eval_results['f1_score'].shape, (3, ))
self.assertEqual(eval_results['support'].shape, (3, ))
# the average_mode method must be valid
with self.assertRaises(ValueError):
dataset.evaluate(
fake_results,
metric=['precision', 'recall', 'f1_score', 'support'],
metric_options={'average_mode': 'micro'})
# the metric must be valid for the dataset
with self.assertRaisesRegex(ValueError,
"{'unknown'} is not supported"):
dataset.evaluate(fake_results, metric='unknown')
TRANSFORMS.register_module(name='test_mock', module=MagicMock)
cfg = {**self.DEFAULT_ARGS, 'pipeline': [dict(type='test_mock')]}
dataset = dataset_class(**cfg)
self.assertIn('With transforms', repr(dataset))
del TRANSFORMS.module_dict['test_mock']
"""Temporarily disabled.
class TestMultiLabelDataset(TestBaseDataset):
DATASET_TYPE = 'MultiLabelDataset'
@ -313,12 +158,39 @@ class TestMultiLabelDataset(TestBaseDataset):
self.assertAlmostEqual(eval_results['mAP'], 67.50, places=2)
self.assertAlmostEqual(eval_results['CR'], 43.75, places=2)
self.assertAlmostEqual(eval_results['OF1'], 42.86, places=2)
"""
class TestCustomDataset(TestBaseDataset):
DATASET_TYPE = 'CustomDataset'
def test_load_annotations(self):
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'CLASSES': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# test load without ann_file
@ -329,23 +201,17 @@ class TestCustomDataset(TestBaseDataset):
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ['a', 'b']) # auto infer classes
self.assertEqual(
dataset.data_infos[0], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'a/1.JPG'
},
'gt_label': np.array(0)
})
self.assertEqual(
dataset.data_infos[2], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'b/subb/3.jpg'
},
'gt_label': np.array(1)
})
self.assertEqual(dataset.CLASSES, ('a', 'b')) # auto infer classes
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a/1.JPG'),
'gt_label': 0
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b/subb/3.jpg'),
'gt_label': 1
}.items())
# test ann_file assertion
cfg = {
@ -353,39 +219,56 @@ class TestCustomDataset(TestBaseDataset):
'data_prefix': ASSETS_ROOT,
'ann_file': ['ann_file.txt'],
}
with self.assertRaisesRegex(TypeError, 'must be a str'):
with self.assertRaisesRegex(TypeError, 'expected str'):
dataset_class(**cfg)
# test load with ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'data_root': ASSETS_ROOT,
'ann_file': 'ann.txt',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a/1.JPG'),
'gt_label': 0,
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b/subb/3.jpg'),
'gt_label': 1
}.items())
np.testing.assert_equal(dataset.get_gt_labels(), np.array([0, 1, 1]))
# test load with absolute ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_root': None,
'data_prefix': None,
'ann_file': osp.join(ASSETS_ROOT, 'ann.txt'),
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
self.assertEqual(
dataset.data_infos[0], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'a/1.JPG'
},
'gt_label': np.array(0)
})
self.assertEqual(
dataset.data_infos[2], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'b/subb/2.jpeg'
},
'gt_label': np.array(1)
})
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': 'a/1.JPG',
'gt_label': 0,
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': 'b/subb/3.jpg',
'gt_label': 1
}.items())
# test extensions filter
cfg = {
**self.DEFAULT_ARGS, 'data_prefix': ASSETS_ROOT,
**self.DEFAULT_ARGS, 'data_prefix': dict(img_path=ASSETS_ROOT),
'ann_file': None,
'extensions': ('.txt', )
}
@ -398,28 +281,25 @@ class TestCustomDataset(TestBaseDataset):
'ann_file': None,
'extensions': ('.jpeg', )
}
with self.assertWarnsRegex(UserWarning,
'Supported extensions are: .jpeg'):
with self.assertLogs(mmcls_logger, 'WARN') as log:
dataset = dataset_class(**cfg)
self.assertIn('Supported extensions are: .jpeg', log.output[0])
self.assertEqual(len(dataset), 1)
self.assertEqual(
dataset.data_infos[0], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'b/2.jpeg'
},
'gt_label': np.array(1)
})
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b/2.jpeg'),
'gt_label': 1
}.items())
# test classes check
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'classes': ['apple', 'banana'],
'classes': ('apple', 'banana'),
'ann_file': None,
}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ['apple', 'banana'])
self.assertEqual(dataset.CLASSES, ('apple', 'banana'))
cfg['classes'] = ['apple', 'banana', 'dog']
with self.assertRaisesRegex(AssertionError,
@ -427,10 +307,12 @@ class TestCustomDataset(TestBaseDataset):
dataset_class(**cfg)
class TestImageNet(TestBaseDataset):
class TestImageNet(TestCustomDataset):
DATASET_TYPE = 'ImageNet'
def test_load_annotations(self):
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# test classes number
@ -452,20 +334,16 @@ class TestImageNet(TestBaseDataset):
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ['cat', 'dog'])
self.assertEqual(dataset.CLASSES, ('cat', 'dog'))
class TestImageNet21k(TestBaseDataset):
class TestImageNet21k(TestCustomDataset):
DATASET_TYPE = 'ImageNet21k'
DEFAULT_ARGS = dict(
data_prefix=ASSETS_ROOT,
pipeline=[],
classes=['cat', 'dog'],
ann_file=osp.join(ASSETS_ROOT, 'ann.txt'),
serialize_data=False)
data_root=ASSETS_ROOT, classes=['cat', 'dog'], ann_file='ann.txt')
def test_initialize(self):
def test_load_data_list(self):
super().test_initialize()
dataset_class = DATASETS.get(self.DATASET_TYPE)
@ -476,61 +354,18 @@ class TestImageNet21k(TestBaseDataset):
# Warn about ann_file
cfg = {**self.DEFAULT_ARGS, 'ann_file': None}
with self.assertWarnsRegex(UserWarning, 'specify the `ann_file`'):
with self.assertLogs(mmcls_logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `ann_file`', log.output[0])
# Warn about classes
cfg = {**self.DEFAULT_ARGS, 'classes': None}
with self.assertWarnsRegex(UserWarning, 'specify the `classes`'):
with self.assertLogs(mmcls_logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `classes`', log.output[0])
def test_load_annotations(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with serialize_data=False
cfg = {**self.DEFAULT_ARGS, 'serialize_data': False}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset.data_infos), 3)
self.assertEqual(len(dataset), 3)
self.assertEqual(
dataset[0], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'a/1.JPG'
},
'gt_label': np.array(0)
})
self.assertEqual(
dataset[2], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'b/subb/2.jpeg'
},
'gt_label': np.array(1)
})
# Test with serialize_data=True
cfg = {**self.DEFAULT_ARGS, 'serialize_data': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset.data_infos), 0) # data_infos is clear.
self.assertEqual(len(dataset), 3)
self.assertEqual(
dataset[0], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'a/1.JPG'
},
'gt_label': np.array(0)
})
self.assertEqual(
dataset[2], {
'img_prefix': ASSETS_ROOT,
'img_info': {
'filename': 'b/subb/2.jpeg'
},
'gt_label': np.array(1)
})
"""Temporarily disabled.
class TestMNIST(TestBaseDataset):
DATASET_TYPE = 'MNIST'
@ -594,7 +429,6 @@ class TestMNIST(TestBaseDataset):
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestCIFAR10(TestBaseDataset):
DATASET_TYPE = 'CIFAR10'
@ -668,17 +502,14 @@ class TestCIFAR10(TestBaseDataset):
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestCIFAR100(TestCIFAR10):
DATASET_TYPE = 'CIFAR100'
class TestVOC(TestMultiLabelDataset):
DATASET_TYPE = 'VOC'
DEFAULT_ARGS = dict(data_prefix='VOC2007', pipeline=[])
class TestCUB(TestBaseDataset):
DATASET_TYPE = 'CUB'
@ -761,3 +592,4 @@ class TestCUB(TestBaseDataset):
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
"""