mmsegmentation/mmseg/datasets/custom.py

292 lines
10 KiB
Python
Raw Normal View History

2020-07-07 20:52:19 +08:00
import os.path as osp
from functools import reduce
import mmcv
import numpy as np
from mmcv.utils import print_log
from torch.utils.data import Dataset
from mmseg.core import mean_iou
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose
@DATASETS.register_module()
class CustomDataset(Dataset):
"""Custom dataset for semantic segmentation.
An example of file structure is as followed.
.. code-block:: none
data
my_dataset
img_dir
train
xxx{img_suffix}
yyy{img_suffix}
zzz{img_suffix}
val
ann_dir
train
xxx{seg_map_suffix}
yyy{seg_map_suffix}
zzz{seg_map_suffix}
val
The img/gt_semantic_seg pair of CustomDataset should be of the same
except suffix. A valid img/gt_semantic_seg filename pair should be like
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
in the suffix). If split is given, then ``xxx`` is specified in txt file.
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
Please refer to ``docs/tutorials/new_dataset.md`` for more details.
Args:
pipeline (list[dict]): Processing pipeline
img_dir (str): Path to image directory
img_suffix (str): Suffix of images. Default: '.jpg'
ann_dir (str, optional): Path to annotation directory. Default: None
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
split (str, optional): Split txt file. If split is specified, only
file with suffix in the splits will be loaded. Otherwise, all
images in img_dir/ann_dir will be loaded. Default: None
data_root (str, optional): Data root for img_dir/ann_dir. Default:
None.
test_mode (bool): If test_mode=True, gt wouldn't be loaded.
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False
"""
CLASSES = None
PALETTE = None
def __init__(self,
pipeline,
img_dir,
img_suffix='.jpg',
ann_dir=None,
seg_map_suffix='.png',
split=None,
data_root=None,
test_mode=False,
ignore_index=255,
reduce_zero_label=False):
self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
self.ann_dir = ann_dir
self.seg_map_suffix = seg_map_suffix
self.split = split
self.data_root = data_root
self.test_mode = test_mode
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
# join paths if data_root is specified
if self.data_root is not None:
if not osp.isabs(self.img_dir):
self.img_dir = osp.join(self.data_root, self.img_dir)
if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
self.ann_dir = osp.join(self.data_root, self.ann_dir)
if not (self.split is None or osp.isabs(self.split)):
self.split = osp.join(self.data_root, self.split)
# load annotations
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
self.ann_dir,
self.seg_map_suffix, self.split)
def __len__(self):
"""Total number of samples of data."""
return len(self.img_infos)
def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
split):
"""Load annotation from directory.
Args:
img_dir (str): Path to image directory
img_suffix (str): Suffix of images.
ann_dir (str|None): Path to annotation directory.
seg_map_suffix (str|None): Suffix of segmentation maps.
split (str|None): Split txt file. If split is specified, only file
with suffix in the splits will be loaded. Otherwise, all images
in img_dir/ann_dir will be loaded. Default: None
Returns:
list[dict]: All image info of dataset.
"""
img_infos = []
if split is not None:
with open(split) as f:
for line in f:
img_name = line.strip()
img_file = osp.join(img_dir, img_name + img_suffix)
img_info = dict(filename=img_file)
if ann_dir is not None:
seg_map = osp.join(ann_dir, img_name + seg_map_suffix)
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
else:
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
img_file = osp.join(img_dir, img)
img_info = dict(filename=img_file)
if ann_dir is not None:
seg_map = osp.join(ann_dir,
img.replace(img_suffix, seg_map_suffix))
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
return img_infos
def get_ann_info(self, idx):
"""Get annotation by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
return self.img_infos[idx]['ann']
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['seg_fields'] = []
def __getitem__(self, idx):
"""Get training/test data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training/test data (with annotation if `test_mode` is set
False).
"""
if self.test_mode:
return self.prepare_test_img(idx)
else:
return self.prepare_train_img(idx)
def prepare_train_img(self, idx):
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_info = self.img_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)
def prepare_test_img(self, idx):
"""Get testing data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys intorduced by
piepline.
"""
img_info = self.img_infos[idx]
results = dict(img_info=img_info)
self.pre_pipeline(results)
return self.pipeline(results)
def format_results(self, results, **kwargs):
"""Place holder to format result to dataset specific output."""
pass
def get_gt_seg_maps(self):
"""Get ground truth segmentation maps for evaluation."""
gt_seg_maps = []
for img_info in self.img_infos:
gt_seg_map = mmcv.imread(
img_info['ann']['seg_map'], flag='unchanged', backend='pillow')
if self.reduce_zero_label:
# avoid using underflow conversion
gt_seg_map[gt_seg_map == 0] = 255
gt_seg_map = gt_seg_map - 1
gt_seg_map[gt_seg_map == 254] = 255
gt_seg_maps.append(gt_seg_map)
return gt_seg_maps
def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str, float]: Default metrics.
"""
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['mIoU']
if metric not in allowed_metrics:
raise KeyError('metric {} is not supported'.format(metric))
eval_results = {}
gt_seg_maps = self.get_gt_seg_maps()
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
all_acc, acc, iou = mean_iou(
results, gt_seg_maps, num_classes, ignore_index=self.ignore_index)
summary_str = ''
summary_str += 'per class results:\n'
line_format = '{:<15} {:>10} {:>10}\n'
summary_str += line_format.format('Class', 'IoU', 'Acc')
if self.CLASSES is None:
class_names = tuple(range(num_classes))
else:
class_names = self.CLASSES
for i in range(num_classes):
iou_str = '{:.2f}'.format(iou[i] * 100)
acc_str = '{:.2f}'.format(acc[i] * 100)
summary_str += line_format.format(class_names[i], iou_str, acc_str)
summary_str += 'Summary:\n'
line_format = '{:<15} {:>10} {:>10} {:>10}\n'
summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc')
iou_str = '{:.2f}'.format(np.nanmean(iou) * 100)
acc_str = '{:.2f}'.format(np.nanmean(acc) * 100)
all_acc_str = '{:.2f}'.format(all_acc * 100)
summary_str += line_format.format('global', iou_str, acc_str,
all_acc_str)
print_log(summary_str, logger)
eval_results['mIoU'] = np.nanmean(iou)
eval_results['mAcc'] = np.nanmean(acc)
eval_results['aAcc'] = all_acc
return eval_results