mmclassification/mmcls/datasets/custom.py

233 lines
8.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import mmengine
from mmengine import FileClient
from mmengine.logging import MMLogger
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
def find_folders(root: str,
file_client: FileClient) -> Tuple[List[str], Dict[str, int]]:
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
Tuple[List[str], Dict[str, int]]:
- folders: The name of sub folders under the root.
- folder_to_idx: The map from folder name to class idx.
"""
folders = list(
file_client.list_dir_or_file(
root,
list_dir=True,
list_file=False,
recursive=False,
))
folders.sort()
folder_to_idx = {folders[i]: i for i in range(len(folders))}
return folders, folder_to_idx
def get_samples(root: str, folder_to_idx: Dict[str, int],
is_valid_file: Callable, file_client: FileClient):
"""Make dataset by walking all images under a root.
Args:
root (string): root directory of folders
folder_to_idx (dict): the map from class name to class idx
is_valid_file (Callable): A function that takes path of a file
and check if the file is a valid sample file.
Returns:
Tuple[list, set]:
- samples: a list of tuple where each element is (image, class_idx)
- empty_folders: The folders don't have any valid files.
"""
samples = []
available_classes = set()
for folder_name in sorted(list(folder_to_idx.keys())):
_dir = file_client.join_path(root, folder_name)
files = list(
file_client.list_dir_or_file(
_dir,
list_dir=False,
list_file=True,
recursive=True,
))
for file in sorted(list(files)):
if is_valid_file(file):
path = file_client.join_path(folder_name, file)
item = (path, folder_to_idx[folder_name])
samples.append(item)
available_classes.add(folder_name)
empty_folders = set(folder_to_idx.keys()) - available_classes
return samples, empty_folders
@DATASETS.register_module()
class CustomDataset(BaseDataset):
"""Custom dataset for classification.
The dataset supports two kinds of annotation format.
1. An annotation file is provided, and each line indicates a sample:
The sample files: ::
data_prefix/
├── folder_1
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
└── folder_2
├── 123.png
├── nsdf3.png
└── ...
The annotation file (the first column is the image path and the second
column is the index of category): ::
folder_1/xxx.png 0
folder_1/xxy.png 1
folder_2/123.png 5
folder_2/nsdf3.png 3
...
Please specify the name of categories by the argument ``classes``
or ``metainfo``.
2. The samples are arranged in the specific way: ::
data_prefix/
├── class_x
│ ├── xxx.png
│ ├── xxy.png
│ └── ...
│ └── xxz.png
└── class_y
├── 123.png
├── nsdf3.png
├── ...
└── asd932_.png
If the ``ann_file`` is specified, the dataset will be generated by the
first way, otherwise, try the second way.
Args:
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str | dict): Prefix for the data. Defaults to ''.
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
lazy_init (bool): Whether to load annotation during instantiation.
In some cases, such as visualization, only the meta information of
the dataset is needed, which is not necessary to load annotation
file. ``Basedataset`` can skip load annotations to save time by set
``lazy_init=False``. Defaults to False.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: str = '',
data_prefix: Union[str, dict] = '',
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
'.bmp', '.pgm', '.tif'),
lazy_init: bool = False,
**kwargs):
assert (ann_file or data_prefix or data_root), \
'One of `ann_file`, `data_root` and `data_prefix` must '\
'be specified.'
self.extensions = tuple(set([i.lower() for i in extensions]))
super().__init__(
# The base class requires string ann_file but this class doesn't
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
# Force to lazy_init for some modification before loading data.
lazy_init=True,
**kwargs)
# Full initialize the dataset.
if not lazy_init:
self.full_init()
def _find_samples(self, file_client):
"""find samples from ``data_prefix``."""
classes, folder_to_idx = find_folders(self.img_prefix, file_client)
samples, empty_classes = get_samples(
self.img_prefix,
folder_to_idx,
is_valid_file=self.is_valid_file,
file_client=file_client,
)
if len(samples) == 0:
raise RuntimeError(
f'Found 0 files in subfolders of: {self.data_prefix}. '
f'Supported extensions are: {",".join(self.extensions)}')
if self.CLASSES is not None:
assert len(self.CLASSES) == len(classes), \
f"The number of subfolders ({len(classes)}) doesn't match " \
f'the number of specified classes ({len(self.CLASSES)}). ' \
'Please check the data folder.'
else:
self._metainfo['classes'] = tuple(classes)
if empty_classes:
logger = MMLogger.get_current_instance()
logger.warning(
'Found no valid file in the folder '
f'{", ".join(empty_classes)}. '
f"Supported extensions are: {', '.join(self.extensions)}")
self.folder_to_idx = folder_to_idx
return samples
def load_data_list(self):
"""Load image paths and gt_labels."""
if self.img_prefix:
file_client = FileClient.infer_client(uri=self.img_prefix)
if not self.ann_file:
samples = self._find_samples(file_client)
else:
lines = mmengine.list_from_file(self.ann_file)
samples = [x.strip().rsplit(' ', 1) for x in lines]
def add_prefix(filename, prefix=''):
if not prefix:
return filename
else:
return file_client.join_path(prefix, filename)
data_list = []
for filename, gt_label in samples:
img_path = add_prefix(filename, self.img_prefix)
info = {'img_path': img_path, 'gt_label': int(gt_label)}
data_list.append(info)
return data_list
def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample."""
return filename.lower().endswith(self.extensions)