[Feature] Support eval concate dataset and add tool to show dataset (#833)
* [Feature] Add tool to show origin or augmented train data * [Feature] Support eval concate dataset * Add docstring and modify evaluate of concate dataset Signed-off-by: FreyWang <wangwxyz@qq.com> * format concat dataset in subfolder of imgfile_prefix Signed-off-by: FreyWang <wangwxyz@qq.com> * add unittest of concate dataset Signed-off-by: FreyWang <wangwxyz@qq.com> * update unittest for eval dataset with CLASSES is None Signed-off-by: FreyWang <wangwxyz@qq.com> * [FIX] bug of generator, which lead metric to nan when pre_eval=False Signed-off-by: FreyWang <wangwxyz@qq.com> * format code Signed-off-by: FreyWang <wangwxyz@qq.com> * add more unittest * add more unittest * optim concat dataset builderpull/866/head
parent
eb0baee414
commit
872e54497e
|
@ -112,8 +112,6 @@ def total_intersect_and_union(results,
|
|||
ndarray: The prediction histogram on all classes.
|
||||
ndarray: The ground truth histogram on all classes.
|
||||
"""
|
||||
num_imgs = len(results)
|
||||
assert len(list(gt_seg_maps)) == num_imgs
|
||||
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
|
||||
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
|
||||
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
||||
|
|
|
@ -30,6 +30,8 @@ def _concat_dataset(cfg, default_args=None):
|
|||
img_dir = cfg['img_dir']
|
||||
ann_dir = cfg.get('ann_dir', None)
|
||||
split = cfg.get('split', None)
|
||||
# pop 'separate_eval' since it is not a valid key for common datasets.
|
||||
separate_eval = cfg.pop('separate_eval', True)
|
||||
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
|
||||
if ann_dir is not None:
|
||||
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
|
||||
|
@ -57,7 +59,7 @@ def _concat_dataset(cfg, default_args=None):
|
|||
data_cfg['split'] = split[i]
|
||||
datasets.append(build_dataset(data_cfg, default_args))
|
||||
|
||||
return ConcatDataset(datasets)
|
||||
return ConcatDataset(datasets, separate_eval)
|
||||
|
||||
|
||||
def build_dataset(cfg, default_args=None):
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
import os.path as osp
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -99,6 +98,9 @@ class CustomDataset(Dataset):
|
|||
self.label_map = None
|
||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
||||
classes, palette)
|
||||
if test_mode:
|
||||
assert self.CLASSES is not None, \
|
||||
'`cls.CLASSES` or `classes` should be specified when testing'
|
||||
|
||||
# join paths if data_root is specified
|
||||
if self.data_root is not None:
|
||||
|
@ -339,7 +341,12 @@ class CustomDataset(Dataset):
|
|||
|
||||
return palette
|
||||
|
||||
def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='mIoU',
|
||||
logger=None,
|
||||
gt_seg_maps=None,
|
||||
**kwargs):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -350,6 +357,8 @@ class CustomDataset(Dataset):
|
|||
'mDice' and 'mFscore' are supported.
|
||||
logger (logging.Logger | None | str): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
|
||||
used in ConcatDataset
|
||||
|
||||
Returns:
|
||||
dict[str, float]: Default metrics.
|
||||
|
@ -364,14 +373,9 @@ class CustomDataset(Dataset):
|
|||
# test a list of files
|
||||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
|
||||
results, str):
|
||||
gt_seg_maps = self.get_gt_seg_maps()
|
||||
if self.CLASSES is None:
|
||||
num_classes = len(
|
||||
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
|
||||
else:
|
||||
num_classes = len(self.CLASSES)
|
||||
# reset generator
|
||||
gt_seg_maps = self.get_gt_seg_maps()
|
||||
if gt_seg_maps is None:
|
||||
gt_seg_maps = self.get_gt_seg_maps()
|
||||
num_classes = len(self.CLASSES)
|
||||
ret_metrics = eval_metrics(
|
||||
results,
|
||||
gt_seg_maps,
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import bisect
|
||||
from itertools import chain
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from .builder import DATASETS
|
||||
from .cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -9,16 +16,148 @@ class ConcatDataset(_ConcatDataset):
|
|||
"""A wrapper of concatenated dataset.
|
||||
|
||||
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
|
||||
concat the group flag for image aspect ratio.
|
||||
support evaluation and formatting results
|
||||
|
||||
Args:
|
||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
||||
separate_eval (bool): Whether to evaluate the concatenated
|
||||
dataset results separately, Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets):
|
||||
def __init__(self, datasets, separate_eval=True):
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
self.CLASSES = datasets[0].CLASSES
|
||||
self.PALETTE = datasets[0].PALETTE
|
||||
self.separate_eval = separate_eval
|
||||
assert separate_eval in [True, False], \
|
||||
f'separate_eval can only be True or False,' \
|
||||
f'but get {separate_eval}'
|
||||
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
|
||||
raise NotImplementedError(
|
||||
'Evaluating ConcatDataset containing CityscapesDataset'
|
||||
'is not supported!')
|
||||
|
||||
def evaluate(self, results, logger=None, **kwargs):
|
||||
"""Evaluate the results.
|
||||
|
||||
Args:
|
||||
results (list[tuple[torch.Tensor]] | list[str]]): per image
|
||||
pre_eval results or predict segmentation map for
|
||||
computing evaluation metric.
|
||||
logger (logging.Logger | str | None): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
|
||||
Returns:
|
||||
dict[str: float]: evaluate results of the total dataset
|
||||
or each separate
|
||||
dataset if `self.separate_eval=True`.
|
||||
"""
|
||||
assert len(results) == self.cumulative_sizes[-1], \
|
||||
('Dataset and results have different sizes: '
|
||||
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
|
||||
|
||||
# Check whether all the datasets support evaluation
|
||||
for dataset in self.datasets:
|
||||
assert hasattr(dataset, 'evaluate'), \
|
||||
f'{type(dataset)} does not implement evaluate function'
|
||||
|
||||
if self.separate_eval:
|
||||
dataset_idx = -1
|
||||
total_eval_results = dict()
|
||||
for size, dataset in zip(self.cumulative_sizes, self.datasets):
|
||||
start_idx = 0 if dataset_idx == -1 else \
|
||||
self.cumulative_sizes[dataset_idx]
|
||||
end_idx = self.cumulative_sizes[dataset_idx + 1]
|
||||
|
||||
results_per_dataset = results[start_idx:end_idx]
|
||||
print_log(
|
||||
f'\nEvaluateing {dataset.img_dir} with '
|
||||
f'{len(results_per_dataset)} images now',
|
||||
logger=logger)
|
||||
|
||||
eval_results_per_dataset = dataset.evaluate(
|
||||
results_per_dataset, logger=logger, **kwargs)
|
||||
dataset_idx += 1
|
||||
for k, v in eval_results_per_dataset.items():
|
||||
total_eval_results.update({f'{dataset_idx}_{k}': v})
|
||||
|
||||
return total_eval_results
|
||||
|
||||
if len(set([type(ds) for ds in self.datasets])) != 1:
|
||||
raise NotImplementedError(
|
||||
'All the datasets should have same types when '
|
||||
'self.separate_eval=False')
|
||||
else:
|
||||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
|
||||
results, str):
|
||||
# merge the generators of gt_seg_maps
|
||||
gt_seg_maps = chain(
|
||||
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
|
||||
else:
|
||||
# if the results are `pre_eval` results,
|
||||
# we do not need gt_seg_maps to evaluate
|
||||
gt_seg_maps = None
|
||||
eval_results = self.datasets[0].evaluate(
|
||||
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
|
||||
return eval_results
|
||||
|
||||
def get_dataset_idx_and_sample_idx(self, indice):
|
||||
"""Return dataset and sample index when given an indice of
|
||||
ConcatDataset.
|
||||
|
||||
Args:
|
||||
indice (int): indice of sample in ConcatDataset
|
||||
|
||||
Returns:
|
||||
int: the index of sub dataset the sample belong to
|
||||
int: the index of sample in its corresponding subset
|
||||
"""
|
||||
if indice < 0:
|
||||
if -indice > len(self):
|
||||
raise ValueError(
|
||||
'absolute value of index should not exceed dataset length')
|
||||
indice = len(self) + indice
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = indice
|
||||
else:
|
||||
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
|
||||
return dataset_idx, sample_idx
|
||||
|
||||
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
|
||||
"""format result for every sample of ConcatDataset."""
|
||||
if indices is None:
|
||||
indices = list(range(len(self)))
|
||||
|
||||
assert isinstance(results, list), 'results must be a list.'
|
||||
assert isinstance(indices, list), 'indices must be a list.'
|
||||
|
||||
ret_res = []
|
||||
for i, indice in enumerate(indices):
|
||||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
|
||||
indice)
|
||||
res = self.datasets[dataset_idx].format_results(
|
||||
[results[i]],
|
||||
imgfile_prefix + f'/{dataset_idx}',
|
||||
indices=[sample_idx],
|
||||
**kwargs)
|
||||
ret_res.append(res)
|
||||
return sum(ret_res, [])
|
||||
|
||||
def pre_eval(self, preds, indices):
|
||||
"""do pre eval for every sample of ConcatDataset."""
|
||||
# In order to compat with batch inference
|
||||
if not isinstance(indices, list):
|
||||
indices = [indices]
|
||||
if not isinstance(preds, list):
|
||||
preds = [preds]
|
||||
ret_res = []
|
||||
for i, indice in enumerate(indices):
|
||||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
|
||||
indice)
|
||||
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
|
||||
ret_res.append(res)
|
||||
return sum(ret_res, [])
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
|
|
@ -6,12 +6,13 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.core.evaluation import get_classes, get_palette
|
||||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||
ConcatDataset, CustomDataset, PascalVOCDataset,
|
||||
RepeatDataset)
|
||||
RepeatDataset, build_dataset)
|
||||
|
||||
|
||||
def test_classes():
|
||||
|
@ -143,7 +144,8 @@ def test_custom_dataset():
|
|||
test_pipeline,
|
||||
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||
img_suffix='img.jpg',
|
||||
test_mode=True)
|
||||
test_mode=True,
|
||||
classes=('pseudo_class', ))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
# training data get
|
||||
|
@ -164,30 +166,21 @@ def test_custom_dataset():
|
|||
with pytest.raises(NotImplementedError):
|
||||
test_dataset.format_results([], '')
|
||||
|
||||
# test past evaluation
|
||||
pseudo_results = []
|
||||
for gt_seg_map in gt_seg_maps:
|
||||
h, w = gt_seg_map.shape
|
||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mIoU' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mDice' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
# test past evaluation without CLASSES
|
||||
with pytest.raises(TypeError):
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
|
||||
|
||||
eval_results = train_dataset.evaluate(
|
||||
pseudo_results, metric=['mDice', 'mIoU'])
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mIoU' in eval_results
|
||||
assert 'mDice' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
with pytest.raises(TypeError):
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
eval_results = train_dataset.evaluate(
|
||||
pseudo_results, metric=['mDice', 'mIoU'])
|
||||
|
||||
# test past evaluation with CLASSES
|
||||
train_dataset.CLASSES = tuple(['a'] * 7)
|
||||
|
@ -221,6 +214,14 @@ def test_custom_dataset():
|
|||
assert 'mPrecision' in eval_results
|
||||
assert 'mRecall' in eval_results
|
||||
|
||||
assert not np.isnan(eval_results['mIoU'])
|
||||
assert not np.isnan(eval_results['mDice'])
|
||||
assert not np.isnan(eval_results['mAcc'])
|
||||
assert not np.isnan(eval_results['aAcc'])
|
||||
assert not np.isnan(eval_results['mFscore'])
|
||||
assert not np.isnan(eval_results['mPrecision'])
|
||||
assert not np.isnan(eval_results['mRecall'])
|
||||
|
||||
# test evaluation with pre-eval and the dataset.CLASSES is necessary
|
||||
train_dataset.CLASSES = tuple(['a'] * 7)
|
||||
pseudo_results = []
|
||||
|
@ -258,6 +259,223 @@ def test_custom_dataset():
|
|||
assert 'mPrecision' in eval_results
|
||||
assert 'mRecall' in eval_results
|
||||
|
||||
assert not np.isnan(eval_results['mIoU'])
|
||||
assert not np.isnan(eval_results['mDice'])
|
||||
assert not np.isnan(eval_results['mAcc'])
|
||||
assert not np.isnan(eval_results['aAcc'])
|
||||
assert not np.isnan(eval_results['mFscore'])
|
||||
assert not np.isnan(eval_results['mPrecision'])
|
||||
assert not np.isnan(eval_results['mRecall'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('separate_eval', [True, False])
|
||||
def test_eval_concat_custom_dataset(separate_eval):
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(128, 256),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
|
||||
img_dir = 'imgs/'
|
||||
ann_dir = 'gts/'
|
||||
|
||||
cfg1 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
img_dir=img_dir,
|
||||
ann_dir=ann_dir,
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png',
|
||||
classes=tuple(['a'] * 7))
|
||||
dataset1 = build_dataset(cfg1)
|
||||
assert len(dataset1) == 5
|
||||
# get gt seg map
|
||||
gt_seg_maps = dataset1.get_gt_seg_maps(efficient_test=True)
|
||||
assert isinstance(gt_seg_maps, Generator)
|
||||
gt_seg_maps = list(gt_seg_maps)
|
||||
assert len(gt_seg_maps) == 5
|
||||
|
||||
# test past evaluation
|
||||
pseudo_results = []
|
||||
for gt_seg_map in gt_seg_maps:
|
||||
h, w = gt_seg_map.shape
|
||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||
eval_results1 = dataset1.evaluate(
|
||||
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
# We use same dir twice for simplicity
|
||||
# with ann_dir
|
||||
cfg2 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
ann_dir=[ann_dir, ann_dir],
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png',
|
||||
classes=tuple(['a'] * 7),
|
||||
separate_eval=separate_eval)
|
||||
dataset2 = build_dataset(cfg2)
|
||||
assert isinstance(dataset2, ConcatDataset)
|
||||
assert len(dataset2) == 10
|
||||
|
||||
eval_results2 = dataset2.evaluate(
|
||||
pseudo_results * 2, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
if separate_eval:
|
||||
assert eval_results1['mIoU'] == eval_results2[
|
||||
'0_mIoU'] == eval_results2['1_mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2[
|
||||
'0_mDice'] == eval_results2['1_mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2[
|
||||
'0_mAcc'] == eval_results2['1_mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2[
|
||||
'0_aAcc'] == eval_results2['1_aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2[
|
||||
'0_mFscore'] == eval_results2['1_mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2[
|
||||
'0_mPrecision'] == eval_results2['1_mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2[
|
||||
'0_mRecall'] == eval_results2['1_mRecall']
|
||||
else:
|
||||
assert eval_results1['mIoU'] == eval_results2['mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2['mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2['mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2['aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2['mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2['mRecall']
|
||||
|
||||
# test get dataset_idx and sample_idx from ConcateDataset
|
||||
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(3)
|
||||
assert dataset_idx == 0
|
||||
assert sample_idx == 3
|
||||
|
||||
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(7)
|
||||
assert dataset_idx == 1
|
||||
assert sample_idx == 2
|
||||
|
||||
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-7)
|
||||
assert dataset_idx == 0
|
||||
assert sample_idx == 3
|
||||
|
||||
# test negative indice exceed length of dataset
|
||||
with pytest.raises(ValueError):
|
||||
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-11)
|
||||
|
||||
# test negative indice value
|
||||
indice = -6
|
||||
dataset_idx1, sample_idx1 = dataset2.get_dataset_idx_and_sample_idx(indice)
|
||||
dataset_idx2, sample_idx2 = dataset2.get_dataset_idx_and_sample_idx(
|
||||
len(dataset2) + indice)
|
||||
assert dataset_idx1 == dataset_idx2
|
||||
assert sample_idx1 == sample_idx2
|
||||
|
||||
# test evaluation with pre-eval and the dataset.CLASSES is necessary
|
||||
pseudo_results = []
|
||||
eval_results1 = []
|
||||
for idx in range(len(dataset1)):
|
||||
h, w = gt_seg_maps[idx].shape
|
||||
pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
|
||||
pseudo_results.append(pseudo_result)
|
||||
eval_results1.extend(dataset1.pre_eval(pseudo_result, idx))
|
||||
|
||||
assert len(eval_results1) == len(dataset1)
|
||||
assert isinstance(eval_results1[0], tuple)
|
||||
assert len(eval_results1[0]) == 4
|
||||
assert isinstance(eval_results1[0][0], torch.Tensor)
|
||||
|
||||
eval_results1 = dataset1.evaluate(
|
||||
eval_results1, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
pseudo_results = pseudo_results * 2
|
||||
eval_results2 = []
|
||||
for idx in range(len(dataset2)):
|
||||
eval_results2.extend(dataset2.pre_eval(pseudo_results[idx], idx))
|
||||
|
||||
assert len(eval_results2) == len(dataset2)
|
||||
assert isinstance(eval_results2[0], tuple)
|
||||
assert len(eval_results2[0]) == 4
|
||||
assert isinstance(eval_results2[0][0], torch.Tensor)
|
||||
|
||||
eval_results2 = dataset2.evaluate(
|
||||
eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
if separate_eval:
|
||||
assert eval_results1['mIoU'] == eval_results2[
|
||||
'0_mIoU'] == eval_results2['1_mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2[
|
||||
'0_mDice'] == eval_results2['1_mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2[
|
||||
'0_mAcc'] == eval_results2['1_mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2[
|
||||
'0_aAcc'] == eval_results2['1_aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2[
|
||||
'0_mFscore'] == eval_results2['1_mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2[
|
||||
'0_mPrecision'] == eval_results2['1_mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2[
|
||||
'0_mRecall'] == eval_results2['1_mRecall']
|
||||
else:
|
||||
assert eval_results1['mIoU'] == eval_results2['mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2['mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2['mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2['aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2['mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2['mRecall']
|
||||
|
||||
# test batch_indices for pre eval
|
||||
eval_results2 = dataset2.pre_eval(pseudo_results,
|
||||
list(range(len(pseudo_results))))
|
||||
|
||||
assert len(eval_results2) == len(dataset2)
|
||||
assert isinstance(eval_results2[0], tuple)
|
||||
assert len(eval_results2[0]) == 4
|
||||
assert isinstance(eval_results2[0][0], torch.Tensor)
|
||||
|
||||
eval_results2 = dataset2.evaluate(
|
||||
eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
if separate_eval:
|
||||
assert eval_results1['mIoU'] == eval_results2[
|
||||
'0_mIoU'] == eval_results2['1_mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2[
|
||||
'0_mDice'] == eval_results2['1_mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2[
|
||||
'0_mAcc'] == eval_results2['1_mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2[
|
||||
'0_aAcc'] == eval_results2['1_aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2[
|
||||
'0_mFscore'] == eval_results2['1_mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2[
|
||||
'0_mPrecision'] == eval_results2['1_mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2[
|
||||
'0_mRecall'] == eval_results2['1_mRecall']
|
||||
else:
|
||||
assert eval_results1['mIoU'] == eval_results2['mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2['mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2['mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2['aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2['mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2['mRecall']
|
||||
|
||||
|
||||
def test_ade():
|
||||
test_dataset = ADE20KDataset(
|
||||
|
@ -279,6 +497,44 @@ def test_ade():
|
|||
shutil.rmtree('.format_ade')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('separate_eval', [True, False])
|
||||
def test_concat_ade(separate_eval):
|
||||
test_dataset = ADE20KDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
concat_dataset = ConcatDataset([test_dataset, test_dataset],
|
||||
separate_eval=separate_eval)
|
||||
assert len(concat_dataset) == 10
|
||||
# Test format_results
|
||||
pseudo_results = []
|
||||
for _ in range(len(concat_dataset)):
|
||||
h, w = (2, 2)
|
||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||
|
||||
# test format per image
|
||||
file_paths = []
|
||||
for i in range(len(pseudo_results)):
|
||||
file_paths.extend(
|
||||
concat_dataset.format_results([pseudo_results[i]],
|
||||
'.format_ade',
|
||||
indices=[i]))
|
||||
assert len(file_paths) == len(concat_dataset)
|
||||
temp = np.array(Image.open(file_paths[0]))
|
||||
assert np.allclose(temp, pseudo_results[0] + 1)
|
||||
|
||||
shutil.rmtree('.format_ade')
|
||||
|
||||
# test default argument
|
||||
file_paths = concat_dataset.format_results(pseudo_results, '.format_ade')
|
||||
assert len(file_paths) == len(concat_dataset)
|
||||
temp = np.array(Image.open(file_paths[0]))
|
||||
assert np.allclose(temp, pseudo_results[0] + 1)
|
||||
|
||||
shutil.rmtree('.format_ade')
|
||||
|
||||
|
||||
def test_cityscapes():
|
||||
test_dataset = CityscapesDataset(
|
||||
pipeline=[],
|
||||
|
@ -311,6 +567,28 @@ def test_cityscapes():
|
|||
shutil.rmtree('.format_city')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('separate_eval', [True, False])
|
||||
def test_concat_cityscapes(separate_eval):
|
||||
cityscape_dataset = CityscapesDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
|
||||
ann_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
|
||||
assert len(cityscape_dataset) == 1
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = ConcatDataset([cityscape_dataset, cityscape_dataset],
|
||||
separate_eval=separate_eval)
|
||||
ade_dataset = ADE20KDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
|
||||
assert len(ade_dataset) == 5
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = ConcatDataset([cityscape_dataset, ade_dataset],
|
||||
separate_eval=separate_eval)
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||
MagicMock(side_effect=lambda idx: idx))
|
||||
|
@ -360,14 +638,23 @@ def test_custom_classes_override_default(dataset, classes):
|
|||
assert custom_dataset.CLASSES == [classes[0]]
|
||||
|
||||
# Test default behavior
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=None,
|
||||
test_mode=True)
|
||||
if dataset_class is CustomDataset:
|
||||
with pytest.raises(AssertionError):
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=None,
|
||||
test_mode=True)
|
||||
else:
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=None,
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES == original_classes
|
||||
assert custom_dataset.CLASSES == original_classes
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
|
|
|
@ -78,7 +78,8 @@ def test_build_dataset():
|
|||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
test_mode=True)
|
||||
test_mode=True,
|
||||
classes=('pseudo_class', ))
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, ConcatDataset)
|
||||
assert len(dataset) == 10
|
||||
|
@ -90,7 +91,8 @@ def test_build_dataset():
|
|||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
split=['splits/val.txt', 'splits/val.txt'],
|
||||
test_mode=True)
|
||||
test_mode=True,
|
||||
classes=('pseudo_class', ))
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, ConcatDataset)
|
||||
assert len(dataset) == 2
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv import Config
|
||||
|
||||
from mmseg.datasets.builder import build_dataset
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Browse a dataset')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument(
|
||||
'--show-origin',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='if True, omit all augmentation in pipeline,'
|
||||
' show origin image and seg map')
|
||||
parser.add_argument(
|
||||
'--skip-type',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
|
||||
help='skip some useless pipeline,if `show-origin` is true, '
|
||||
'all pipeline except `Load` will be skipped')
|
||||
parser.add_argument(
|
||||
'--output-dir',
|
||||
default='./output',
|
||||
type=str,
|
||||
help='If there is no display interface, you can save it')
|
||||
parser.add_argument('--show', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--show-interval',
|
||||
type=int,
|
||||
default=999,
|
||||
help='the interval of show (ms)')
|
||||
parser.add_argument(
|
||||
'--opacity',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='the opacity of semantic map')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def imshow_semantic(img,
|
||||
seg,
|
||||
class_names,
|
||||
palette=None,
|
||||
win_name='',
|
||||
show=False,
|
||||
wait_time=0,
|
||||
out_file=None,
|
||||
opacity=0.5):
|
||||
"""Draw `result` over `img`.
|
||||
|
||||
Args:
|
||||
img (str or Tensor): The image to be displayed.
|
||||
seg (Tensor): The semantic segmentation results to draw over
|
||||
`img`.
|
||||
class_names (list[str]): Names of each classes.
|
||||
palette (list[list[int]]] | np.ndarray | None): The palette of
|
||||
segmentation map. If None is given, random palette will be
|
||||
generated. Default: None
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
Default: 0.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
out_file (str or None): The filename to write the image.
|
||||
Default: None.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5.
|
||||
Must be in (0, 1] range.
|
||||
Returns:
|
||||
img (Tensor): Only if not `show` or `out_file`
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
if palette is None:
|
||||
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
||||
palette = np.array(palette)
|
||||
assert palette.shape[0] == len(class_names)
|
||||
assert palette.shape[1] == 3
|
||||
assert len(palette.shape) == 2
|
||||
assert 0 < opacity <= 1.0
|
||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||
for label, color in enumerate(palette):
|
||||
color_seg[seg == label, :] = color
|
||||
# convert to BGR
|
||||
color_seg = color_seg[..., ::-1]
|
||||
|
||||
img = img * (1 - opacity) + color_seg * opacity
|
||||
img = img.astype(np.uint8)
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
show = False
|
||||
|
||||
if show:
|
||||
mmcv.imshow(img, win_name, wait_time)
|
||||
if out_file is not None:
|
||||
mmcv.imwrite(img, out_file)
|
||||
|
||||
if not (show or out_file):
|
||||
warnings.warn('show==False and out_file is not specified, only '
|
||||
'result image will be returned')
|
||||
return img
|
||||
|
||||
|
||||
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
|
||||
if show_origin is True:
|
||||
# only keep pipeline of Loading data and ann
|
||||
_data_cfg['pipeline'] = [
|
||||
x for x in _data_cfg.pipeline if 'Load' in x['type']
|
||||
]
|
||||
else:
|
||||
_data_cfg['pipeline'] = [
|
||||
x for x in _data_cfg.pipeline if x['type'] not in skip_type
|
||||
]
|
||||
|
||||
|
||||
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
|
||||
cfg = Config.fromfile(config_path)
|
||||
train_data_cfg = cfg.data.train
|
||||
if isinstance(train_data_cfg, list):
|
||||
for _data_cfg in train_data_cfg:
|
||||
if 'pipeline' in _data_cfg:
|
||||
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
|
||||
elif 'dataset' in _data_cfg:
|
||||
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
|
||||
show_origin)
|
||||
else:
|
||||
raise ValueError
|
||||
elif 'dataset' in train_data_cfg:
|
||||
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
|
||||
else:
|
||||
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
|
||||
return cfg
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
|
||||
dataset = build_dataset(cfg.data.train)
|
||||
progress_bar = mmcv.ProgressBar(len(dataset))
|
||||
for item in dataset:
|
||||
filename = os.path.join(args.output_dir,
|
||||
Path(item['filename']).name
|
||||
) if args.output_dir is not None else None
|
||||
imshow_semantic(
|
||||
item['img'],
|
||||
item['gt_semantic_seg'],
|
||||
dataset.CLASSES,
|
||||
dataset.PALETTE,
|
||||
show=args.show,
|
||||
wait_time=args.show_interval,
|
||||
out_file=filename,
|
||||
opacity=args.opacity,
|
||||
)
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -215,7 +215,8 @@ def main():
|
|||
print(f'\nwriting results to {args.out}')
|
||||
mmcv.dump(results, args.out)
|
||||
if args.eval:
|
||||
metric = dataset.evaluate(results, args.eval, **eval_kwargs)
|
||||
eval_kwargs.update(metric=args.eval)
|
||||
metric = dataset.evaluate(results, **eval_kwargs)
|
||||
metric_dict = dict(config=args.config, metric=metric)
|
||||
if args.work_dir is not None and rank == 0:
|
||||
mmcv.dump(metric_dict, json_file, indent=4)
|
||||
|
|
Loading…
Reference in New Issue