[Feature] Support eval concate dataset and add tool to show dataset (#833)

* [Feature] Add tool to show origin or augmented train data

* [Feature] Support eval concate dataset

* Add docstring and modify evaluate of concate dataset

Signed-off-by: FreyWang <wangwxyz@qq.com>

* format concat dataset in subfolder of imgfile_prefix

Signed-off-by: FreyWang <wangwxyz@qq.com>

* add unittest of concate dataset

Signed-off-by: FreyWang <wangwxyz@qq.com>

* update unittest for eval dataset with CLASSES is None

Signed-off-by: FreyWang <wangwxyz@qq.com>

* [FIX] bug of generator,  which lead metric to nan when pre_eval=False

Signed-off-by: FreyWang <wangwxyz@qq.com>

* format code

Signed-off-by: FreyWang <wangwxyz@qq.com>

* add more unittest

* add more unittest

* optim concat dataset builder
This commit is contained in:
FreyWang 2021-09-09 13:00:23 +08:00 committed by GitHub
parent eb0baee414
commit 872e54497e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 645 additions and 45 deletions

View File

@ -112,8 +112,6 @@ def total_intersect_and_union(results,
ndarray: The prediction histogram on all classes. ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes. ndarray: The ground truth histogram on all classes.
""" """
num_imgs = len(results)
assert len(list(gt_seg_maps)) == num_imgs
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)

View File

@ -30,6 +30,8 @@ def _concat_dataset(cfg, default_args=None):
img_dir = cfg['img_dir'] img_dir = cfg['img_dir']
ann_dir = cfg.get('ann_dir', None) ann_dir = cfg.get('ann_dir', None)
split = cfg.get('split', None) split = cfg.get('split', None)
# pop 'separate_eval' since it is not a valid key for common datasets.
separate_eval = cfg.pop('separate_eval', True)
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
if ann_dir is not None: if ann_dir is not None:
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
@ -57,7 +59,7 @@ def _concat_dataset(cfg, default_args=None):
data_cfg['split'] = split[i] data_cfg['split'] = split[i]
datasets.append(build_dataset(data_cfg, default_args)) datasets.append(build_dataset(data_cfg, default_args))
return ConcatDataset(datasets) return ConcatDataset(datasets, separate_eval)
def build_dataset(cfg, default_args=None): def build_dataset(cfg, default_args=None):

View File

@ -2,7 +2,6 @@
import os.path as osp import os.path as osp
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from functools import reduce
import mmcv import mmcv
import numpy as np import numpy as np
@ -99,6 +98,9 @@ class CustomDataset(Dataset):
self.label_map = None self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette( self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette) classes, palette)
if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'
# join paths if data_root is specified # join paths if data_root is specified
if self.data_root is not None: if self.data_root is not None:
@ -339,7 +341,12 @@ class CustomDataset(Dataset):
return palette return palette
def evaluate(self, results, metric='mIoU', logger=None, **kwargs): def evaluate(self,
results,
metric='mIoU',
logger=None,
gt_seg_maps=None,
**kwargs):
"""Evaluate the dataset. """Evaluate the dataset.
Args: Args:
@ -350,6 +357,8 @@ class CustomDataset(Dataset):
'mDice' and 'mFscore' are supported. 'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
used in ConcatDataset
Returns: Returns:
dict[str, float]: Default metrics. dict[str, float]: Default metrics.
@ -364,14 +373,9 @@ class CustomDataset(Dataset):
# test a list of files # test a list of files
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str): results, str):
gt_seg_maps = self.get_gt_seg_maps() if gt_seg_maps is None:
if self.CLASSES is None: gt_seg_maps = self.get_gt_seg_maps()
num_classes = len( num_classes = len(self.CLASSES)
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
# reset generator
gt_seg_maps = self.get_gt_seg_maps()
ret_metrics = eval_metrics( ret_metrics = eval_metrics(
results, results,
gt_seg_maps, gt_seg_maps,

View File

@ -1,7 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import bisect
from itertools import chain
import mmcv
import numpy as np
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS from .builder import DATASETS
from .cityscapes import CityscapesDataset
@DATASETS.register_module() @DATASETS.register_module()
@ -9,16 +16,148 @@ class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset. """A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio. support evaluation and formatting results
Args: Args:
datasets (list[:obj:`Dataset`]): A list of datasets. datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the concatenated
dataset results separately, Defaults to True.
""" """
def __init__(self, datasets): def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets) super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE self.PALETTE = datasets[0].PALETTE
self.separate_eval = separate_eval
assert separate_eval in [True, False], \
f'separate_eval can only be True or False,' \
f'but get {separate_eval}'
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
raise NotImplementedError(
'Evaluating ConcatDataset containing CityscapesDataset'
'is not supported!')
def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[tuple[torch.Tensor]] | list[str]]): per image
pre_eval results or predict segmentation map for
computing evaluation metric.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]: evaluate results of the total dataset
or each separate
dataset if `self.separate_eval=True`.
"""
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'
if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
self.cumulative_sizes[dataset_idx]
end_idx = self.cumulative_sizes[dataset_idx + 1]
results_per_dataset = results[start_idx:end_idx]
print_log(
f'\nEvaluateing {dataset.img_dir} with '
f'{len(results_per_dataset)} images now',
logger=logger)
eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})
return total_eval_results
if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types when '
'self.separate_eval=False')
else:
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
# merge the generators of gt_seg_maps
gt_seg_maps = chain(
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
else:
# if the results are `pre_eval` results,
# we do not need gt_seg_maps to evaluate
gt_seg_maps = None
eval_results = self.datasets[0].evaluate(
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
return eval_results
def get_dataset_idx_and_sample_idx(self, indice):
"""Return dataset and sample index when given an indice of
ConcatDataset.
Args:
indice (int): indice of sample in ConcatDataset
Returns:
int: the index of sub dataset the sample belong to
int: the index of sample in its corresponding subset
"""
if indice < 0:
if -indice > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
indice = len(self) + indice
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
if dataset_idx == 0:
sample_idx = indice
else:
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""format result for every sample of ConcatDataset."""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].format_results(
[results[i]],
imgfile_prefix + f'/{dataset_idx}',
indices=[sample_idx],
**kwargs)
ret_res.append(res)
return sum(ret_res, [])
def pre_eval(self, preds, indices):
"""do pre eval for every sample of ConcatDataset."""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
ret_res.append(res)
return sum(ret_res, [])
@DATASETS.register_module() @DATASETS.register_module()

View File

@ -6,12 +6,13 @@ from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
import torch
from PIL import Image from PIL import Image
from mmseg.core.evaluation import get_classes, get_palette from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, PascalVOCDataset, ConcatDataset, CustomDataset, PascalVOCDataset,
RepeatDataset) RepeatDataset, build_dataset)
def test_classes(): def test_classes():
@ -143,7 +144,8 @@ def test_custom_dataset():
test_pipeline, test_pipeline,
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'), img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
img_suffix='img.jpg', img_suffix='img.jpg',
test_mode=True) test_mode=True,
classes=('pseudo_class', ))
assert len(test_dataset) == 5 assert len(test_dataset) == 5
# training data get # training data get
@ -164,30 +166,21 @@ def test_custom_dataset():
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
test_dataset.format_results([], '') test_dataset.format_results([], '')
# test past evaluation
pseudo_results = [] pseudo_results = []
for gt_seg_map in gt_seg_maps: for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') # test past evaluation without CLASSES
assert isinstance(eval_results, dict) with pytest.raises(TypeError):
assert 'mDice' in eval_results eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate( with pytest.raises(TypeError):
pseudo_results, metric=['mDice', 'mIoU']) eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results with pytest.raises(TypeError):
assert 'mDice' in eval_results eval_results = train_dataset.evaluate(
assert 'mAcc' in eval_results pseudo_results, metric=['mDice', 'mIoU'])
assert 'aAcc' in eval_results
# test past evaluation with CLASSES # test past evaluation with CLASSES
train_dataset.CLASSES = tuple(['a'] * 7) train_dataset.CLASSES = tuple(['a'] * 7)
@ -221,6 +214,14 @@ def test_custom_dataset():
assert 'mPrecision' in eval_results assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results assert 'mRecall' in eval_results
assert not np.isnan(eval_results['mIoU'])
assert not np.isnan(eval_results['mDice'])
assert not np.isnan(eval_results['mAcc'])
assert not np.isnan(eval_results['aAcc'])
assert not np.isnan(eval_results['mFscore'])
assert not np.isnan(eval_results['mPrecision'])
assert not np.isnan(eval_results['mRecall'])
# test evaluation with pre-eval and the dataset.CLASSES is necessary # test evaluation with pre-eval and the dataset.CLASSES is necessary
train_dataset.CLASSES = tuple(['a'] * 7) train_dataset.CLASSES = tuple(['a'] * 7)
pseudo_results = [] pseudo_results = []
@ -258,6 +259,223 @@ def test_custom_dataset():
assert 'mPrecision' in eval_results assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results assert 'mRecall' in eval_results
assert not np.isnan(eval_results['mIoU'])
assert not np.isnan(eval_results['mDice'])
assert not np.isnan(eval_results['mAcc'])
assert not np.isnan(eval_results['aAcc'])
assert not np.isnan(eval_results['mFscore'])
assert not np.isnan(eval_results['mPrecision'])
assert not np.isnan(eval_results['mRecall'])
@pytest.mark.parametrize('separate_eval', [True, False])
def test_eval_concat_custom_dataset(separate_eval):
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(128, 256),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
img_dir = 'imgs/'
ann_dir = 'gts/'
cfg1 = dict(
type='CustomDataset',
pipeline=test_pipeline,
data_root=data_root,
img_dir=img_dir,
ann_dir=ann_dir,
img_suffix='img.jpg',
seg_map_suffix='gt.png',
classes=tuple(['a'] * 7))
dataset1 = build_dataset(cfg1)
assert len(dataset1) == 5
# get gt seg map
gt_seg_maps = dataset1.get_gt_seg_maps(efficient_test=True)
assert isinstance(gt_seg_maps, Generator)
gt_seg_maps = list(gt_seg_maps)
assert len(gt_seg_maps) == 5
# test past evaluation
pseudo_results = []
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results1 = dataset1.evaluate(
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
# We use same dir twice for simplicity
# with ann_dir
cfg2 = dict(
type='CustomDataset',
pipeline=test_pipeline,
data_root=data_root,
img_dir=[img_dir, img_dir],
ann_dir=[ann_dir, ann_dir],
img_suffix='img.jpg',
seg_map_suffix='gt.png',
classes=tuple(['a'] * 7),
separate_eval=separate_eval)
dataset2 = build_dataset(cfg2)
assert isinstance(dataset2, ConcatDataset)
assert len(dataset2) == 10
eval_results2 = dataset2.evaluate(
pseudo_results * 2, metric=['mIoU', 'mDice', 'mFscore'])
if separate_eval:
assert eval_results1['mIoU'] == eval_results2[
'0_mIoU'] == eval_results2['1_mIoU']
assert eval_results1['mDice'] == eval_results2[
'0_mDice'] == eval_results2['1_mDice']
assert eval_results1['mAcc'] == eval_results2[
'0_mAcc'] == eval_results2['1_mAcc']
assert eval_results1['aAcc'] == eval_results2[
'0_aAcc'] == eval_results2['1_aAcc']
assert eval_results1['mFscore'] == eval_results2[
'0_mFscore'] == eval_results2['1_mFscore']
assert eval_results1['mPrecision'] == eval_results2[
'0_mPrecision'] == eval_results2['1_mPrecision']
assert eval_results1['mRecall'] == eval_results2[
'0_mRecall'] == eval_results2['1_mRecall']
else:
assert eval_results1['mIoU'] == eval_results2['mIoU']
assert eval_results1['mDice'] == eval_results2['mDice']
assert eval_results1['mAcc'] == eval_results2['mAcc']
assert eval_results1['aAcc'] == eval_results2['aAcc']
assert eval_results1['mFscore'] == eval_results2['mFscore']
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
assert eval_results1['mRecall'] == eval_results2['mRecall']
# test get dataset_idx and sample_idx from ConcateDataset
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(3)
assert dataset_idx == 0
assert sample_idx == 3
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(7)
assert dataset_idx == 1
assert sample_idx == 2
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-7)
assert dataset_idx == 0
assert sample_idx == 3
# test negative indice exceed length of dataset
with pytest.raises(ValueError):
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-11)
# test negative indice value
indice = -6
dataset_idx1, sample_idx1 = dataset2.get_dataset_idx_and_sample_idx(indice)
dataset_idx2, sample_idx2 = dataset2.get_dataset_idx_and_sample_idx(
len(dataset2) + indice)
assert dataset_idx1 == dataset_idx2
assert sample_idx1 == sample_idx2
# test evaluation with pre-eval and the dataset.CLASSES is necessary
pseudo_results = []
eval_results1 = []
for idx in range(len(dataset1)):
h, w = gt_seg_maps[idx].shape
pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
pseudo_results.append(pseudo_result)
eval_results1.extend(dataset1.pre_eval(pseudo_result, idx))
assert len(eval_results1) == len(dataset1)
assert isinstance(eval_results1[0], tuple)
assert len(eval_results1[0]) == 4
assert isinstance(eval_results1[0][0], torch.Tensor)
eval_results1 = dataset1.evaluate(
eval_results1, metric=['mIoU', 'mDice', 'mFscore'])
pseudo_results = pseudo_results * 2
eval_results2 = []
for idx in range(len(dataset2)):
eval_results2.extend(dataset2.pre_eval(pseudo_results[idx], idx))
assert len(eval_results2) == len(dataset2)
assert isinstance(eval_results2[0], tuple)
assert len(eval_results2[0]) == 4
assert isinstance(eval_results2[0][0], torch.Tensor)
eval_results2 = dataset2.evaluate(
eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
if separate_eval:
assert eval_results1['mIoU'] == eval_results2[
'0_mIoU'] == eval_results2['1_mIoU']
assert eval_results1['mDice'] == eval_results2[
'0_mDice'] == eval_results2['1_mDice']
assert eval_results1['mAcc'] == eval_results2[
'0_mAcc'] == eval_results2['1_mAcc']
assert eval_results1['aAcc'] == eval_results2[
'0_aAcc'] == eval_results2['1_aAcc']
assert eval_results1['mFscore'] == eval_results2[
'0_mFscore'] == eval_results2['1_mFscore']
assert eval_results1['mPrecision'] == eval_results2[
'0_mPrecision'] == eval_results2['1_mPrecision']
assert eval_results1['mRecall'] == eval_results2[
'0_mRecall'] == eval_results2['1_mRecall']
else:
assert eval_results1['mIoU'] == eval_results2['mIoU']
assert eval_results1['mDice'] == eval_results2['mDice']
assert eval_results1['mAcc'] == eval_results2['mAcc']
assert eval_results1['aAcc'] == eval_results2['aAcc']
assert eval_results1['mFscore'] == eval_results2['mFscore']
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
assert eval_results1['mRecall'] == eval_results2['mRecall']
# test batch_indices for pre eval
eval_results2 = dataset2.pre_eval(pseudo_results,
list(range(len(pseudo_results))))
assert len(eval_results2) == len(dataset2)
assert isinstance(eval_results2[0], tuple)
assert len(eval_results2[0]) == 4
assert isinstance(eval_results2[0][0], torch.Tensor)
eval_results2 = dataset2.evaluate(
eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
if separate_eval:
assert eval_results1['mIoU'] == eval_results2[
'0_mIoU'] == eval_results2['1_mIoU']
assert eval_results1['mDice'] == eval_results2[
'0_mDice'] == eval_results2['1_mDice']
assert eval_results1['mAcc'] == eval_results2[
'0_mAcc'] == eval_results2['1_mAcc']
assert eval_results1['aAcc'] == eval_results2[
'0_aAcc'] == eval_results2['1_aAcc']
assert eval_results1['mFscore'] == eval_results2[
'0_mFscore'] == eval_results2['1_mFscore']
assert eval_results1['mPrecision'] == eval_results2[
'0_mPrecision'] == eval_results2['1_mPrecision']
assert eval_results1['mRecall'] == eval_results2[
'0_mRecall'] == eval_results2['1_mRecall']
else:
assert eval_results1['mIoU'] == eval_results2['mIoU']
assert eval_results1['mDice'] == eval_results2['mDice']
assert eval_results1['mAcc'] == eval_results2['mAcc']
assert eval_results1['aAcc'] == eval_results2['aAcc']
assert eval_results1['mFscore'] == eval_results2['mFscore']
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
assert eval_results1['mRecall'] == eval_results2['mRecall']
def test_ade(): def test_ade():
test_dataset = ADE20KDataset( test_dataset = ADE20KDataset(
@ -279,6 +497,44 @@ def test_ade():
shutil.rmtree('.format_ade') shutil.rmtree('.format_ade')
@pytest.mark.parametrize('separate_eval', [True, False])
def test_concat_ade(separate_eval):
test_dataset = ADE20KDataset(
pipeline=[],
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
assert len(test_dataset) == 5
concat_dataset = ConcatDataset([test_dataset, test_dataset],
separate_eval=separate_eval)
assert len(concat_dataset) == 10
# Test format_results
pseudo_results = []
for _ in range(len(concat_dataset)):
h, w = (2, 2)
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
# test format per image
file_paths = []
for i in range(len(pseudo_results)):
file_paths.extend(
concat_dataset.format_results([pseudo_results[i]],
'.format_ade',
indices=[i]))
assert len(file_paths) == len(concat_dataset)
temp = np.array(Image.open(file_paths[0]))
assert np.allclose(temp, pseudo_results[0] + 1)
shutil.rmtree('.format_ade')
# test default argument
file_paths = concat_dataset.format_results(pseudo_results, '.format_ade')
assert len(file_paths) == len(concat_dataset)
temp = np.array(Image.open(file_paths[0]))
assert np.allclose(temp, pseudo_results[0] + 1)
shutil.rmtree('.format_ade')
def test_cityscapes(): def test_cityscapes():
test_dataset = CityscapesDataset( test_dataset = CityscapesDataset(
pipeline=[], pipeline=[],
@ -311,6 +567,28 @@ def test_cityscapes():
shutil.rmtree('.format_city') shutil.rmtree('.format_city')
@pytest.mark.parametrize('separate_eval', [True, False])
def test_concat_cityscapes(separate_eval):
cityscape_dataset = CityscapesDataset(
pipeline=[],
img_dir=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
assert len(cityscape_dataset) == 1
with pytest.raises(NotImplementedError):
_ = ConcatDataset([cityscape_dataset, cityscape_dataset],
separate_eval=separate_eval)
ade_dataset = ADE20KDataset(
pipeline=[],
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
assert len(ade_dataset) == 5
with pytest.raises(NotImplementedError):
_ = ConcatDataset([cityscape_dataset, ade_dataset],
separate_eval=separate_eval)
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__', @patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx)) MagicMock(side_effect=lambda idx: idx))
@ -360,14 +638,23 @@ def test_custom_classes_override_default(dataset, classes):
assert custom_dataset.CLASSES == [classes[0]] assert custom_dataset.CLASSES == [classes[0]]
# Test default behavior # Test default behavior
custom_dataset = dataset_class( if dataset_class is CustomDataset:
pipeline=[], with pytest.raises(AssertionError):
img_dir=MagicMock(), custom_dataset = dataset_class(
split=MagicMock(), pipeline=[],
classes=None, img_dir=MagicMock(),
test_mode=True) split=MagicMock(),
classes=None,
test_mode=True)
else:
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)
assert custom_dataset.CLASSES == original_classes assert custom_dataset.CLASSES == original_classes
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)

View File

@ -78,7 +78,8 @@ def test_build_dataset():
pipeline=[], pipeline=[],
data_root=data_root, data_root=data_root,
img_dir=[img_dir, img_dir], img_dir=[img_dir, img_dir],
test_mode=True) test_mode=True,
classes=('pseudo_class', ))
dataset = build_dataset(cfg) dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset) assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 10 assert len(dataset) == 10
@ -90,7 +91,8 @@ def test_build_dataset():
data_root=data_root, data_root=data_root,
img_dir=[img_dir, img_dir], img_dir=[img_dir, img_dir],
split=['splits/val.txt', 'splits/val.txt'], split=['splits/val.txt', 'splits/val.txt'],
test_mode=True) test_mode=True,
classes=('pseudo_class', ))
dataset = build_dataset(cfg) dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset) assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 2 assert len(dataset) == 2

167
tools/browse_dataset.py Normal file
View File

@ -0,0 +1,167 @@
import argparse
import os
import warnings
from pathlib import Path
import mmcv
import numpy as np
from mmcv import Config
from mmseg.datasets.builder import build_dataset
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--show-origin',
default=False,
action='store_true',
help='if True, omit all augmentation in pipeline,'
' show origin image and seg map')
parser.add_argument(
'--skip-type',
type=str,
nargs='+',
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
help='skip some useless pipelineif `show-origin` is true, '
'all pipeline except `Load` will be skipped')
parser.add_argument(
'--output-dir',
default='./output',
type=str,
help='If there is no display interface, you can save it')
parser.add_argument('--show', default=False, action='store_true')
parser.add_argument(
'--show-interval',
type=int,
default=999,
help='the interval of show (ms)')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='the opacity of semantic map')
args = parser.parse_args()
return args
def imshow_semantic(img,
seg,
class_names,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
seg (Tensor): The semantic segmentation results to draw over
`img`.
class_names (list[str]): Names of each classes.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
if palette is None:
palette = np.random.randint(0, 255, size=(len(class_names), 3))
palette = np.array(palette)
assert palette.shape[0] == len(class_names)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
if show_origin is True:
# only keep pipeline of Loading data and ann
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if 'Load' in x['type']
]
else:
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if x['type'] not in skip_type
]
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
cfg = Config.fromfile(config_path)
train_data_cfg = cfg.data.train
if isinstance(train_data_cfg, list):
for _data_cfg in train_data_cfg:
if 'pipeline' in _data_cfg:
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
elif 'dataset' in _data_cfg:
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
show_origin)
else:
raise ValueError
elif 'dataset' in train_data_cfg:
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
else:
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
return cfg
def main():
args = parse_args()
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
dataset = build_dataset(cfg.data.train)
progress_bar = mmcv.ProgressBar(len(dataset))
for item in dataset:
filename = os.path.join(args.output_dir,
Path(item['filename']).name
) if args.output_dir is not None else None
imshow_semantic(
item['img'],
item['gt_semantic_seg'],
dataset.CLASSES,
dataset.PALETTE,
show=args.show,
wait_time=args.show_interval,
out_file=filename,
opacity=args.opacity,
)
progress_bar.update()
if __name__ == '__main__':
main()

View File

@ -215,7 +215,8 @@ def main():
print(f'\nwriting results to {args.out}') print(f'\nwriting results to {args.out}')
mmcv.dump(results, args.out) mmcv.dump(results, args.out)
if args.eval: if args.eval:
metric = dataset.evaluate(results, args.eval, **eval_kwargs) eval_kwargs.update(metric=args.eval)
metric = dataset.evaluate(results, **eval_kwargs)
metric_dict = dict(config=args.config, metric=metric) metric_dict = dict(config=args.config, metric=metric)
if args.work_dir is not None and rank == 0: if args.work_dir is not None and rank == 0:
mmcv.dump(metric_dict, json_file, indent=4) mmcv.dump(metric_dict, json_file, indent=4)