mmsegmentation/tests/test_datasets/test_dataset.py

354 lines
12 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from unittest.mock import MagicMock
import pytest
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
COCOStuffDataset, CustomDataset, ISPRSDataset,
LoveDADataset, PascalVOCDataset, PotsdamDataset,
iSAIDDataset)
def test_classes():
assert list(
CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes')
assert list(PascalVOCDataset.METAINFO['classes']) == get_classes(
'voc') == get_classes('pascal_voc')
assert list(ADE20KDataset.METAINFO['classes']) == get_classes(
'ade') == get_classes('ade20k')
assert list(
COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff')
assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda')
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
with pytest.raises(ValueError):
get_classes('unsupported')
def test_classes_file_path():
tmp_file = tempfile.NamedTemporaryFile()
classes_path = f'{tmp_file.name}.txt'
train_pipeline = []
kwargs = dict(
pipeline=train_pipeline,
data_prefix=dict(img_path='./', seg_map_path='./'),
metainfo=dict(classes=classes_path))
# classes.txt with full categories
categories = get_classes('cityscapes')
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
dataset = CityscapesDataset(**kwargs)
assert list(dataset.metainfo['classes']) == categories
assert dataset.label_map is None
# classes.txt with sub categories
categories = ['road', 'sidewalk', 'building']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
dataset = CityscapesDataset(**kwargs)
assert list(dataset.metainfo['classes']) == categories
assert dataset.label_map is not None
# classes.txt with unknown categories
categories = ['road', 'sidewalk', 'unknown']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
with pytest.raises(ValueError):
CityscapesDataset(**kwargs)
tmp_file.close()
os.remove(classes_path)
assert not osp.exists(classes_path)
def test_palette():
assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes')
assert PascalVOCDataset.METAINFO['palette'] == get_palette(
'voc') == get_palette('pascal_voc')
assert ADE20KDataset.METAINFO['palette'] == get_palette(
'ade') == get_palette('ade20k')
assert LoveDADataset.METAINFO['palette'] == get_palette('loveda')
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
with pytest.raises(ValueError):
get_palette('unsupported')
def test_custom_dataset():
# with 'img_path' and 'seg_map_path' in data_prefix
train_dataset = CustomDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path='imgs/',
seg_map_path='gts/',
),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with 'img_path' and 'seg_map_path' in data_prefix and ann_file
train_dataset = CustomDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path='imgs/',
seg_map_path='gts/',
),
img_suffix='img.jpg',
seg_map_suffix='gt.png',
ann_file='splits/train.txt')
assert len(train_dataset) == 4
# no data_root
train_dataset = CustomDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
seg_map_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with data_root but 'img_path' and 'seg_map_path' in data_prefix are
# abs path
train_dataset = CustomDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
seg_map_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# test_mode=True
test_dataset = CustomDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
img_suffix='img.jpg',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )))
assert len(test_dataset) == 5
# training data get
train_data = train_dataset[0]
assert isinstance(train_data, dict)
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
assert 'seg_map_path' in train_data and osp.isfile(
train_data['seg_map_path'])
# test data get
test_data = test_dataset[0]
assert isinstance(test_data, dict)
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
assert 'seg_map_path' in train_data and osp.isfile(
train_data['seg_map_path'])
def test_ade():
test_dataset = ADE20KDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs')))
assert len(test_dataset) == 5
def test_cityscapes():
test_dataset = CityscapesDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit/val'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/gtFine/val')))
assert len(test_dataset) == 1
def test_loveda():
test_dataset = LoveDADataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_loveda_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_loveda_dataset/ann_dir')))
assert len(test_dataset) == 3
def test_potsdam():
test_dataset = PotsdamDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_potsdam_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_potsdam_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_vaihingen():
test_dataset = ISPRSDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_vaihingen_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_vaihingen_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_isaid():
test_dataset = iSAIDDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/ann_dir')))
assert len(test_dataset) == 2
test_dataset = iSAIDDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/ann_dir')),
ann_file=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/splits/train.txt'))
assert len(test_dataset) == 1
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('CustomDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):
dataset_class = DATASETS.get(dataset)
if isinstance(dataset_class, PascalVOCDataset):
tmp_file = tempfile.NamedTemporaryFile()
ann_file = f'{tmp_file.name}.txt'
else:
ann_file = MagicMock()
original_classes = dataset_class.METAINFO.get('classes', None)
# Test setting classes as a tuple
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=dict(classes=classes),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == classes
if not isinstance(custom_dataset, CustomDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test setting classes as a list
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=dict(classes=list(classes)),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == list(classes)
if not isinstance(custom_dataset, CustomDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test overriding not a subset
custom_dataset = dataset_class(
ann_file=ann_file,
data_prefix=dict(img_path=MagicMock()),
metainfo=dict(classes=[classes[0]]),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == [classes[0]]
if not isinstance(custom_dataset, CustomDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test default behavior
if dataset_class is CustomDataset:
with pytest.raises(AssertionError):
custom_dataset = dataset_class(
ann_file=ann_file,
data_prefix=dict(img_path=MagicMock()),
metainfo=None,
test_mode=True,
lazy_init=True)
else:
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=None,
test_mode=True,
lazy_init=True)
assert custom_dataset.METAINFO['classes'] == original_classes
assert custom_dataset.label_map is None
def test_custom_dataset_random_palette_is_generated():
dataset = CustomDataset(
pipeline=[],
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(classes=('bus', 'car')),
lazy_init=True,
test_mode=True)
assert len(dataset.metainfo['palette']) == 2
for class_color in dataset.metainfo['palette']:
assert len(class_color) == 3
assert all(x >= 0 and x <= 255 for x in class_color)
def test_custom_dataset_custom_palette():
dataset = CustomDataset(
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(
classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200,
200]]),
lazy_init=True,
test_mode=True)
assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100],
[200, 200, 200]])
# test custom class and palette don't match
with pytest.raises(ValueError):
dataset = CustomDataset(
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
lazy_init=True)