import os.path as osp from functools import reduce import mmcv import numpy as np from mmcv.utils import print_log from torch.utils.data import Dataset from mmseg.core import mean_iou from mmseg.utils import get_root_logger from .builder import DATASETS from .pipelines import Compose @DATASETS.register_module() class CustomDataset(Dataset): """Custom dataset for semantic segmentation. An example of file structure is as followed. .. code-block:: none ├── data │ ├── my_dataset │ │ ├── img_dir │ │ │ ├── train │ │ │ │ ├── xxx{img_suffix} │ │ │ │ ├── yyy{img_suffix} │ │ │ │ ├── zzz{img_suffix} │ │ │ ├── val │ │ ├── ann_dir │ │ │ ├── train │ │ │ │ ├── xxx{seg_map_suffix} │ │ │ │ ├── yyy{seg_map_suffix} │ │ │ │ ├── zzz{seg_map_suffix} │ │ │ ├── val The img/gt_semantic_seg pair of CustomDataset should be of the same except suffix. A valid img/gt_semantic_seg filename pair should be like ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included in the suffix). If split is given, then ``xxx`` is specified in txt file. Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. Please refer to ``docs/tutorials/new_dataset.md`` for more details. Args: pipeline (list[dict]): Processing pipeline img_dir (str): Path to image directory img_suffix (str): Suffix of images. Default: '.jpg' ann_dir (str, optional): Path to annotation directory. Default: None seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' split (str, optional): Split txt file. If split is specified, only file with suffix in the splits will be loaded. Otherwise, all images in img_dir/ann_dir will be loaded. Default: None data_root (str, optional): Data root for img_dir/ann_dir. Default: None. test_mode (bool): If test_mode=True, gt wouldn't be loaded. ignore_index (int): The label index to be ignored. Default: 255 reduce_zero_label (bool): Whether to mark label zero as ignored. Default: False """ CLASSES = None PALETTE = None def __init__(self, pipeline, img_dir, img_suffix='.jpg', ann_dir=None, seg_map_suffix='.png', split=None, data_root=None, test_mode=False, ignore_index=255, reduce_zero_label=False): self.pipeline = Compose(pipeline) self.img_dir = img_dir self.img_suffix = img_suffix self.ann_dir = ann_dir self.seg_map_suffix = seg_map_suffix self.split = split self.data_root = data_root self.test_mode = test_mode self.ignore_index = ignore_index self.reduce_zero_label = reduce_zero_label # join paths if data_root is specified if self.data_root is not None: if not osp.isabs(self.img_dir): self.img_dir = osp.join(self.data_root, self.img_dir) if not (self.ann_dir is None or osp.isabs(self.ann_dir)): self.ann_dir = osp.join(self.data_root, self.ann_dir) if not (self.split is None or osp.isabs(self.split)): self.split = osp.join(self.data_root, self.split) # load annotations self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, self.ann_dir, self.seg_map_suffix, self.split) def __len__(self): """Total number of samples of data.""" return len(self.img_infos) def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, split): """Load annotation from directory. Args: img_dir (str): Path to image directory img_suffix (str): Suffix of images. ann_dir (str|None): Path to annotation directory. seg_map_suffix (str|None): Suffix of segmentation maps. split (str|None): Split txt file. If split is specified, only file with suffix in the splits will be loaded. Otherwise, all images in img_dir/ann_dir will be loaded. Default: None Returns: list[dict]: All image info of dataset. """ img_infos = [] if split is not None: with open(split) as f: for line in f: img_name = line.strip() img_file = osp.join(img_dir, img_name + img_suffix) img_info = dict(filename=img_file) if ann_dir is not None: seg_map = osp.join(ann_dir, img_name + seg_map_suffix) img_info['ann'] = dict(seg_map=seg_map) img_infos.append(img_info) else: for img in mmcv.scandir(img_dir, img_suffix, recursive=True): img_file = osp.join(img_dir, img) img_info = dict(filename=img_file) if ann_dir is not None: seg_map = osp.join(ann_dir, img.replace(img_suffix, seg_map_suffix)) img_info['ann'] = dict(seg_map=seg_map) img_infos.append(img_info) print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) return img_infos def get_ann_info(self, idx): """Get annotation by index. Args: idx (int): Index of data. Returns: dict: Annotation info of specified index. """ return self.img_infos[idx]['ann'] def pre_pipeline(self, results): """Prepare results dict for pipeline.""" results['seg_fields'] = [] def __getitem__(self, idx): """Get training/test data after pipeline. Args: idx (int): Index of data. Returns: dict: Training/test data (with annotation if `test_mode` is set False). """ if self.test_mode: return self.prepare_test_img(idx) else: return self.prepare_train_img(idx) def prepare_train_img(self, idx): """Get training data and annotations after pipeline. Args: idx (int): Index of data. Returns: dict: Training data and annotation after pipeline with new keys introduced by pipeline. """ img_info = self.img_infos[idx] ann_info = self.get_ann_info(idx) results = dict(img_info=img_info, ann_info=ann_info) self.pre_pipeline(results) return self.pipeline(results) def prepare_test_img(self, idx): """Get testing data after pipeline. Args: idx (int): Index of data. Returns: dict: Testing data after pipeline with new keys intorduced by piepline. """ img_info = self.img_infos[idx] results = dict(img_info=img_info) self.pre_pipeline(results) return self.pipeline(results) def format_results(self, results, **kwargs): """Place holder to format result to dataset specific output.""" pass def get_gt_seg_maps(self): """Get ground truth segmentation maps for evaluation.""" gt_seg_maps = [] for img_info in self.img_infos: gt_seg_map = mmcv.imread( img_info['ann']['seg_map'], flag='unchanged', backend='pillow') if self.reduce_zero_label: # avoid using underflow conversion gt_seg_map[gt_seg_map == 0] = 255 gt_seg_map = gt_seg_map - 1 gt_seg_map[gt_seg_map == 254] = 255 gt_seg_maps.append(gt_seg_map) return gt_seg_maps def evaluate(self, results, metric='mIoU', logger=None, **kwargs): """Evaluate the dataset. Args: results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. Returns: dict[str, float]: Default metrics. """ if not isinstance(metric, str): assert len(metric) == 1 metric = metric[0] allowed_metrics = ['mIoU'] if metric not in allowed_metrics: raise KeyError('metric {} is not supported'.format(metric)) eval_results = {} gt_seg_maps = self.get_gt_seg_maps() if self.CLASSES is None: num_classes = len( reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) else: num_classes = len(self.CLASSES) all_acc, acc, iou = mean_iou( results, gt_seg_maps, num_classes, ignore_index=self.ignore_index) summary_str = '' summary_str += 'per class results:\n' line_format = '{:<15} {:>10} {:>10}\n' summary_str += line_format.format('Class', 'IoU', 'Acc') if self.CLASSES is None: class_names = tuple(range(num_classes)) else: class_names = self.CLASSES for i in range(num_classes): iou_str = '{:.2f}'.format(iou[i] * 100) acc_str = '{:.2f}'.format(acc[i] * 100) summary_str += line_format.format(class_names[i], iou_str, acc_str) summary_str += 'Summary:\n' line_format = '{:<15} {:>10} {:>10} {:>10}\n' summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc') iou_str = '{:.2f}'.format(np.nanmean(iou) * 100) acc_str = '{:.2f}'.format(np.nanmean(acc) * 100) all_acc_str = '{:.2f}'.format(all_acc * 100) summary_str += line_format.format('global', iou_str, acc_str, all_acc_str) print_log(summary_str, logger) eval_results['mIoU'] = np.nanmean(iou) eval_results['mAcc'] = np.nanmean(acc) eval_results['aAcc'] = all_acc return eval_results