213 lines
7.5 KiB
Python
213 lines
7.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
from os import PathLike
|
|
from typing import List, Optional, Sequence, Union
|
|
|
|
import mmengine
|
|
import numpy as np
|
|
from mmengine.dataset import BaseDataset as _BaseDataset
|
|
|
|
from .builder import DATASETS
|
|
|
|
|
|
def expanduser(path):
|
|
"""Expand ~ and ~user constructions.
|
|
|
|
If user or $HOME is unknown, do nothing.
|
|
"""
|
|
if isinstance(path, (str, PathLike)):
|
|
return osp.expanduser(path)
|
|
else:
|
|
return path
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class BaseDataset(_BaseDataset):
|
|
"""Base dataset for image classification task.
|
|
|
|
This dataset support annotation file in `OpenMMLab 2.0 style annotation
|
|
format`.
|
|
|
|
.. _OpenMMLab 2.0 style annotation format:
|
|
https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
|
|
|
|
Comparing with the :class:`mmengine.BaseDataset`, this class implemented
|
|
several useful methods.
|
|
|
|
Args:
|
|
ann_file (str): Annotation file path.
|
|
metainfo (dict, optional): Meta information for dataset, such as class
|
|
information. Defaults to None.
|
|
data_root (str): The root directory for ``data_prefix`` and
|
|
``ann_file``. Defaults to ''.
|
|
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
|
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
|
indices (int or Sequence[int], optional): Support using first few
|
|
data in annotation file to facilitate training/testing on a smaller
|
|
dataset. Defaults to None, which means using all ``data_infos``.
|
|
serialize_data (bool): Whether to hold memory using serialized objects,
|
|
when enabled, data loader workers can use shared RAM from master
|
|
process instead of making a copy. Defaults to True.
|
|
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
|
test_mode (bool): ``test_mode=True`` means in test phase.
|
|
Defaults to False.
|
|
lazy_init (bool): Whether to load annotation during instantiation.
|
|
In some cases, such as visualization, only the meta information of
|
|
the dataset is needed, which is not necessary to load annotation
|
|
file. ``Basedataset`` can skip load annotations to save time by set
|
|
``lazy_init=False``. Defaults to False.
|
|
max_refetch (int): If ``Basedataset.prepare_data`` get a None img.
|
|
The maximum extra number of cycles to get a valid image.
|
|
Defaults to 1000.
|
|
classes (str | Sequence[str], optional): Specify names of classes.
|
|
|
|
- If is string, it should be a file path, and the every line of
|
|
the file is a name of a class.
|
|
- If is a sequence of string, every item is a name of class.
|
|
- If is None, use categories information in ``metainfo`` argument,
|
|
annotation file or the class attribute ``METAINFO``.
|
|
|
|
Defaults to None.
|
|
""" # noqa: E501
|
|
|
|
def __init__(self,
|
|
ann_file: str,
|
|
metainfo: Optional[dict] = None,
|
|
data_root: str = '',
|
|
data_prefix: Union[str, dict] = '',
|
|
filter_cfg: Optional[dict] = None,
|
|
indices: Optional[Union[int, Sequence[int]]] = None,
|
|
serialize_data: bool = True,
|
|
pipeline: Sequence = (),
|
|
test_mode: bool = False,
|
|
lazy_init: bool = False,
|
|
max_refetch: int = 1000,
|
|
classes: Union[str, Sequence[str], None] = None):
|
|
if isinstance(data_prefix, str):
|
|
data_prefix = dict(img_path=expanduser(data_prefix))
|
|
|
|
ann_file = expanduser(ann_file)
|
|
metainfo = self._compat_classes(metainfo, classes)
|
|
|
|
super().__init__(
|
|
ann_file=ann_file,
|
|
metainfo=metainfo,
|
|
data_root=data_root,
|
|
data_prefix=data_prefix,
|
|
filter_cfg=filter_cfg,
|
|
indices=indices,
|
|
serialize_data=serialize_data,
|
|
pipeline=pipeline,
|
|
test_mode=test_mode,
|
|
lazy_init=lazy_init,
|
|
max_refetch=max_refetch)
|
|
|
|
@property
|
|
def img_prefix(self):
|
|
"""The prefix of images."""
|
|
return self.data_prefix['img_path']
|
|
|
|
@property
|
|
def CLASSES(self):
|
|
"""Return all categories names."""
|
|
return self._metainfo.get('classes', None)
|
|
|
|
@property
|
|
def class_to_idx(self):
|
|
"""Map mapping class name to class index.
|
|
|
|
Returns:
|
|
dict: mapping from class name to class index.
|
|
"""
|
|
|
|
return {cat: i for i, cat in enumerate(self.CLASSES)}
|
|
|
|
def get_gt_labels(self):
|
|
"""Get all ground-truth labels (categories).
|
|
|
|
Returns:
|
|
np.ndarray: categories for all images.
|
|
"""
|
|
|
|
gt_labels = np.array(
|
|
[self.get_data_info(i)['gt_label'] for i in range(len(self))])
|
|
return gt_labels
|
|
|
|
def get_cat_ids(self, idx: int) -> List[int]:
|
|
"""Get category id by index.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
cat_ids (List[int]): Image category of specified index.
|
|
"""
|
|
|
|
return [int(self.get_data_info(idx)['gt_label'])]
|
|
|
|
def _compat_classes(self, metainfo, classes):
|
|
"""Merge the old style ``classes`` arguments to ``metainfo``."""
|
|
if isinstance(classes, str):
|
|
# take it as a file path
|
|
class_names = mmengine.list_from_file(expanduser(classes))
|
|
elif isinstance(classes, (tuple, list)):
|
|
class_names = classes
|
|
elif classes is not None:
|
|
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
|
|
|
if metainfo is None:
|
|
metainfo = {}
|
|
|
|
if classes is not None:
|
|
metainfo = {'classes': tuple(class_names), **metainfo}
|
|
|
|
return metainfo
|
|
|
|
def full_init(self):
|
|
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
|
|
True."""
|
|
super().full_init()
|
|
|
|
# To support the standard OpenMMLab 2.0 annotation format. Generate
|
|
# metainfo in internal format from standard metainfo format.
|
|
if 'categories' in self._metainfo and 'classes' not in self._metainfo:
|
|
categories = sorted(
|
|
self._metainfo['categories'], key=lambda x: x['id'])
|
|
self._metainfo['classes'] = tuple(
|
|
[cat['category_name'] for cat in categories])
|
|
|
|
def __repr__(self):
|
|
"""Print the basic information of the dataset.
|
|
|
|
Returns:
|
|
str: Formatted string.
|
|
"""
|
|
head = 'Dataset ' + self.__class__.__name__
|
|
body = []
|
|
if self._fully_initialized:
|
|
body.append(f'Number of samples: \t{self.__len__()}')
|
|
else:
|
|
body.append("Haven't been initialized")
|
|
|
|
if self.CLASSES is not None:
|
|
body.append(f'Number of categories: \t{len(self.CLASSES)}')
|
|
else:
|
|
body.append('The `CLASSES` meta info is not set.')
|
|
|
|
body.extend(self.extra_repr())
|
|
|
|
if len(self.pipeline.transforms) > 0:
|
|
body.append('With transforms:')
|
|
for t in self.pipeline.transforms:
|
|
body.append(f' {t}')
|
|
|
|
lines = [head] + [' ' * 4 + line for line in body]
|
|
return '\n'.join(lines)
|
|
|
|
def extra_repr(self) -> List[str]:
|
|
"""The extra repr information of the dataset."""
|
|
body = []
|
|
body.append(f'Annotation file: \t{self.ann_file}')
|
|
body.append(f'Prefix of images: \t{self.img_prefix}')
|
|
return body
|