[Feature] Add `BaseDataset`, `CustomDataset`, `ImageNet` and `ImageNet21k`
parent
98377df512
commit
27e685fe10
|
@ -7,8 +7,7 @@ from .cub import CUB
|
|||
from .custom import CustomDataset
|
||||
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
|
||||
KFoldDataset, RepeatDataset)
|
||||
from .imagenet import ImageNet
|
||||
from .imagenet21k import ImageNet21k
|
||||
from .imagenet import ImageNet, ImageNet21k
|
||||
from .mnist import MNIST, FashionMNIST
|
||||
from .multi_label import MultiLabelDataset
|
||||
from .samplers import DistributedSampler, RepeatAugSampler
|
||||
|
|
|
@ -1,58 +1,120 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from os import PathLike
|
||||
from typing import List
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from mmengine.dataset import BaseDataset as _BaseDataset
|
||||
|
||||
from mmcls.core.evaluation import precision_recall_f1, support
|
||||
from mmcls.models.losses import accuracy
|
||||
from .pipelines import Compose
|
||||
from .builder import DATASETS
|
||||
|
||||
|
||||
def expanduser(path):
|
||||
"""Expand ~ and ~user constructions.
|
||||
|
||||
If user or $HOME is unknown, do nothing.
|
||||
"""
|
||||
if isinstance(path, (str, PathLike)):
|
||||
return osp.expanduser(path)
|
||||
else:
|
||||
return path
|
||||
|
||||
|
||||
class BaseDataset(Dataset, metaclass=ABCMeta):
|
||||
"""Base dataset.
|
||||
@DATASETS.register_module()
|
||||
class BaseDataset(_BaseDataset):
|
||||
"""Base dataset for image classification task.
|
||||
|
||||
This dataset support annotation file in `OpenMMLab 2.0 style annotation
|
||||
format`.
|
||||
|
||||
.. _OpenMMLab 2.0 style annotation format:
|
||||
https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
|
||||
|
||||
Comparing with the :class:`mmengine.BaseDataset`, this class implemented
|
||||
several useful methods.
|
||||
|
||||
Args:
|
||||
data_prefix (str): the prefix of data path
|
||||
pipeline (list): a list of dict, where each element represents
|
||||
a operation defined in `mmcls.datasets.pipelines`
|
||||
ann_file (str | None): the annotation file. When ann_file is str,
|
||||
the subclass is expected to read from the ann_file. When ann_file
|
||||
is None, the subclass is expected to read according to data_prefix
|
||||
test_mode (bool): in train mode or test mode
|
||||
"""
|
||||
ann_file (str): Annotation file path.
|
||||
metainfo (dict, optional): Meta information for dataset, such as class
|
||||
information. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (str | dict, optional): Prefix for training data. Defaults
|
||||
to None.
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=False``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
classes (str | Sequence[str], optional): Specify names of classes.
|
||||
|
||||
CLASSES = None
|
||||
- If is string, it should be a file path, and the every line of
|
||||
the file is a name of a class.
|
||||
- If is a sequence of string, every item is a name of class.
|
||||
- If is None, use categories information in ``metainfo`` argument,
|
||||
annotation file or the class attribute ``METAINFO``.
|
||||
|
||||
Defaults to None.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
data_prefix,
|
||||
pipeline,
|
||||
classes=None,
|
||||
ann_file=None,
|
||||
test_mode=False):
|
||||
super(BaseDataset, self).__init__()
|
||||
self.data_prefix = expanduser(data_prefix)
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.CLASSES = self.get_classes(classes)
|
||||
self.ann_file = expanduser(ann_file)
|
||||
self.test_mode = test_mode
|
||||
self.data_infos = self.load_annotations()
|
||||
ann_file: str,
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: Union[str, dict, None] = None,
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: Sequence = (),
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
classes: Union[str, Sequence[str], None] = None):
|
||||
if isinstance(data_prefix, str):
|
||||
data_prefix = dict(img_path=expanduser(data_prefix))
|
||||
elif data_prefix is None:
|
||||
data_prefix = dict(img_path=None)
|
||||
|
||||
@abstractmethod
|
||||
def load_annotations(self):
|
||||
pass
|
||||
ann_file = expanduser(ann_file)
|
||||
metainfo = self._compat_classes(metainfo, classes)
|
||||
|
||||
super().__init__(
|
||||
ann_file=ann_file,
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
filter_cfg=filter_cfg,
|
||||
indices=indices,
|
||||
serialize_data=serialize_data,
|
||||
pipeline=pipeline,
|
||||
test_mode=test_mode,
|
||||
lazy_init=lazy_init,
|
||||
max_refetch=max_refetch)
|
||||
|
||||
@property
|
||||
def img_prefix(self):
|
||||
"""The prefix of images."""
|
||||
return self.data_prefix['img_path']
|
||||
|
||||
@property
|
||||
def CLASSES(self):
|
||||
"""Return all categories names."""
|
||||
return self._metainfo.get('CLASSES', None)
|
||||
|
||||
@property
|
||||
def class_to_idx(self):
|
||||
|
@ -62,7 +124,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
dict: mapping from class name to class index.
|
||||
"""
|
||||
|
||||
return {_class: i for i, _class in enumerate(self.CLASSES)}
|
||||
return {cat: i for i, cat in enumerate(self.CLASSES)}
|
||||
|
||||
def get_gt_labels(self):
|
||||
"""Get all ground-truth labels (categories).
|
||||
|
@ -71,7 +133,8 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
np.ndarray: categories for all images.
|
||||
"""
|
||||
|
||||
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
|
||||
gt_labels = np.array(
|
||||
[self.get_data_info(i)['gt_label'] for i in range(len(self))])
|
||||
return gt_labels
|
||||
|
||||
def get_cat_ids(self, idx: int) -> List[int]:
|
||||
|
@ -84,138 +147,64 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
cat_ids (List[int]): Image category of specified index.
|
||||
"""
|
||||
|
||||
return [int(self.data_infos[idx]['gt_label'])]
|
||||
|
||||
def prepare_data(self, idx):
|
||||
results = copy.deepcopy(self.data_infos[idx])
|
||||
return self.pipeline(results)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_infos)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.prepare_data(idx)
|
||||
|
||||
@classmethod
|
||||
def get_classes(cls, classes=None):
|
||||
"""Get class names of current dataset.
|
||||
|
||||
Args:
|
||||
classes (Sequence[str] | str | None): If classes is None, use
|
||||
default CLASSES defined by builtin dataset. If classes is a
|
||||
string, take it as a file name. The file contains the name of
|
||||
classes where each line contains one class name. If classes is
|
||||
a tuple or list, override the CLASSES defined by the dataset.
|
||||
|
||||
Returns:
|
||||
tuple[str] or list[str]: Names of categories of the dataset.
|
||||
"""
|
||||
if classes is None:
|
||||
return cls.CLASSES
|
||||
return [int(self.get_data_info(idx)['gt_label'])]
|
||||
|
||||
def _compat_classes(self, metainfo, classes):
|
||||
"""Merge the old style ``classes`` arguments to ``metainfo``."""
|
||||
if isinstance(classes, str):
|
||||
# take it as a file path
|
||||
class_names = mmcv.list_from_file(expanduser(classes))
|
||||
class_names = mmengine.list_from_file(expanduser(classes))
|
||||
elif isinstance(classes, (tuple, list)):
|
||||
class_names = classes
|
||||
else:
|
||||
elif classes is not None:
|
||||
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
||||
|
||||
return class_names
|
||||
if metainfo is None:
|
||||
metainfo = {}
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='accuracy',
|
||||
metric_options=None,
|
||||
indices=None,
|
||||
logger=None):
|
||||
"""Evaluate the dataset.
|
||||
if classes is not None:
|
||||
metainfo = {'CLASSES': tuple(class_names), **metainfo}
|
||||
|
||||
return metainfo
|
||||
|
||||
def full_init(self):
|
||||
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
|
||||
True."""
|
||||
super().full_init()
|
||||
|
||||
# To support the standard OpenMMLab 2.0 annotation format. Generate
|
||||
# metainfo in internal format from standard metainfo format.
|
||||
if 'categories' in self._metainfo and 'CLASSES' not in self._metainfo:
|
||||
categories = sorted(
|
||||
self._metainfo['categories'], key=lambda x: x['id'])
|
||||
self._metainfo['CLASSES'] = tuple(
|
||||
[cat['category_name'] for cat in categories])
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the basic information of the dataset.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
Default value is `accuracy`.
|
||||
metric_options (dict, optional): Options for calculating metrics.
|
||||
Allowed keys are 'topk', 'thrs' and 'average_mode'.
|
||||
Defaults to None.
|
||||
indices (list, optional): The indices of samples corresponding to
|
||||
the results. Defaults to None.
|
||||
logger (logging.Logger | str, optional): Logger used for printing
|
||||
related information during evaluation. Defaults to None.
|
||||
Returns:
|
||||
dict: evaluation results
|
||||
str: Formatted string.
|
||||
"""
|
||||
if metric_options is None:
|
||||
metric_options = {'topk': (1, 5)}
|
||||
if isinstance(metric, str):
|
||||
metrics = [metric]
|
||||
head = 'Dataset ' + self.__class__.__name__
|
||||
body = []
|
||||
if self._fully_initialized:
|
||||
body.append(f'Number of samples: \t{self.__len__()}')
|
||||
else:
|
||||
metrics = metric
|
||||
allowed_metrics = [
|
||||
'accuracy', 'precision', 'recall', 'f1_score', 'support'
|
||||
]
|
||||
eval_results = {}
|
||||
results = np.vstack(results)
|
||||
gt_labels = self.get_gt_labels()
|
||||
if indices is not None:
|
||||
gt_labels = gt_labels[indices]
|
||||
num_imgs = len(results)
|
||||
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
|
||||
'be of the same length as gt_labels.'
|
||||
body.append("Haven't been initialized")
|
||||
|
||||
invalid_metrics = set(metrics) - set(allowed_metrics)
|
||||
if len(invalid_metrics) != 0:
|
||||
raise ValueError(f'metric {invalid_metrics} is not supported.')
|
||||
if self.CLASSES is not None:
|
||||
body.append(f'Number of categories: \t{len(self.CLASSES)}')
|
||||
else:
|
||||
body.append('The `CLASSES` meta info is not set.')
|
||||
|
||||
topk = metric_options.get('topk', (1, 5))
|
||||
thrs = metric_options.get('thrs')
|
||||
average_mode = metric_options.get('average_mode', 'macro')
|
||||
body.append(f'Annotation file: \t{self.ann_file}')
|
||||
body.append(f'Prefix of images: \t{self.img_prefix}')
|
||||
|
||||
if 'accuracy' in metrics:
|
||||
if thrs is not None:
|
||||
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
|
||||
else:
|
||||
acc = accuracy(results, gt_labels, topk=topk)
|
||||
if isinstance(topk, tuple):
|
||||
eval_results_ = {
|
||||
f'accuracy_top-{k}': a
|
||||
for k, a in zip(topk, acc)
|
||||
}
|
||||
else:
|
||||
eval_results_ = {'accuracy': acc}
|
||||
if isinstance(thrs, tuple):
|
||||
for key, values in eval_results_.items():
|
||||
eval_results.update({
|
||||
f'{key}_thr_{thr:.2f}': value.item()
|
||||
for thr, value in zip(thrs, values)
|
||||
})
|
||||
else:
|
||||
eval_results.update(
|
||||
{k: v.item()
|
||||
for k, v in eval_results_.items()})
|
||||
if len(self.pipeline.transforms) > 0:
|
||||
body.append('With transforms:')
|
||||
for t in self.pipeline.transforms:
|
||||
body.append(f' {t}')
|
||||
|
||||
if 'support' in metrics:
|
||||
support_value = support(
|
||||
results, gt_labels, average_mode=average_mode)
|
||||
eval_results['support'] = support_value
|
||||
|
||||
precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
|
||||
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
|
||||
if thrs is not None:
|
||||
precision_recall_f1_values = precision_recall_f1(
|
||||
results, gt_labels, average_mode=average_mode, thrs=thrs)
|
||||
else:
|
||||
precision_recall_f1_values = precision_recall_f1(
|
||||
results, gt_labels, average_mode=average_mode)
|
||||
for key, values in zip(precision_recall_f1_keys,
|
||||
precision_recall_f1_values):
|
||||
if key in metrics:
|
||||
if isinstance(thrs, tuple):
|
||||
eval_results.update({
|
||||
f'{key}_thr_{thr:.2f}': value
|
||||
for thr, value in zip(thrs, values)
|
||||
})
|
||||
else:
|
||||
eval_results[key] = values
|
||||
|
||||
return eval_results
|
||||
lines = [head] + [' ' * 4 + line for line in body]
|
||||
return '\n'.join(lines)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,12 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv import FileClient
|
||||
import mmengine
|
||||
from mmengine import FileClient
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from mmcls.utils import get_root_logger
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
|
@ -104,7 +103,8 @@ class CustomDataset(BaseDataset):
|
|||
folder_2/nsdf3.png 3
|
||||
...
|
||||
|
||||
Please specify the name of categories by the argument ``classes``.
|
||||
Please specify the name of categories by the argument ``classes``
|
||||
or ``metainfo``.
|
||||
|
||||
2. The samples are arranged in the specific way: ::
|
||||
|
||||
|
@ -124,58 +124,61 @@ class CustomDataset(BaseDataset):
|
|||
first way, otherwise, try the second way.
|
||||
|
||||
Args:
|
||||
data_prefix (str): The path of data directory.
|
||||
pipeline (Sequence[dict]): A list of dict, where each element
|
||||
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
|
||||
Defaults to an empty tuple.
|
||||
classes (str | Sequence[str], optional): Specify names of classes.
|
||||
|
||||
- If is string, it should be a file path, and the every line of
|
||||
the file is a name of a class.
|
||||
- If is a sequence of string, every item is a name of class.
|
||||
- If is None, use ``cls.CLASSES`` or the names of sub folders
|
||||
(If use the second way to arrange samples).
|
||||
|
||||
Defaults to None.
|
||||
ann_file (str, optional): The annotation file. If is string, read
|
||||
samples paths from the ann_file. If is None, find samples in
|
||||
``data_prefix``. Defaults to None.
|
||||
ann_file (str, optional): Annotation file path. Defaults to None.
|
||||
metainfo (dict, optional): Meta information for dataset, such as class
|
||||
information. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (str | dict, optional): Prefix for training data. Defaults
|
||||
to None.
|
||||
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
|
||||
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
|
||||
test_mode (bool): In train mode or test mode. It's only a mark and
|
||||
won't be used in this class. Defaults to False.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
If None, automatically inference from the specified path.
|
||||
Defaults to None.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=False``. Defaults to False.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_prefix: str,
|
||||
pipeline: Sequence = (),
|
||||
classes: Union[str, Sequence[str], None] = None,
|
||||
ann_file: Optional[str] = None,
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: Union[str, dict, None] = None,
|
||||
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
|
||||
'.bmp', '.pgm', '.tif'),
|
||||
test_mode: bool = False,
|
||||
file_client_args: Optional[dict] = None):
|
||||
lazy_init: bool = False,
|
||||
**kwargs):
|
||||
assert (ann_file is not None or data_prefix is not None
|
||||
or data_root is not None), \
|
||||
'One of `ann_file`, `data_root` and `data_prefix` must '\
|
||||
'be specified.'
|
||||
|
||||
self.extensions = tuple(set([i.lower() for i in extensions]))
|
||||
self.file_client_args = file_client_args
|
||||
|
||||
super().__init__(
|
||||
# The base class requires string ann_file but this class doesn't
|
||||
ann_file=ann_file if ann_file is not None else '',
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
pipeline=pipeline,
|
||||
classes=classes,
|
||||
ann_file=ann_file,
|
||||
test_mode=test_mode)
|
||||
# Force to lazy_init for some modification before loading data.
|
||||
lazy_init=True,
|
||||
**kwargs)
|
||||
|
||||
def _find_samples(self):
|
||||
if ann_file is None:
|
||||
self.ann_file = None
|
||||
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
def _find_samples(self, file_client):
|
||||
"""find samples from ``data_prefix``."""
|
||||
file_client = FileClient.infer_client(self.file_client_args,
|
||||
self.data_prefix)
|
||||
classes, folder_to_idx = find_folders(self.data_prefix, file_client)
|
||||
classes, folder_to_idx = find_folders(self.img_prefix, file_client)
|
||||
samples, empty_classes = get_samples(
|
||||
self.data_prefix,
|
||||
self.img_prefix,
|
||||
folder_to_idx,
|
||||
is_valid_file=self.is_valid_file,
|
||||
file_client=file_client,
|
||||
|
@ -192,37 +195,42 @@ class CustomDataset(BaseDataset):
|
|||
f'the number of specified classes ({len(self.CLASSES)}). ' \
|
||||
'Please check the data folder.'
|
||||
else:
|
||||
self.CLASSES = classes
|
||||
self._metainfo['CLASSES'] = tuple(classes)
|
||||
|
||||
if empty_classes:
|
||||
warnings.warn(
|
||||
logger = get_root_logger()
|
||||
logger.warning(
|
||||
'Found no valid file in the folder '
|
||||
f'{", ".join(empty_classes)}. '
|
||||
f"Supported extensions are: {', '.join(self.extensions)}",
|
||||
UserWarning)
|
||||
f"Supported extensions are: {', '.join(self.extensions)}")
|
||||
|
||||
self.folder_to_idx = folder_to_idx
|
||||
|
||||
return samples
|
||||
|
||||
def load_annotations(self):
|
||||
def load_data_list(self):
|
||||
"""Load image paths and gt_labels."""
|
||||
if self.ann_file is None:
|
||||
samples = self._find_samples()
|
||||
elif isinstance(self.ann_file, str):
|
||||
lines = mmcv.list_from_file(
|
||||
self.ann_file, file_client_args=self.file_client_args)
|
||||
samples = [x.strip().rsplit(' ', 1) for x in lines]
|
||||
else:
|
||||
raise TypeError('ann_file must be a str or None')
|
||||
if self.img_prefix is not None:
|
||||
file_client = FileClient.infer_client(uri=self.img_prefix)
|
||||
|
||||
data_infos = []
|
||||
if self.ann_file is None:
|
||||
samples = self._find_samples(file_client)
|
||||
else:
|
||||
lines = mmengine.list_from_file(self.ann_file)
|
||||
samples = [x.strip().rsplit(' ', 1) for x in lines]
|
||||
|
||||
def add_prefix(filename, prefix=None):
|
||||
if prefix is None:
|
||||
return filename
|
||||
else:
|
||||
return file_client.join_path(prefix, filename)
|
||||
|
||||
data_list = []
|
||||
for filename, gt_label in samples:
|
||||
info = {'img_prefix': self.data_prefix}
|
||||
info['img_info'] = {'filename': filename}
|
||||
info['gt_label'] = np.array(gt_label, dtype=np.int64)
|
||||
data_infos.append(info)
|
||||
return data_infos
|
||||
img_path = add_prefix(filename, self.img_prefix)
|
||||
info = {'img_path': img_path, 'gt_label': int(gt_label)}
|
||||
data_list.append(info)
|
||||
return data_list
|
||||
|
||||
def is_valid_file(self, filename: str) -> bool:
|
||||
"""Check if a file is a valid sample."""
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,174 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import gc
|
||||
import pickle
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ImageNet21k(CustomDataset):
|
||||
"""ImageNet21k Dataset.
|
||||
|
||||
Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
|
||||
and 1.4B files. This class has improved the following points on the
|
||||
basis of the class ``ImageNet``, in order to save memory, we enable the
|
||||
``serialize_data`` optional by default. With this option, the annotation
|
||||
won't be stored in the list ``data_infos``, but be serialized as an
|
||||
array.
|
||||
|
||||
Args:
|
||||
data_prefix (str): The path of data directory.
|
||||
pipeline (Sequence[dict]): A list of dict, where each element
|
||||
represents a operation defined in :mod:`mmcls.datasets.pipelines`.
|
||||
Defaults to an empty tuple.
|
||||
classes (str | Sequence[str], optional): Specify names of classes.
|
||||
|
||||
- If is string, it should be a file path, and the every line of
|
||||
the file is a name of a class.
|
||||
- If is a sequence of string, every item is a name of class.
|
||||
- If is None, the object won't have category information.
|
||||
(Not recommended)
|
||||
|
||||
Defaults to None.
|
||||
ann_file (str, optional): The annotation file. If is string, read
|
||||
samples paths from the ann_file. If is None, find samples in
|
||||
``data_prefix``. Defaults to None.
|
||||
serialize_data (bool): Whether to hold memory using serialized objects,
|
||||
when enabled, data loader workers can use shared RAM from master
|
||||
process instead of making a copy. Defaults to True.
|
||||
multi_label (bool): Not implement by now. Use multi label or not.
|
||||
Defaults to False.
|
||||
recursion_subdir(bool): Deprecated, and the dataset will recursively
|
||||
get all images now.
|
||||
test_mode (bool): In train mode or test mode. It's only a mark and
|
||||
won't be used in this class. Defaults to False.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
If None, automatically inference from the specified path.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
||||
CLASSES = None
|
||||
|
||||
def __init__(self,
|
||||
data_prefix: str,
|
||||
pipeline: Sequence = (),
|
||||
classes: Union[str, Sequence[str], None] = None,
|
||||
ann_file: Optional[str] = None,
|
||||
serialize_data: bool = True,
|
||||
multi_label: bool = False,
|
||||
recursion_subdir: bool = True,
|
||||
test_mode=False,
|
||||
file_client_args: Optional[dict] = None):
|
||||
assert recursion_subdir, 'The `recursion_subdir` option is ' \
|
||||
'deprecated. Now the dataset will recursively get all images.'
|
||||
if multi_label:
|
||||
raise NotImplementedError(
|
||||
'The `multi_label` option is not supported by now.')
|
||||
self.multi_label = multi_label
|
||||
self.serialize_data = serialize_data
|
||||
|
||||
if ann_file is None:
|
||||
warnings.warn(
|
||||
'The ImageNet21k dataset is large, and scanning directory may '
|
||||
'consume long time. Considering to specify the `ann_file` to '
|
||||
'accelerate the initialization.', UserWarning)
|
||||
|
||||
if classes is None:
|
||||
warnings.warn(
|
||||
'The CLASSES is not stored in the `ImageNet21k` class. '
|
||||
'Considering to specify the `classes` argument if you need '
|
||||
'do inference on the ImageNet-21k dataset', UserWarning)
|
||||
|
||||
super().__init__(
|
||||
data_prefix=data_prefix,
|
||||
pipeline=pipeline,
|
||||
classes=classes,
|
||||
ann_file=ann_file,
|
||||
extensions=self.IMG_EXTENSIONS,
|
||||
test_mode=test_mode,
|
||||
file_client_args=file_client_args)
|
||||
|
||||
if self.serialize_data:
|
||||
self.data_infos_bytes, self.data_address = self._serialize_data()
|
||||
# Empty cache for preventing making multiple copies of
|
||||
# `self.data_infos` when loading data multi-processes.
|
||||
self.data_infos.clear()
|
||||
gc.collect()
|
||||
|
||||
def get_cat_ids(self, idx: int) -> List[int]:
|
||||
"""Get category id by index.
|
||||
|
||||
Args:
|
||||
idx (int): Index of data.
|
||||
|
||||
Returns:
|
||||
cat_ids (List[int]): Image category of specified index.
|
||||
"""
|
||||
|
||||
return [int(self.get_data_info(idx)['gt_label'])]
|
||||
|
||||
def get_data_info(self, idx: int) -> dict:
|
||||
"""Get annotation by index.
|
||||
|
||||
Args:
|
||||
idx (int): The index of data.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th annotation of the dataset.
|
||||
"""
|
||||
if self.serialize_data:
|
||||
start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
|
||||
end_addr = self.data_address[idx].item()
|
||||
bytes = memoryview(self.data_infos_bytes[start_addr:end_addr])
|
||||
data_info = pickle.loads(bytes)
|
||||
else:
|
||||
data_info = self.data_infos[idx]
|
||||
|
||||
return data_info
|
||||
|
||||
def prepare_data(self, idx):
|
||||
data_info = self.get_data_info(idx)
|
||||
return self.pipeline(data_info)
|
||||
|
||||
def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Serialize ``self.data_infos`` to save memory when launching multiple
|
||||
workers in data loading. This function will be called in ``full_init``.
|
||||
|
||||
Hold memory using serialized objects, and data loader workers can use
|
||||
shared RAM from master process instead of making a copy.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: serialize result and corresponding
|
||||
address.
|
||||
"""
|
||||
|
||||
def _serialize(data):
|
||||
buffer = pickle.dumps(data, protocol=4)
|
||||
return np.frombuffer(buffer, dtype=np.uint8)
|
||||
|
||||
serialized_data_infos_list = [_serialize(x) for x in self.data_infos]
|
||||
address_list = np.asarray([len(x) for x in serialized_data_infos_list],
|
||||
dtype=np.int64)
|
||||
data_address: np.ndarray = np.cumsum(address_list)
|
||||
serialized_data_infos = np.concatenate(serialized_data_infos_list)
|
||||
|
||||
return serialized_data_infos, data_address
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of filtered dataset and automatically call
|
||||
``full_init`` if the dataset has not been fully init.
|
||||
|
||||
Returns:
|
||||
int: The length of filtered dataset.
|
||||
"""
|
||||
if self.serialize_data:
|
||||
return len(self.data_address)
|
||||
else:
|
||||
return len(self.data_infos)
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"metainfo": {
|
||||
"categories": [
|
||||
{
|
||||
"category_name": "first",
|
||||
"id": 0
|
||||
},
|
||||
{
|
||||
"category_name": "second",
|
||||
"id": 1
|
||||
}
|
||||
]
|
||||
},
|
||||
"data_list": [
|
||||
{
|
||||
"img_path": "a/1.JPG",
|
||||
"gt_label": 0
|
||||
},
|
||||
{
|
||||
"img_path": "b/2.jpeg",
|
||||
"gt_label": 1
|
||||
},
|
||||
{
|
||||
"img_path": "b/subb/2.jpeg",
|
||||
"gt_label": 1
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1,3 +1,3 @@
|
|||
a/1.JPG 0
|
||||
b/2.jpeg 1
|
||||
b/subb/2.jpeg 1
|
||||
b/subb/3.jpg 1
|
||||
|
|
|
@ -1,256 +1,101 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
# import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
# import pickle
|
||||
# import tempfile
|
||||
from mmengine.registry import TRANSFORMS
|
||||
|
||||
from mmcls.datasets import DATASETS
|
||||
from mmcls.datasets import BaseDataset as _BaseDataset
|
||||
from mmcls.datasets import MultiLabelDataset as _MultiLabelDataset
|
||||
from mmcls.utils import get_root_logger
|
||||
|
||||
# import torch
|
||||
|
||||
mmcls_logger = get_root_logger()
|
||||
ASSETS_ROOT = osp.abspath(
|
||||
osp.join(osp.dirname(__file__), '../../data/dataset'))
|
||||
|
||||
|
||||
class BaseDataset(_BaseDataset):
|
||||
|
||||
def load_annotations(self):
|
||||
pass
|
||||
|
||||
|
||||
class MultiLabelDataset(_MultiLabelDataset):
|
||||
|
||||
def load_annotations(self):
|
||||
pass
|
||||
|
||||
|
||||
DATASETS.module_dict['BaseDataset'] = BaseDataset
|
||||
DATASETS.module_dict['MultiLabelDataset'] = MultiLabelDataset
|
||||
|
||||
|
||||
class TestBaseDataset(TestCase):
|
||||
DATASET_TYPE = 'BaseDataset'
|
||||
|
||||
DEFAULT_ARGS = dict(data_prefix='', pipeline=[])
|
||||
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.json')
|
||||
|
||||
def test_initialize(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
with patch.object(dataset_class, 'load_annotations'):
|
||||
# Test default behavior
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': None, 'ann_file': None}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
|
||||
self.assertFalse(dataset.test_mode)
|
||||
self.assertIsNone(dataset.ann_file)
|
||||
# Test loading metainfo from ann_file
|
||||
cfg = {**self.DEFAULT_ARGS, 'metainfo': None, 'classes': None}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(
|
||||
dataset.CLASSES,
|
||||
dataset_class.METAINFO.get('CLASSES', ('first', 'second')))
|
||||
self.assertFalse(dataset.test_mode)
|
||||
|
||||
# Test setting classes as a tuple
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': ('bus', 'car')}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
# Test overriding metainfo by `metainfo` argument
|
||||
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'CLASSES': ('bus', 'car')}}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
|
||||
# Test setting classes as a tuple
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ['bus', 'car'])
|
||||
# Test overriding metainfo by `classes` argument
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
|
||||
# Test setting classes through a file
|
||||
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ['bus', 'car'])
|
||||
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
|
||||
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
|
||||
|
||||
# Test invalid classes
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
|
||||
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
|
||||
dataset_class(**cfg)
|
||||
# Test invalid classes
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
|
||||
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
|
||||
dataset_class(**cfg)
|
||||
|
||||
def test_get_cat_ids(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
fake_ann = [
|
||||
dict(
|
||||
img_prefix='',
|
||||
img_info=dict(),
|
||||
gt_label=np.array(0, dtype=np.int64))
|
||||
]
|
||||
|
||||
with patch.object(dataset_class, 'load_annotations') as mock_load:
|
||||
mock_load.return_value = fake_ann
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
|
||||
cat_ids = dataset.get_cat_ids(0)
|
||||
self.assertIsInstance(cat_ids, list)
|
||||
self.assertEqual(len(cat_ids), 1)
|
||||
self.assertIsInstance(cat_ids[0], int)
|
||||
|
||||
def test_evaluate(self):
|
||||
def test_repr(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
fake_ann = [
|
||||
dict(gt_label=np.array(0, dtype=np.int64)),
|
||||
dict(gt_label=np.array(0, dtype=np.int64)),
|
||||
dict(gt_label=np.array(1, dtype=np.int64)),
|
||||
dict(gt_label=np.array(2, dtype=np.int64)),
|
||||
dict(gt_label=np.array(1, dtype=np.int64)),
|
||||
dict(gt_label=np.array(0, dtype=np.int64)),
|
||||
]
|
||||
head = 'Dataset ' + dataset.__class__.__name__
|
||||
self.assertIn(head, repr(dataset))
|
||||
|
||||
with patch.object(dataset_class, 'load_annotations') as mock_load:
|
||||
mock_load.return_value = fake_ann
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
if dataset.CLASSES is not None:
|
||||
num_classes = len(dataset.CLASSES)
|
||||
self.assertIn(f'Number of categories: \t{num_classes}',
|
||||
repr(dataset))
|
||||
else:
|
||||
self.assertIn('The `CLASSES` meta info is not set.', repr(dataset))
|
||||
|
||||
fake_results = np.array([
|
||||
[0.7, 0.0, 0.3],
|
||||
[0.5, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.1],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
])
|
||||
self.assertIn("Haven't been initialized", repr(dataset))
|
||||
dataset.full_init()
|
||||
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
|
||||
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results,
|
||||
metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'],
|
||||
metric_options={'topk': 1})
|
||||
self.assertIn(f'Annotation file: \t{dataset.ann_file}', repr(dataset))
|
||||
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
|
||||
repr(dataset))
|
||||
|
||||
# Test results
|
||||
self.assertAlmostEqual(
|
||||
eval_results['precision'], (1 + 1 + 1 / 3) / 3 * 100.0, places=4)
|
||||
self.assertAlmostEqual(
|
||||
eval_results['recall'], (2 / 3 + 1 / 2 + 1) / 3 * 100.0, places=4)
|
||||
self.assertAlmostEqual(
|
||||
eval_results['f1_score'], (4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0,
|
||||
places=4)
|
||||
self.assertEqual(eval_results['support'], 6)
|
||||
self.assertAlmostEqual(eval_results['accuracy'], 4 / 6 * 100, places=4)
|
||||
|
||||
# test indices
|
||||
eval_results_ = dataset.evaluate(
|
||||
fake_results[:5],
|
||||
metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'],
|
||||
metric_options={'topk': 1},
|
||||
indices=range(5))
|
||||
self.assertAlmostEqual(
|
||||
eval_results_['precision'], (1 + 1 + 1 / 2) / 3 * 100.0, places=4)
|
||||
self.assertAlmostEqual(
|
||||
eval_results_['recall'], (1 + 1 / 2 + 1) / 3 * 100.0, places=4)
|
||||
self.assertAlmostEqual(
|
||||
eval_results_['f1_score'], (1 + 2 / 3 + 2 / 3) / 3 * 100.0,
|
||||
places=4)
|
||||
self.assertEqual(eval_results_['support'], 5)
|
||||
self.assertAlmostEqual(
|
||||
eval_results_['accuracy'], 4 / 5 * 100, places=4)
|
||||
|
||||
# test input as tensor
|
||||
fake_results_tensor = torch.from_numpy(fake_results)
|
||||
eval_results_ = dataset.evaluate(
|
||||
fake_results_tensor,
|
||||
metric=['precision', 'recall', 'f1_score', 'support', 'accuracy'],
|
||||
metric_options={'topk': 1})
|
||||
assert eval_results_ == eval_results
|
||||
|
||||
# test thr
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results,
|
||||
metric=['precision', 'recall', 'f1_score', 'accuracy'],
|
||||
metric_options={
|
||||
'thrs': 0.6,
|
||||
'topk': 1
|
||||
})
|
||||
|
||||
self.assertAlmostEqual(
|
||||
eval_results['precision'], (1 + 0 + 1 / 3) / 3 * 100.0, places=4)
|
||||
self.assertAlmostEqual(
|
||||
eval_results['recall'], (1 / 3 + 0 + 1) / 3 * 100.0, places=4)
|
||||
self.assertAlmostEqual(
|
||||
eval_results['f1_score'], (1 / 2 + 0 + 1 / 2) / 3 * 100.0,
|
||||
places=4)
|
||||
self.assertAlmostEqual(eval_results['accuracy'], 2 / 6 * 100, places=4)
|
||||
|
||||
# thrs must be a number or tuple
|
||||
with self.assertRaises(TypeError):
|
||||
dataset.evaluate(
|
||||
fake_results,
|
||||
metric=['precision', 'recall', 'f1_score', 'accuracy'],
|
||||
metric_options={
|
||||
'thrs': 'thr',
|
||||
'topk': 1
|
||||
})
|
||||
|
||||
# test topk and thr as tuple
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results,
|
||||
metric=['precision', 'recall', 'f1_score', 'accuracy'],
|
||||
metric_options={
|
||||
'thrs': (0.5, 0.6),
|
||||
'topk': (1, 2)
|
||||
})
|
||||
self.assertEqual(
|
||||
{
|
||||
'precision_thr_0.50', 'precision_thr_0.60', 'recall_thr_0.50',
|
||||
'recall_thr_0.60', 'f1_score_thr_0.50', 'f1_score_thr_0.60',
|
||||
'accuracy_top-1_thr_0.50', 'accuracy_top-1_thr_0.60',
|
||||
'accuracy_top-2_thr_0.50', 'accuracy_top-2_thr_0.60'
|
||||
}, eval_results.keys())
|
||||
|
||||
self.assertIsInstance(eval_results['precision_thr_0.50'], float)
|
||||
self.assertIsInstance(eval_results['recall_thr_0.50'], float)
|
||||
self.assertIsInstance(eval_results['f1_score_thr_0.50'], float)
|
||||
self.assertIsInstance(eval_results['accuracy_top-1_thr_0.50'], float)
|
||||
|
||||
# test topk is tuple while thrs is number
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results,
|
||||
metric='accuracy',
|
||||
metric_options={
|
||||
'thrs': 0.5,
|
||||
'topk': (1, 2)
|
||||
})
|
||||
self.assertEqual({'accuracy_top-1', 'accuracy_top-2'},
|
||||
eval_results.keys())
|
||||
self.assertIsInstance(eval_results['accuracy_top-1'], float)
|
||||
|
||||
# test topk is number while thrs is tuple
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results,
|
||||
metric='accuracy',
|
||||
metric_options={
|
||||
'thrs': (0.5, 0.6),
|
||||
'topk': 1
|
||||
})
|
||||
self.assertEqual({'accuracy_thr_0.50', 'accuracy_thr_0.60'},
|
||||
eval_results.keys())
|
||||
self.assertIsInstance(eval_results['accuracy_thr_0.50'], float)
|
||||
|
||||
# test evaluation results for classes
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results,
|
||||
metric=['precision', 'recall', 'f1_score', 'support'],
|
||||
metric_options={'average_mode': 'none'})
|
||||
self.assertEqual(eval_results['precision'].shape, (3, ))
|
||||
self.assertEqual(eval_results['recall'].shape, (3, ))
|
||||
self.assertEqual(eval_results['f1_score'].shape, (3, ))
|
||||
self.assertEqual(eval_results['support'].shape, (3, ))
|
||||
|
||||
# the average_mode method must be valid
|
||||
with self.assertRaises(ValueError):
|
||||
dataset.evaluate(
|
||||
fake_results,
|
||||
metric=['precision', 'recall', 'f1_score', 'support'],
|
||||
metric_options={'average_mode': 'micro'})
|
||||
|
||||
# the metric must be valid for the dataset
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"{'unknown'} is not supported"):
|
||||
dataset.evaluate(fake_results, metric='unknown')
|
||||
TRANSFORMS.register_module(name='test_mock', module=MagicMock)
|
||||
cfg = {**self.DEFAULT_ARGS, 'pipeline': [dict(type='test_mock')]}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertIn('With transforms', repr(dataset))
|
||||
del TRANSFORMS.module_dict['test_mock']
|
||||
|
||||
|
||||
"""Temporarily disabled.
|
||||
class TestMultiLabelDataset(TestBaseDataset):
|
||||
DATASET_TYPE = 'MultiLabelDataset'
|
||||
|
||||
|
@ -313,12 +158,39 @@ class TestMultiLabelDataset(TestBaseDataset):
|
|||
self.assertAlmostEqual(eval_results['mAP'], 67.50, places=2)
|
||||
self.assertAlmostEqual(eval_results['CR'], 43.75, places=2)
|
||||
self.assertAlmostEqual(eval_results['OF1'], 42.86, places=2)
|
||||
"""
|
||||
|
||||
|
||||
class TestCustomDataset(TestBaseDataset):
|
||||
DATASET_TYPE = 'CustomDataset'
|
||||
|
||||
def test_load_annotations(self):
|
||||
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
|
||||
|
||||
def test_initialize(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test overriding metainfo by `metainfo` argument
|
||||
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'CLASSES': ('bus', 'car')}}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
|
||||
# Test overriding metainfo by `classes` argument
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
|
||||
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
|
||||
|
||||
# Test invalid classes
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
|
||||
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
|
||||
dataset_class(**cfg)
|
||||
|
||||
def test_load_data_list(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# test load without ann_file
|
||||
|
@ -329,23 +201,17 @@ class TestCustomDataset(TestBaseDataset):
|
|||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
self.assertEqual(dataset.CLASSES, ['a', 'b']) # auto infer classes
|
||||
self.assertEqual(
|
||||
dataset.data_infos[0], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'a/1.JPG'
|
||||
},
|
||||
'gt_label': np.array(0)
|
||||
})
|
||||
self.assertEqual(
|
||||
dataset.data_infos[2], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'b/subb/3.jpg'
|
||||
},
|
||||
'gt_label': np.array(1)
|
||||
})
|
||||
self.assertEqual(dataset.CLASSES, ('a', 'b')) # auto infer classes
|
||||
self.assertGreaterEqual(
|
||||
dataset.get_data_info(0).items(), {
|
||||
'img_path': osp.join(ASSETS_ROOT, 'a/1.JPG'),
|
||||
'gt_label': 0
|
||||
}.items())
|
||||
self.assertGreaterEqual(
|
||||
dataset.get_data_info(2).items(), {
|
||||
'img_path': osp.join(ASSETS_ROOT, 'b/subb/3.jpg'),
|
||||
'gt_label': 1
|
||||
}.items())
|
||||
|
||||
# test ann_file assertion
|
||||
cfg = {
|
||||
|
@ -353,39 +219,56 @@ class TestCustomDataset(TestBaseDataset):
|
|||
'data_prefix': ASSETS_ROOT,
|
||||
'ann_file': ['ann_file.txt'],
|
||||
}
|
||||
with self.assertRaisesRegex(TypeError, 'must be a str'):
|
||||
with self.assertRaisesRegex(TypeError, 'expected str'):
|
||||
dataset_class(**cfg)
|
||||
|
||||
# test load with ann_file
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS,
|
||||
'data_prefix': ASSETS_ROOT,
|
||||
'data_root': ASSETS_ROOT,
|
||||
'ann_file': 'ann.txt',
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
# custom dataset won't infer CLASSES from ann_file
|
||||
self.assertIsNone(dataset.CLASSES, None)
|
||||
self.assertGreaterEqual(
|
||||
dataset.get_data_info(0).items(), {
|
||||
'img_path': osp.join(ASSETS_ROOT, 'a/1.JPG'),
|
||||
'gt_label': 0,
|
||||
}.items())
|
||||
self.assertGreaterEqual(
|
||||
dataset.get_data_info(2).items(), {
|
||||
'img_path': osp.join(ASSETS_ROOT, 'b/subb/3.jpg'),
|
||||
'gt_label': 1
|
||||
}.items())
|
||||
np.testing.assert_equal(dataset.get_gt_labels(), np.array([0, 1, 1]))
|
||||
|
||||
# test load with absolute ann_file
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS,
|
||||
'data_root': None,
|
||||
'data_prefix': None,
|
||||
'ann_file': osp.join(ASSETS_ROOT, 'ann.txt'),
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
# custom dataset won't infer CLASSES from ann_file
|
||||
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
|
||||
self.assertEqual(
|
||||
dataset.data_infos[0], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'a/1.JPG'
|
||||
},
|
||||
'gt_label': np.array(0)
|
||||
})
|
||||
self.assertEqual(
|
||||
dataset.data_infos[2], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'b/subb/2.jpeg'
|
||||
},
|
||||
'gt_label': np.array(1)
|
||||
})
|
||||
self.assertIsNone(dataset.CLASSES, None)
|
||||
self.assertGreaterEqual(
|
||||
dataset.get_data_info(0).items(), {
|
||||
'img_path': 'a/1.JPG',
|
||||
'gt_label': 0,
|
||||
}.items())
|
||||
self.assertGreaterEqual(
|
||||
dataset.get_data_info(2).items(), {
|
||||
'img_path': 'b/subb/3.jpg',
|
||||
'gt_label': 1
|
||||
}.items())
|
||||
|
||||
# test extensions filter
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'data_prefix': ASSETS_ROOT,
|
||||
**self.DEFAULT_ARGS, 'data_prefix': dict(img_path=ASSETS_ROOT),
|
||||
'ann_file': None,
|
||||
'extensions': ('.txt', )
|
||||
}
|
||||
|
@ -398,28 +281,25 @@ class TestCustomDataset(TestBaseDataset):
|
|||
'ann_file': None,
|
||||
'extensions': ('.jpeg', )
|
||||
}
|
||||
with self.assertWarnsRegex(UserWarning,
|
||||
'Supported extensions are: .jpeg'):
|
||||
with self.assertLogs(mmcls_logger, 'WARN') as log:
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertIn('Supported extensions are: .jpeg', log.output[0])
|
||||
self.assertEqual(len(dataset), 1)
|
||||
self.assertEqual(
|
||||
dataset.data_infos[0], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'b/2.jpeg'
|
||||
},
|
||||
'gt_label': np.array(1)
|
||||
})
|
||||
self.assertGreaterEqual(
|
||||
dataset.get_data_info(0).items(), {
|
||||
'img_path': osp.join(ASSETS_ROOT, 'b/2.jpeg'),
|
||||
'gt_label': 1
|
||||
}.items())
|
||||
|
||||
# test classes check
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS,
|
||||
'data_prefix': ASSETS_ROOT,
|
||||
'classes': ['apple', 'banana'],
|
||||
'classes': ('apple', 'banana'),
|
||||
'ann_file': None,
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ['apple', 'banana'])
|
||||
self.assertEqual(dataset.CLASSES, ('apple', 'banana'))
|
||||
|
||||
cfg['classes'] = ['apple', 'banana', 'dog']
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
|
@ -427,10 +307,12 @@ class TestCustomDataset(TestBaseDataset):
|
|||
dataset_class(**cfg)
|
||||
|
||||
|
||||
class TestImageNet(TestBaseDataset):
|
||||
class TestImageNet(TestCustomDataset):
|
||||
DATASET_TYPE = 'ImageNet'
|
||||
|
||||
def test_load_annotations(self):
|
||||
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
|
||||
|
||||
def test_load_data_list(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# test classes number
|
||||
|
@ -452,20 +334,16 @@ class TestImageNet(TestBaseDataset):
|
|||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
self.assertEqual(dataset.CLASSES, ['cat', 'dog'])
|
||||
self.assertEqual(dataset.CLASSES, ('cat', 'dog'))
|
||||
|
||||
|
||||
class TestImageNet21k(TestBaseDataset):
|
||||
class TestImageNet21k(TestCustomDataset):
|
||||
DATASET_TYPE = 'ImageNet21k'
|
||||
|
||||
DEFAULT_ARGS = dict(
|
||||
data_prefix=ASSETS_ROOT,
|
||||
pipeline=[],
|
||||
classes=['cat', 'dog'],
|
||||
ann_file=osp.join(ASSETS_ROOT, 'ann.txt'),
|
||||
serialize_data=False)
|
||||
data_root=ASSETS_ROOT, classes=['cat', 'dog'], ann_file='ann.txt')
|
||||
|
||||
def test_initialize(self):
|
||||
def test_load_data_list(self):
|
||||
super().test_initialize()
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
|
@ -476,61 +354,18 @@ class TestImageNet21k(TestBaseDataset):
|
|||
|
||||
# Warn about ann_file
|
||||
cfg = {**self.DEFAULT_ARGS, 'ann_file': None}
|
||||
with self.assertWarnsRegex(UserWarning, 'specify the `ann_file`'):
|
||||
with self.assertLogs(mmcls_logger, 'WARN') as log:
|
||||
dataset_class(**cfg)
|
||||
self.assertIn('specify the `ann_file`', log.output[0])
|
||||
|
||||
# Warn about classes
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': None}
|
||||
with self.assertWarnsRegex(UserWarning, 'specify the `classes`'):
|
||||
with self.assertLogs(mmcls_logger, 'WARN') as log:
|
||||
dataset_class(**cfg)
|
||||
self.assertIn('specify the `classes`', log.output[0])
|
||||
|
||||
def test_load_annotations(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test with serialize_data=False
|
||||
cfg = {**self.DEFAULT_ARGS, 'serialize_data': False}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset.data_infos), 3)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
self.assertEqual(
|
||||
dataset[0], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'a/1.JPG'
|
||||
},
|
||||
'gt_label': np.array(0)
|
||||
})
|
||||
self.assertEqual(
|
||||
dataset[2], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'b/subb/2.jpeg'
|
||||
},
|
||||
'gt_label': np.array(1)
|
||||
})
|
||||
|
||||
# Test with serialize_data=True
|
||||
cfg = {**self.DEFAULT_ARGS, 'serialize_data': True}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset.data_infos), 0) # data_infos is clear.
|
||||
self.assertEqual(len(dataset), 3)
|
||||
self.assertEqual(
|
||||
dataset[0], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'a/1.JPG'
|
||||
},
|
||||
'gt_label': np.array(0)
|
||||
})
|
||||
self.assertEqual(
|
||||
dataset[2], {
|
||||
'img_prefix': ASSETS_ROOT,
|
||||
'img_info': {
|
||||
'filename': 'b/subb/2.jpeg'
|
||||
},
|
||||
'gt_label': np.array(1)
|
||||
})
|
||||
|
||||
"""Temporarily disabled.
|
||||
|
||||
class TestMNIST(TestBaseDataset):
|
||||
DATASET_TYPE = 'MNIST'
|
||||
|
@ -594,7 +429,6 @@ class TestMNIST(TestBaseDataset):
|
|||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
||||
|
||||
class TestCIFAR10(TestBaseDataset):
|
||||
DATASET_TYPE = 'CIFAR10'
|
||||
|
||||
|
@ -668,17 +502,14 @@ class TestCIFAR10(TestBaseDataset):
|
|||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
||||
|
||||
class TestCIFAR100(TestCIFAR10):
|
||||
DATASET_TYPE = 'CIFAR100'
|
||||
|
||||
|
||||
class TestVOC(TestMultiLabelDataset):
|
||||
DATASET_TYPE = 'VOC'
|
||||
|
||||
DEFAULT_ARGS = dict(data_prefix='VOC2007', pipeline=[])
|
||||
|
||||
|
||||
class TestCUB(TestBaseDataset):
|
||||
DATASET_TYPE = 'CUB'
|
||||
|
||||
|
@ -761,3 +592,4 @@ class TestCUB(TestBaseDataset):
|
|||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue