[Feature] Support eval concate dataset and add tool to show dataset (#833)

* [Feature] Add tool to show origin or augmented train data

* [Feature] Support eval concate dataset

* Add docstring and modify evaluate of concate dataset

Signed-off-by: FreyWang <wangwxyz@qq.com>

* format concat dataset in subfolder of imgfile_prefix

Signed-off-by: FreyWang <wangwxyz@qq.com>

* add unittest of concate dataset

Signed-off-by: FreyWang <wangwxyz@qq.com>

* update unittest for eval dataset with CLASSES is None

Signed-off-by: FreyWang <wangwxyz@qq.com>

* [FIX] bug of generator,  which lead metric to nan when pre_eval=False

Signed-off-by: FreyWang <wangwxyz@qq.com>

* format code

Signed-off-by: FreyWang <wangwxyz@qq.com>

* add more unittest

* add more unittest

* optim concat dataset builder
pull/866/head
FreyWang 2021-09-09 13:00:23 +08:00 committed by GitHub
parent eb0baee414
commit 872e54497e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 645 additions and 45 deletions

View File

@ -112,8 +112,6 @@ def total_intersect_and_union(results,
ndarray: The prediction histogram on all classes.
ndarray: The ground truth histogram on all classes.
"""
num_imgs = len(results)
assert len(list(gt_seg_maps)) == num_imgs
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)

View File

@ -30,6 +30,8 @@ def _concat_dataset(cfg, default_args=None):
img_dir = cfg['img_dir']
ann_dir = cfg.get('ann_dir', None)
split = cfg.get('split', None)
# pop 'separate_eval' since it is not a valid key for common datasets.
separate_eval = cfg.pop('separate_eval', True)
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
if ann_dir is not None:
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
@ -57,7 +59,7 @@ def _concat_dataset(cfg, default_args=None):
data_cfg['split'] = split[i]
datasets.append(build_dataset(data_cfg, default_args))
return ConcatDataset(datasets)
return ConcatDataset(datasets, separate_eval)
def build_dataset(cfg, default_args=None):

View File

@ -2,7 +2,6 @@
import os.path as osp
import warnings
from collections import OrderedDict
from functools import reduce
import mmcv
import numpy as np
@ -99,6 +98,9 @@ class CustomDataset(Dataset):
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'
# join paths if data_root is specified
if self.data_root is not None:
@ -339,7 +341,12 @@ class CustomDataset(Dataset):
return palette
def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
def evaluate(self,
results,
metric='mIoU',
logger=None,
gt_seg_maps=None,
**kwargs):
"""Evaluate the dataset.
Args:
@ -350,6 +357,8 @@ class CustomDataset(Dataset):
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
used in ConcatDataset
Returns:
dict[str, float]: Default metrics.
@ -364,14 +373,9 @@ class CustomDataset(Dataset):
# test a list of files
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
gt_seg_maps = self.get_gt_seg_maps()
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
# reset generator
gt_seg_maps = self.get_gt_seg_maps()
if gt_seg_maps is None:
gt_seg_maps = self.get_gt_seg_maps()
num_classes = len(self.CLASSES)
ret_metrics = eval_metrics(
results,
gt_seg_maps,

View File

@ -1,7 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
from itertools import chain
import mmcv
import numpy as np
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS
from .cityscapes import CityscapesDataset
@DATASETS.register_module()
@ -9,16 +16,148 @@ class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
support evaluation and formatting results
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the concatenated
dataset results separately, Defaults to True.
"""
def __init__(self, datasets):
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
self.separate_eval = separate_eval
assert separate_eval in [True, False], \
f'separate_eval can only be True or False,' \
f'but get {separate_eval}'
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
raise NotImplementedError(
'Evaluating ConcatDataset containing CityscapesDataset'
'is not supported!')
def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[tuple[torch.Tensor]] | list[str]]): per image
pre_eval results or predict segmentation map for
computing evaluation metric.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]: evaluate results of the total dataset
or each separate
dataset if `self.separate_eval=True`.
"""
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'
if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
self.cumulative_sizes[dataset_idx]
end_idx = self.cumulative_sizes[dataset_idx + 1]
results_per_dataset = results[start_idx:end_idx]
print_log(
f'\nEvaluateing {dataset.img_dir} with '
f'{len(results_per_dataset)} images now',
logger=logger)
eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})
return total_eval_results
if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types when '
'self.separate_eval=False')
else:
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
# merge the generators of gt_seg_maps
gt_seg_maps = chain(
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
else:
# if the results are `pre_eval` results,
# we do not need gt_seg_maps to evaluate
gt_seg_maps = None
eval_results = self.datasets[0].evaluate(
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
return eval_results
def get_dataset_idx_and_sample_idx(self, indice):
"""Return dataset and sample index when given an indice of
ConcatDataset.
Args:
indice (int): indice of sample in ConcatDataset
Returns:
int: the index of sub dataset the sample belong to
int: the index of sample in its corresponding subset
"""
if indice < 0:
if -indice > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
indice = len(self) + indice
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
if dataset_idx == 0:
sample_idx = indice
else:
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""format result for every sample of ConcatDataset."""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].format_results(
[results[i]],
imgfile_prefix + f'/{dataset_idx}',
indices=[sample_idx],
**kwargs)
ret_res.append(res)
return sum(ret_res, [])
def pre_eval(self, preds, indices):
"""do pre eval for every sample of ConcatDataset."""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
ret_res.append(res)
return sum(ret_res, [])
@DATASETS.register_module()

View File

@ -6,12 +6,13 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
from PIL import Image
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, PascalVOCDataset,
RepeatDataset)
RepeatDataset, build_dataset)
def test_classes():
@ -143,7 +144,8 @@ def test_custom_dataset():
test_pipeline,
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
img_suffix='img.jpg',
test_mode=True)
test_mode=True,
classes=('pseudo_class', ))
assert len(test_dataset) == 5
# training data get
@ -164,30 +166,21 @@ def test_custom_dataset():
with pytest.raises(NotImplementedError):
test_dataset.format_results([], '')
# test past evaluation
pseudo_results = []
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
# test past evaluation without CLASSES
with pytest.raises(TypeError):
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
eval_results = train_dataset.evaluate(
pseudo_results, metric=['mDice', 'mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
with pytest.raises(TypeError):
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
with pytest.raises(TypeError):
eval_results = train_dataset.evaluate(
pseudo_results, metric=['mDice', 'mIoU'])
# test past evaluation with CLASSES
train_dataset.CLASSES = tuple(['a'] * 7)
@ -221,6 +214,14 @@ def test_custom_dataset():
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results
assert not np.isnan(eval_results['mIoU'])
assert not np.isnan(eval_results['mDice'])
assert not np.isnan(eval_results['mAcc'])
assert not np.isnan(eval_results['aAcc'])
assert not np.isnan(eval_results['mFscore'])
assert not np.isnan(eval_results['mPrecision'])
assert not np.isnan(eval_results['mRecall'])
# test evaluation with pre-eval and the dataset.CLASSES is necessary
train_dataset.CLASSES = tuple(['a'] * 7)
pseudo_results = []
@ -258,6 +259,223 @@ def test_custom_dataset():
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results
assert not np.isnan(eval_results['mIoU'])
assert not np.isnan(eval_results['mDice'])
assert not np.isnan(eval_results['mAcc'])
assert not np.isnan(eval_results['aAcc'])
assert not np.isnan(eval_results['mFscore'])
assert not np.isnan(eval_results['mPrecision'])
assert not np.isnan(eval_results['mRecall'])
@pytest.mark.parametrize('separate_eval', [True, False])
def test_eval_concat_custom_dataset(separate_eval):
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(128, 256),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
img_dir = 'imgs/'
ann_dir = 'gts/'
cfg1 = dict(
type='CustomDataset',
pipeline=test_pipeline,
data_root=data_root,
img_dir=img_dir,
ann_dir=ann_dir,
img_suffix='img.jpg',
seg_map_suffix='gt.png',
classes=tuple(['a'] * 7))
dataset1 = build_dataset(cfg1)
assert len(dataset1) == 5
# get gt seg map
gt_seg_maps = dataset1.get_gt_seg_maps(efficient_test=True)
assert isinstance(gt_seg_maps, Generator)
gt_seg_maps = list(gt_seg_maps)
assert len(gt_seg_maps) == 5
# test past evaluation
pseudo_results = []
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results1 = dataset1.evaluate(
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
# We use same dir twice for simplicity
# with ann_dir
cfg2 = dict(
type='CustomDataset',
pipeline=test_pipeline,
data_root=data_root,
img_dir=[img_dir, img_dir],
ann_dir=[ann_dir, ann_dir],
img_suffix='img.jpg',
seg_map_suffix='gt.png',
classes=tuple(['a'] * 7),
separate_eval=separate_eval)
dataset2 = build_dataset(cfg2)
assert isinstance(dataset2, ConcatDataset)
assert len(dataset2) == 10
eval_results2 = dataset2.evaluate(
pseudo_results * 2, metric=['mIoU', 'mDice', 'mFscore'])
if separate_eval:
assert eval_results1['mIoU'] == eval_results2[
'0_mIoU'] == eval_results2['1_mIoU']
assert eval_results1['mDice'] == eval_results2[
'0_mDice'] == eval_results2['1_mDice']
assert eval_results1['mAcc'] == eval_results2[
'0_mAcc'] == eval_results2['1_mAcc']
assert eval_results1['aAcc'] == eval_results2[
'0_aAcc'] == eval_results2['1_aAcc']
assert eval_results1['mFscore'] == eval_results2[
'0_mFscore'] == eval_results2['1_mFscore']
assert eval_results1['mPrecision'] == eval_results2[
'0_mPrecision'] == eval_results2['1_mPrecision']
assert eval_results1['mRecall'] == eval_results2[
'0_mRecall'] == eval_results2['1_mRecall']
else:
assert eval_results1['mIoU'] == eval_results2['mIoU']
assert eval_results1['mDice'] == eval_results2['mDice']
assert eval_results1['mAcc'] == eval_results2['mAcc']
assert eval_results1['aAcc'] == eval_results2['aAcc']
assert eval_results1['mFscore'] == eval_results2['mFscore']
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
assert eval_results1['mRecall'] == eval_results2['mRecall']
# test get dataset_idx and sample_idx from ConcateDataset
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(3)
assert dataset_idx == 0
assert sample_idx == 3
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(7)
assert dataset_idx == 1
assert sample_idx == 2
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-7)
assert dataset_idx == 0
assert sample_idx == 3
# test negative indice exceed length of dataset
with pytest.raises(ValueError):
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-11)
# test negative indice value
indice = -6
dataset_idx1, sample_idx1 = dataset2.get_dataset_idx_and_sample_idx(indice)
dataset_idx2, sample_idx2 = dataset2.get_dataset_idx_and_sample_idx(
len(dataset2) + indice)
assert dataset_idx1 == dataset_idx2
assert sample_idx1 == sample_idx2
# test evaluation with pre-eval and the dataset.CLASSES is necessary
pseudo_results = []
eval_results1 = []
for idx in range(len(dataset1)):
h, w = gt_seg_maps[idx].shape
pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
pseudo_results.append(pseudo_result)
eval_results1.extend(dataset1.pre_eval(pseudo_result, idx))
assert len(eval_results1) == len(dataset1)
assert isinstance(eval_results1[0], tuple)
assert len(eval_results1[0]) == 4
assert isinstance(eval_results1[0][0], torch.Tensor)
eval_results1 = dataset1.evaluate(
eval_results1, metric=['mIoU', 'mDice', 'mFscore'])
pseudo_results = pseudo_results * 2
eval_results2 = []
for idx in range(len(dataset2)):
eval_results2.extend(dataset2.pre_eval(pseudo_results[idx], idx))
assert len(eval_results2) == len(dataset2)
assert isinstance(eval_results2[0], tuple)
assert len(eval_results2[0]) == 4
assert isinstance(eval_results2[0][0], torch.Tensor)
eval_results2 = dataset2.evaluate(
eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
if separate_eval:
assert eval_results1['mIoU'] == eval_results2[
'0_mIoU'] == eval_results2['1_mIoU']
assert eval_results1['mDice'] == eval_results2[
'0_mDice'] == eval_results2['1_mDice']
assert eval_results1['mAcc'] == eval_results2[
'0_mAcc'] == eval_results2['1_mAcc']
assert eval_results1['aAcc'] == eval_results2[
'0_aAcc'] == eval_results2['1_aAcc']
assert eval_results1['mFscore'] == eval_results2[
'0_mFscore'] == eval_results2['1_mFscore']
assert eval_results1['mPrecision'] == eval_results2[
'0_mPrecision'] == eval_results2['1_mPrecision']
assert eval_results1['mRecall'] == eval_results2[
'0_mRecall'] == eval_results2['1_mRecall']
else:
assert eval_results1['mIoU'] == eval_results2['mIoU']
assert eval_results1['mDice'] == eval_results2['mDice']
assert eval_results1['mAcc'] == eval_results2['mAcc']
assert eval_results1['aAcc'] == eval_results2['aAcc']
assert eval_results1['mFscore'] == eval_results2['mFscore']
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
assert eval_results1['mRecall'] == eval_results2['mRecall']
# test batch_indices for pre eval
eval_results2 = dataset2.pre_eval(pseudo_results,
list(range(len(pseudo_results))))
assert len(eval_results2) == len(dataset2)
assert isinstance(eval_results2[0], tuple)
assert len(eval_results2[0]) == 4
assert isinstance(eval_results2[0][0], torch.Tensor)
eval_results2 = dataset2.evaluate(
eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
if separate_eval:
assert eval_results1['mIoU'] == eval_results2[
'0_mIoU'] == eval_results2['1_mIoU']
assert eval_results1['mDice'] == eval_results2[
'0_mDice'] == eval_results2['1_mDice']
assert eval_results1['mAcc'] == eval_results2[
'0_mAcc'] == eval_results2['1_mAcc']
assert eval_results1['aAcc'] == eval_results2[
'0_aAcc'] == eval_results2['1_aAcc']
assert eval_results1['mFscore'] == eval_results2[
'0_mFscore'] == eval_results2['1_mFscore']
assert eval_results1['mPrecision'] == eval_results2[
'0_mPrecision'] == eval_results2['1_mPrecision']
assert eval_results1['mRecall'] == eval_results2[
'0_mRecall'] == eval_results2['1_mRecall']
else:
assert eval_results1['mIoU'] == eval_results2['mIoU']
assert eval_results1['mDice'] == eval_results2['mDice']
assert eval_results1['mAcc'] == eval_results2['mAcc']
assert eval_results1['aAcc'] == eval_results2['aAcc']
assert eval_results1['mFscore'] == eval_results2['mFscore']
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
assert eval_results1['mRecall'] == eval_results2['mRecall']
def test_ade():
test_dataset = ADE20KDataset(
@ -279,6 +497,44 @@ def test_ade():
shutil.rmtree('.format_ade')
@pytest.mark.parametrize('separate_eval', [True, False])
def test_concat_ade(separate_eval):
test_dataset = ADE20KDataset(
pipeline=[],
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
assert len(test_dataset) == 5
concat_dataset = ConcatDataset([test_dataset, test_dataset],
separate_eval=separate_eval)
assert len(concat_dataset) == 10
# Test format_results
pseudo_results = []
for _ in range(len(concat_dataset)):
h, w = (2, 2)
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
# test format per image
file_paths = []
for i in range(len(pseudo_results)):
file_paths.extend(
concat_dataset.format_results([pseudo_results[i]],
'.format_ade',
indices=[i]))
assert len(file_paths) == len(concat_dataset)
temp = np.array(Image.open(file_paths[0]))
assert np.allclose(temp, pseudo_results[0] + 1)
shutil.rmtree('.format_ade')
# test default argument
file_paths = concat_dataset.format_results(pseudo_results, '.format_ade')
assert len(file_paths) == len(concat_dataset)
temp = np.array(Image.open(file_paths[0]))
assert np.allclose(temp, pseudo_results[0] + 1)
shutil.rmtree('.format_ade')
def test_cityscapes():
test_dataset = CityscapesDataset(
pipeline=[],
@ -311,6 +567,28 @@ def test_cityscapes():
shutil.rmtree('.format_city')
@pytest.mark.parametrize('separate_eval', [True, False])
def test_concat_cityscapes(separate_eval):
cityscape_dataset = CityscapesDataset(
pipeline=[],
img_dir=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
assert len(cityscape_dataset) == 1
with pytest.raises(NotImplementedError):
_ = ConcatDataset([cityscape_dataset, cityscape_dataset],
separate_eval=separate_eval)
ade_dataset = ADE20KDataset(
pipeline=[],
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
assert len(ade_dataset) == 5
with pytest.raises(NotImplementedError):
_ = ConcatDataset([cityscape_dataset, ade_dataset],
separate_eval=separate_eval)
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
@ -360,14 +638,23 @@ def test_custom_classes_override_default(dataset, classes):
assert custom_dataset.CLASSES == [classes[0]]
# Test default behavior
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)
if dataset_class is CustomDataset:
with pytest.raises(AssertionError):
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)
else:
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)
assert custom_dataset.CLASSES == original_classes
assert custom_dataset.CLASSES == original_classes
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)

View File

@ -78,7 +78,8 @@ def test_build_dataset():
pipeline=[],
data_root=data_root,
img_dir=[img_dir, img_dir],
test_mode=True)
test_mode=True,
classes=('pseudo_class', ))
dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 10
@ -90,7 +91,8 @@ def test_build_dataset():
data_root=data_root,
img_dir=[img_dir, img_dir],
split=['splits/val.txt', 'splits/val.txt'],
test_mode=True)
test_mode=True,
classes=('pseudo_class', ))
dataset = build_dataset(cfg)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 2

View File

@ -0,0 +1,167 @@
import argparse
import os
import warnings
from pathlib import Path
import mmcv
import numpy as np
from mmcv import Config
from mmseg.datasets.builder import build_dataset
def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--show-origin',
default=False,
action='store_true',
help='if True, omit all augmentation in pipeline,'
' show origin image and seg map')
parser.add_argument(
'--skip-type',
type=str,
nargs='+',
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
help='skip some useless pipelineif `show-origin` is true, '
'all pipeline except `Load` will be skipped')
parser.add_argument(
'--output-dir',
default='./output',
type=str,
help='If there is no display interface, you can save it')
parser.add_argument('--show', default=False, action='store_true')
parser.add_argument(
'--show-interval',
type=int,
default=999,
help='the interval of show (ms)')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='the opacity of semantic map')
args = parser.parse_args()
return args
def imshow_semantic(img,
seg,
class_names,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
seg (Tensor): The semantic segmentation results to draw over
`img`.
class_names (list[str]): Names of each classes.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
if palette is None:
palette = np.random.randint(0, 255, size=(len(class_names), 3))
palette = np.array(palette)
assert palette.shape[0] == len(class_names)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
if show_origin is True:
# only keep pipeline of Loading data and ann
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if 'Load' in x['type']
]
else:
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if x['type'] not in skip_type
]
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
cfg = Config.fromfile(config_path)
train_data_cfg = cfg.data.train
if isinstance(train_data_cfg, list):
for _data_cfg in train_data_cfg:
if 'pipeline' in _data_cfg:
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
elif 'dataset' in _data_cfg:
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
show_origin)
else:
raise ValueError
elif 'dataset' in train_data_cfg:
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
else:
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
return cfg
def main():
args = parse_args()
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
dataset = build_dataset(cfg.data.train)
progress_bar = mmcv.ProgressBar(len(dataset))
for item in dataset:
filename = os.path.join(args.output_dir,
Path(item['filename']).name
) if args.output_dir is not None else None
imshow_semantic(
item['img'],
item['gt_semantic_seg'],
dataset.CLASSES,
dataset.PALETTE,
show=args.show,
wait_time=args.show_interval,
out_file=filename,
opacity=args.opacity,
)
progress_bar.update()
if __name__ == '__main__':
main()

View File

@ -215,7 +215,8 @@ def main():
print(f'\nwriting results to {args.out}')
mmcv.dump(results, args.out)
if args.eval:
metric = dataset.evaluate(results, args.eval, **eval_kwargs)
eval_kwargs.update(metric=args.eval)
metric = dataset.evaluate(results, **eval_kwargs)
metric_dict = dict(config=args.config, metric=metric)
if args.work_dir is not None and rank == 0:
mmcv.dump(metric_dict, json_file, indent=4)