264 lines
8.7 KiB
Python
264 lines
8.7 KiB
Python
import os.path as osp
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from mmseg.core.evaluation import get_classes, get_palette
|
|
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
|
ConcatDataset, CustomDataset, PascalVOCDataset,
|
|
RepeatDataset)
|
|
|
|
|
|
def test_classes():
|
|
assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes')
|
|
assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes(
|
|
'pascal_voc')
|
|
assert list(
|
|
ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
|
|
|
|
with pytest.raises(ValueError):
|
|
get_classes('unsupported')
|
|
|
|
|
|
def test_palette():
|
|
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
|
|
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
|
|
'pascal_voc')
|
|
assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k')
|
|
|
|
with pytest.raises(ValueError):
|
|
get_palette('unsupported')
|
|
|
|
|
|
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
|
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
|
MagicMock(side_effect=lambda idx: idx))
|
|
def test_dataset_wrapper():
|
|
# CustomDataset.load_annotations = MagicMock()
|
|
# CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
|
dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
|
|
len_a = 10
|
|
dataset_a.img_infos = MagicMock()
|
|
dataset_a.img_infos.__len__.return_value = len_a
|
|
dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
|
|
len_b = 20
|
|
dataset_b.img_infos = MagicMock()
|
|
dataset_b.img_infos.__len__.return_value = len_b
|
|
|
|
concat_dataset = ConcatDataset([dataset_a, dataset_b])
|
|
assert concat_dataset[5] == 5
|
|
assert concat_dataset[25] == 15
|
|
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
|
|
|
|
repeat_dataset = RepeatDataset(dataset_a, 10)
|
|
assert repeat_dataset[5] == 5
|
|
assert repeat_dataset[15] == 5
|
|
assert repeat_dataset[27] == 7
|
|
assert len(repeat_dataset) == 10 * len(dataset_a)
|
|
|
|
|
|
def test_custom_dataset():
|
|
img_norm_cfg = dict(
|
|
mean=[123.675, 116.28, 103.53],
|
|
std=[58.395, 57.12, 57.375],
|
|
to_rgb=True)
|
|
crop_size = (512, 1024)
|
|
train_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(type='LoadAnnotations'),
|
|
dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
|
|
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
|
dict(type='RandomFlip', prob=0.5),
|
|
dict(type='PhotoMetricDistortion'),
|
|
dict(type='Normalize', **img_norm_cfg),
|
|
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
|
dict(type='DefaultFormatBundle'),
|
|
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
|
]
|
|
test_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(
|
|
type='MultiScaleFlipAug',
|
|
img_scale=(128, 256),
|
|
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
|
flip=False,
|
|
transforms=[
|
|
dict(type='Resize', keep_ratio=True),
|
|
dict(type='RandomFlip'),
|
|
dict(type='Normalize', **img_norm_cfg),
|
|
dict(type='ImageToTensor', keys=['img']),
|
|
dict(type='Collect', keys=['img']),
|
|
])
|
|
]
|
|
|
|
# with img_dir and ann_dir
|
|
train_dataset = CustomDataset(
|
|
train_pipeline,
|
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
|
img_dir='imgs/',
|
|
ann_dir='gts/',
|
|
img_suffix='img.jpg',
|
|
seg_map_suffix='gt.png')
|
|
assert len(train_dataset) == 5
|
|
|
|
# with img_dir, ann_dir, split
|
|
train_dataset = CustomDataset(
|
|
train_pipeline,
|
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
|
img_dir='imgs/',
|
|
ann_dir='gts/',
|
|
img_suffix='img.jpg',
|
|
seg_map_suffix='gt.png',
|
|
split='splits/train.txt')
|
|
assert len(train_dataset) == 4
|
|
|
|
# no data_root
|
|
train_dataset = CustomDataset(
|
|
train_pipeline,
|
|
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
|
ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
|
|
img_suffix='img.jpg',
|
|
seg_map_suffix='gt.png')
|
|
assert len(train_dataset) == 5
|
|
|
|
# with data_root but img_dir/ann_dir are abs path
|
|
train_dataset = CustomDataset(
|
|
train_pipeline,
|
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
|
img_dir=osp.abspath(
|
|
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
|
|
ann_dir=osp.abspath(
|
|
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
|
img_suffix='img.jpg',
|
|
seg_map_suffix='gt.png')
|
|
assert len(train_dataset) == 5
|
|
|
|
# test_mode=True
|
|
test_dataset = CustomDataset(
|
|
test_pipeline,
|
|
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
|
img_suffix='img.jpg',
|
|
test_mode=True)
|
|
assert len(test_dataset) == 5
|
|
|
|
# training data get
|
|
train_data = train_dataset[0]
|
|
assert isinstance(train_data, dict)
|
|
|
|
# test data get
|
|
test_data = test_dataset[0]
|
|
assert isinstance(test_data, dict)
|
|
|
|
# get gt seg map
|
|
gt_seg_maps = train_dataset.get_gt_seg_maps()
|
|
assert len(gt_seg_maps) == 5
|
|
|
|
# evaluation
|
|
pseudo_results = []
|
|
for gt_seg_map in gt_seg_maps:
|
|
h, w = gt_seg_map.shape
|
|
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
|
eval_results = train_dataset.evaluate(pseudo_results)
|
|
assert isinstance(eval_results, dict)
|
|
assert 'mIoU' in eval_results
|
|
assert 'mAcc' in eval_results
|
|
assert 'aAcc' in eval_results
|
|
|
|
# evaluation with CLASSES
|
|
train_dataset.CLASSES = tuple(['a'] * 7)
|
|
eval_results = train_dataset.evaluate(pseudo_results)
|
|
assert isinstance(eval_results, dict)
|
|
assert 'mIoU' in eval_results
|
|
assert 'mAcc' in eval_results
|
|
assert 'aAcc' in eval_results
|
|
|
|
|
|
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
|
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
|
MagicMock(side_effect=lambda idx: idx))
|
|
@pytest.mark.parametrize('dataset, classes', [
|
|
('ADE20KDataset', ('wall', 'building')),
|
|
('CityscapesDataset', ('road', 'sidewalk')),
|
|
('CustomDataset', ('bus', 'car')),
|
|
('PascalVOCDataset', ('aeroplane', 'bicycle')),
|
|
])
|
|
def test_custom_classes_override_default(dataset, classes):
|
|
|
|
dataset_class = DATASETS.get(dataset)
|
|
|
|
original_classes = dataset_class.CLASSES
|
|
|
|
# Test setting classes as a tuple
|
|
custom_dataset = dataset_class(
|
|
pipeline=[],
|
|
img_dir=MagicMock(),
|
|
split=MagicMock(),
|
|
classes=classes,
|
|
test_mode=True)
|
|
|
|
assert custom_dataset.CLASSES != original_classes
|
|
assert custom_dataset.CLASSES == classes
|
|
|
|
# Test setting classes as a list
|
|
custom_dataset = dataset_class(
|
|
pipeline=[],
|
|
img_dir=MagicMock(),
|
|
split=MagicMock(),
|
|
classes=list(classes),
|
|
test_mode=True)
|
|
|
|
assert custom_dataset.CLASSES != original_classes
|
|
assert custom_dataset.CLASSES == list(classes)
|
|
|
|
# Test overriding not a subset
|
|
custom_dataset = dataset_class(
|
|
pipeline=[],
|
|
img_dir=MagicMock(),
|
|
split=MagicMock(),
|
|
classes=[classes[0]],
|
|
test_mode=True)
|
|
|
|
assert custom_dataset.CLASSES != original_classes
|
|
assert custom_dataset.CLASSES == [classes[0]]
|
|
|
|
# Test default behavior
|
|
custom_dataset = dataset_class(
|
|
pipeline=[],
|
|
img_dir=MagicMock(),
|
|
split=MagicMock(),
|
|
classes=None,
|
|
test_mode=True)
|
|
|
|
assert custom_dataset.CLASSES == original_classes
|
|
|
|
|
|
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
|
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
|
MagicMock(side_effect=lambda idx: idx))
|
|
def test_custom_dataset_random_palette_is_generated():
|
|
dataset = CustomDataset(
|
|
pipeline=[],
|
|
img_dir=MagicMock(),
|
|
split=MagicMock(),
|
|
classes=('bus', 'car'),
|
|
test_mode=True)
|
|
assert len(dataset.PALETTE) == 2
|
|
for class_color in dataset.PALETTE:
|
|
assert len(class_color) == 3
|
|
assert all(x >= 0 and x <= 255 for x in class_color)
|
|
|
|
|
|
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
|
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
|
MagicMock(side_effect=lambda idx: idx))
|
|
def test_custom_dataset_custom_palette():
|
|
dataset = CustomDataset(
|
|
pipeline=[],
|
|
img_dir=MagicMock(),
|
|
split=MagicMock(),
|
|
classes=('bus', 'car'),
|
|
palette=[[100, 100, 100], [200, 200, 200]],
|
|
test_mode=True)
|
|
assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])
|