mmsegmentation/tests/test_datasets/test_dataset.py

354 lines
12 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from unittest.mock import MagicMock
import pytest
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset,
COCOStuffDataset, ISPRSDataset, LoveDADataset,
PascalVOCDataset, PotsdamDataset, iSAIDDataset)
from mmseg.registry import DATASETS
from mmseg.utils import get_classes, get_palette
def test_classes():
assert list(
CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes')
assert list(PascalVOCDataset.METAINFO['classes']) == get_classes(
'voc') == get_classes('pascal_voc')
assert list(ADE20KDataset.METAINFO['classes']) == get_classes(
'ade') == get_classes('ade20k')
assert list(
COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff')
assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda')
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
with pytest.raises(ValueError):
get_classes('unsupported')
def test_classes_file_path():
tmp_file = tempfile.NamedTemporaryFile()
classes_path = f'{tmp_file.name}.txt'
train_pipeline = []
kwargs = dict(
pipeline=train_pipeline,
data_prefix=dict(img_path='./', seg_map_path='./'),
metainfo=dict(classes=classes_path))
# classes.txt with full categories
categories = get_classes('cityscapes')
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
dataset = CityscapesDataset(**kwargs)
assert list(dataset.metainfo['classes']) == categories
assert dataset.label_map is None
# classes.txt with sub categories
categories = ['road', 'sidewalk', 'building']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
dataset = CityscapesDataset(**kwargs)
assert list(dataset.metainfo['classes']) == categories
assert dataset.label_map is not None
# classes.txt with unknown categories
categories = ['road', 'sidewalk', 'unknown']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
with pytest.raises(ValueError):
CityscapesDataset(**kwargs)
tmp_file.close()
os.remove(classes_path)
assert not osp.exists(classes_path)
def test_palette():
assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes')
assert PascalVOCDataset.METAINFO['palette'] == get_palette(
'voc') == get_palette('pascal_voc')
assert ADE20KDataset.METAINFO['palette'] == get_palette(
'ade') == get_palette('ade20k')
assert LoveDADataset.METAINFO['palette'] == get_palette('loveda')
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
with pytest.raises(ValueError):
get_palette('unsupported')
def test_custom_dataset():
# with 'img_path' and 'seg_map_path' in data_prefix
train_dataset = BaseSegDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path='imgs/',
seg_map_path='gts/',
),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with 'img_path' and 'seg_map_path' in data_prefix and ann_file
train_dataset = BaseSegDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path='imgs/',
seg_map_path='gts/',
),
img_suffix='img.jpg',
seg_map_suffix='gt.png',
ann_file='splits/train.txt')
assert len(train_dataset) == 4
# no data_root
train_dataset = BaseSegDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
seg_map_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with data_root but 'img_path' and 'seg_map_path' in data_prefix are
# abs path
train_dataset = BaseSegDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
seg_map_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# test_mode=True
test_dataset = BaseSegDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
img_suffix='img.jpg',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )))
assert len(test_dataset) == 5
# training data get
train_data = train_dataset[0]
assert isinstance(train_data, dict)
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
assert 'seg_map_path' in train_data and osp.isfile(
train_data['seg_map_path'])
# test data get
test_data = test_dataset[0]
assert isinstance(test_data, dict)
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
assert 'seg_map_path' in train_data and osp.isfile(
train_data['seg_map_path'])
def test_ade():
test_dataset = ADE20KDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs')))
assert len(test_dataset) == 5
def test_cityscapes():
test_dataset = CityscapesDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit/val'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/gtFine/val')))
assert len(test_dataset) == 1
def test_loveda():
test_dataset = LoveDADataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_loveda_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_loveda_dataset/ann_dir')))
assert len(test_dataset) == 3
def test_potsdam():
test_dataset = PotsdamDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_potsdam_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_potsdam_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_vaihingen():
test_dataset = ISPRSDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_vaihingen_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_vaihingen_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_isaid():
test_dataset = iSAIDDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/ann_dir')))
assert len(test_dataset) == 2
test_dataset = iSAIDDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/ann_dir')),
ann_file=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/splits/train.txt'))
assert len(test_dataset) == 1
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('BaseSegDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):
dataset_class = DATASETS.get(dataset)
if isinstance(dataset_class, PascalVOCDataset):
tmp_file = tempfile.NamedTemporaryFile()
ann_file = f'{tmp_file.name}.txt'
else:
ann_file = MagicMock()
original_classes = dataset_class.METAINFO.get('classes', None)
# Test setting classes as a tuple
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=dict(classes=classes),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == classes
if not isinstance(custom_dataset, BaseSegDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test setting classes as a list
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=dict(classes=list(classes)),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == list(classes)
if not isinstance(custom_dataset, BaseSegDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test overriding not a subset
custom_dataset = dataset_class(
ann_file=ann_file,
data_prefix=dict(img_path=MagicMock()),
metainfo=dict(classes=[classes[0]]),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == [classes[0]]
if not isinstance(custom_dataset, BaseSegDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test default behavior
if dataset_class is BaseSegDataset:
with pytest.raises(AssertionError):
custom_dataset = dataset_class(
ann_file=ann_file,
data_prefix=dict(img_path=MagicMock()),
metainfo=None,
test_mode=True,
lazy_init=True)
else:
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=None,
test_mode=True,
lazy_init=True)
assert custom_dataset.METAINFO['classes'] == original_classes
assert custom_dataset.label_map is None
def test_custom_dataset_random_palette_is_generated():
dataset = BaseSegDataset(
pipeline=[],
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(classes=('bus', 'car')),
lazy_init=True,
test_mode=True)
assert len(dataset.metainfo['palette']) == 2
for class_color in dataset.metainfo['palette']:
assert len(class_color) == 3
assert all(x >= 0 and x <= 255 for x in class_color)
def test_custom_dataset_custom_palette():
dataset = BaseSegDataset(
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(
classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200,
200]]),
lazy_init=True,
test_mode=True)
assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100],
[200, 200, 200]])
# test custom class and palette don't match
with pytest.raises(ValueError):
dataset = BaseSegDataset(
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
lazy_init=True)