# Copyright (c) OpenMMLab. All rights reserved. from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import mmengine from mmengine import FileClient from mmengine.logging import MMLogger from mmcls.registry import DATASETS from .base_dataset import BaseDataset def find_folders(root: str, file_client: FileClient) -> Tuple[List[str], Dict[str, int]]: """Find classes by folders under a root. Args: root (string): root directory of folders Returns: Tuple[List[str], Dict[str, int]]: - folders: The name of sub folders under the root. - folder_to_idx: The map from folder name to class idx. """ folders = list( file_client.list_dir_or_file( root, list_dir=True, list_file=False, recursive=False, )) folders.sort() folder_to_idx = {folders[i]: i for i in range(len(folders))} return folders, folder_to_idx def get_samples(root: str, folder_to_idx: Dict[str, int], is_valid_file: Callable, file_client: FileClient): """Make dataset by walking all images under a root. Args: root (string): root directory of folders folder_to_idx (dict): the map from class name to class idx is_valid_file (Callable): A function that takes path of a file and check if the file is a valid sample file. Returns: Tuple[list, set]: - samples: a list of tuple where each element is (image, class_idx) - empty_folders: The folders don't have any valid files. """ samples = [] available_classes = set() for folder_name in sorted(list(folder_to_idx.keys())): _dir = file_client.join_path(root, folder_name) files = list( file_client.list_dir_or_file( _dir, list_dir=False, list_file=True, recursive=True, )) for file in sorted(list(files)): if is_valid_file(file): path = file_client.join_path(folder_name, file) item = (path, folder_to_idx[folder_name]) samples.append(item) available_classes.add(folder_name) empty_folders = set(folder_to_idx.keys()) - available_classes return samples, empty_folders @DATASETS.register_module() class CustomDataset(BaseDataset): """Custom dataset for classification. The dataset supports two kinds of annotation format. 1. An annotation file is provided, and each line indicates a sample: The sample files: :: data_prefix/ ├── folder_1 │ ├── xxx.png │ ├── xxy.png │ └── ... └── folder_2 ├── 123.png ├── nsdf3.png └── ... The annotation file (the first column is the image path and the second column is the index of category): :: folder_1/xxx.png 0 folder_1/xxy.png 1 folder_2/123.png 5 folder_2/nsdf3.png 3 ... Please specify the name of categories by the argument ``classes`` or ``metainfo``. 2. The samples are arranged in the specific way: :: data_prefix/ ├── class_x │ ├── xxx.png │ ├── xxy.png │ └── ... │ └── xxz.png └── class_y ├── 123.png ├── nsdf3.png ├── ... └── asd932_.png If the ``ann_file`` is specified, the dataset will be generated by the first way, otherwise, try the second way. Args: ann_file (str): Annotation file path. Defaults to ''. metainfo (dict, optional): Meta information for dataset, such as class information. Defaults to None. data_root (str): The root directory for ``data_prefix`` and ``ann_file``. Defaults to ''. data_prefix (str | dict): Prefix for the data. Defaults to ''. extensions (Sequence[str]): A sequence of allowed extensions. Defaults to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'). lazy_init (bool): Whether to load annotation during instantiation. In some cases, such as visualization, only the meta information of the dataset is needed, which is not necessary to load annotation file. ``Basedataset`` can skip load annotations to save time by set ``lazy_init=False``. Defaults to False. **kwargs: Other keyword arguments in :class:`BaseDataset`. """ def __init__(self, ann_file: str = '', metainfo: Optional[dict] = None, data_root: str = '', data_prefix: Union[str, dict] = '', extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'), lazy_init: bool = False, **kwargs): assert (ann_file or data_prefix or data_root), \ 'One of `ann_file`, `data_root` and `data_prefix` must '\ 'be specified.' self.extensions = tuple(set([i.lower() for i in extensions])) super().__init__( # The base class requires string ann_file but this class doesn't ann_file=ann_file, metainfo=metainfo, data_root=data_root, data_prefix=data_prefix, # Force to lazy_init for some modification before loading data. lazy_init=True, **kwargs) # Full initialize the dataset. if not lazy_init: self.full_init() def _find_samples(self, file_client): """find samples from ``data_prefix``.""" classes, folder_to_idx = find_folders(self.img_prefix, file_client) samples, empty_classes = get_samples( self.img_prefix, folder_to_idx, is_valid_file=self.is_valid_file, file_client=file_client, ) if len(samples) == 0: raise RuntimeError( f'Found 0 files in subfolders of: {self.data_prefix}. ' f'Supported extensions are: {",".join(self.extensions)}') if self.CLASSES is not None: assert len(self.CLASSES) == len(classes), \ f"The number of subfolders ({len(classes)}) doesn't match " \ f'the number of specified classes ({len(self.CLASSES)}). ' \ 'Please check the data folder.' else: self._metainfo['classes'] = tuple(classes) if empty_classes: logger = MMLogger.get_current_instance() logger.warning( 'Found no valid file in the folder ' f'{", ".join(empty_classes)}. ' f"Supported extensions are: {', '.join(self.extensions)}") self.folder_to_idx = folder_to_idx return samples def load_data_list(self): """Load image paths and gt_labels.""" if self.img_prefix: file_client = FileClient.infer_client(uri=self.img_prefix) if not self.ann_file: samples = self._find_samples(file_client) else: lines = mmengine.list_from_file(self.ann_file) samples = [x.strip().rsplit(' ', 1) for x in lines] def add_prefix(filename, prefix=''): if not prefix: return filename else: return file_client.join_path(prefix, filename) data_list = [] for filename, gt_label in samples: img_path = add_prefix(filename, self.img_prefix) info = {'img_path': img_path, 'gt_label': int(gt_label)} data_list.append(info) return data_list def is_valid_file(self, filename: str) -> bool: """Check if a file is a valid sample.""" return filename.lower().endswith(self.extensions)