465 lines
16 KiB
Python
465 lines
16 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import os.path as osp
|
|
import tempfile
|
|
|
|
import pytest
|
|
|
|
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, BDD100KDataset,
|
|
CityscapesDataset, COCOStuffDataset,
|
|
DecathlonDataset, DSDLSegDataset, ISPRSDataset,
|
|
LIPDataset, LoveDADataset, MapillaryDataset_v1,
|
|
MapillaryDataset_v2, PascalVOCDataset,
|
|
PotsdamDataset, REFUGEDataset, SynapseDataset,
|
|
iSAIDDataset)
|
|
from mmseg.registry import DATASETS
|
|
from mmseg.utils import get_classes, get_palette
|
|
|
|
try:
|
|
from dsdl.dataset import DSDLDataset
|
|
except ImportError:
|
|
DSDLDataset = None
|
|
|
|
|
|
def test_classes():
|
|
assert list(
|
|
CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes')
|
|
assert list(PascalVOCDataset.METAINFO['classes']) == get_classes(
|
|
'voc') == get_classes('pascal_voc')
|
|
assert list(ADE20KDataset.METAINFO['classes']) == get_classes(
|
|
'ade') == get_classes('ade20k')
|
|
assert list(
|
|
COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff')
|
|
assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda')
|
|
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
|
|
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
|
|
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
|
|
assert list(
|
|
MapillaryDataset_v1.METAINFO['classes']) == get_classes('mapillary_v1')
|
|
assert list(
|
|
MapillaryDataset_v2.METAINFO['classes']) == get_classes('mapillary_v2')
|
|
assert list(BDD100KDataset.METAINFO['classes']) == get_classes('bdd100k')
|
|
with pytest.raises(ValueError):
|
|
get_classes('unsupported')
|
|
|
|
|
|
def test_classes_file_path():
|
|
tmp_file = tempfile.NamedTemporaryFile()
|
|
classes_path = f'{tmp_file.name}.txt'
|
|
train_pipeline = []
|
|
kwargs = dict(
|
|
pipeline=train_pipeline,
|
|
data_prefix=dict(img_path='./', seg_map_path='./'),
|
|
metainfo=dict(classes=classes_path))
|
|
|
|
# classes.txt with full categories
|
|
categories = get_classes('cityscapes')
|
|
with open(classes_path, 'w') as f:
|
|
f.write('\n'.join(categories))
|
|
dataset = CityscapesDataset(**kwargs)
|
|
assert list(dataset.metainfo['classes']) == categories
|
|
assert dataset.label_map is None
|
|
|
|
# classes.txt with sub categories
|
|
categories = ['road', 'sidewalk', 'building']
|
|
with open(classes_path, 'w') as f:
|
|
f.write('\n'.join(categories))
|
|
dataset = CityscapesDataset(**kwargs)
|
|
assert list(dataset.metainfo['classes']) == categories
|
|
assert dataset.label_map is not None
|
|
|
|
# classes.txt with unknown categories
|
|
categories = ['road', 'sidewalk', 'unknown']
|
|
with open(classes_path, 'w') as f:
|
|
f.write('\n'.join(categories))
|
|
|
|
with pytest.raises(ValueError):
|
|
CityscapesDataset(**kwargs)
|
|
|
|
tmp_file.close()
|
|
os.remove(classes_path)
|
|
assert not osp.exists(classes_path)
|
|
|
|
|
|
def test_palette():
|
|
assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes')
|
|
assert PascalVOCDataset.METAINFO['palette'] == get_palette(
|
|
'voc') == get_palette('pascal_voc')
|
|
assert ADE20KDataset.METAINFO['palette'] == get_palette(
|
|
'ade') == get_palette('ade20k')
|
|
assert LoveDADataset.METAINFO['palette'] == get_palette('loveda')
|
|
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
|
|
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
|
|
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
|
|
assert list(
|
|
MapillaryDataset_v1.METAINFO['palette']) == get_palette('mapillary_v1')
|
|
assert list(
|
|
MapillaryDataset_v2.METAINFO['palette']) == get_palette('mapillary_v2')
|
|
assert list(BDD100KDataset.METAINFO['palette']) == get_palette('bdd100k')
|
|
with pytest.raises(ValueError):
|
|
get_palette('unsupported')
|
|
|
|
|
|
def test_custom_dataset():
|
|
|
|
# with 'img_path' and 'seg_map_path' in data_prefix
|
|
train_dataset = BaseSegDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
|
data_prefix=dict(
|
|
img_path='imgs/',
|
|
seg_map_path='gts/',
|
|
),
|
|
img_suffix='img.jpg',
|
|
seg_map_suffix='gt.png')
|
|
assert len(train_dataset) == 5
|
|
|
|
# with 'img_path' and 'seg_map_path' in data_prefix and ann_file
|
|
train_dataset = BaseSegDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
|
data_prefix=dict(
|
|
img_path='imgs/',
|
|
seg_map_path='gts/',
|
|
),
|
|
img_suffix='img.jpg',
|
|
seg_map_suffix='gt.png',
|
|
ann_file='splits/train.txt')
|
|
assert len(train_dataset) == 4
|
|
|
|
# no data_root
|
|
train_dataset = BaseSegDataset(
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
|
img_suffix='img.jpg',
|
|
seg_map_suffix='gt.png')
|
|
assert len(train_dataset) == 5
|
|
|
|
# with data_root but 'img_path' and 'seg_map_path' in data_prefix are
|
|
# abs path
|
|
train_dataset = BaseSegDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
|
img_suffix='img.jpg',
|
|
seg_map_suffix='gt.png')
|
|
assert len(train_dataset) == 5
|
|
|
|
# test_mode=True
|
|
test_dataset = BaseSegDataset(
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
|
|
img_suffix='img.jpg',
|
|
test_mode=True,
|
|
metainfo=dict(classes=('pseudo_class', )))
|
|
assert len(test_dataset) == 5
|
|
|
|
# training data get
|
|
train_data = train_dataset[0]
|
|
assert isinstance(train_data, dict)
|
|
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
|
|
assert 'seg_map_path' in train_data and osp.isfile(
|
|
train_data['seg_map_path'])
|
|
|
|
# test data get
|
|
test_data = test_dataset[0]
|
|
assert isinstance(test_data, dict)
|
|
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
|
|
assert 'seg_map_path' in train_data and osp.isfile(
|
|
train_data['seg_map_path'])
|
|
|
|
|
|
def test_ade():
|
|
test_dataset = ADE20KDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__), '../data/pseudo_dataset/imgs')))
|
|
assert len(test_dataset) == 5
|
|
|
|
|
|
def test_cityscapes():
|
|
test_dataset = CityscapesDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_cityscapes_dataset/leftImg8bit/val'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_cityscapes_dataset/gtFine/val')))
|
|
assert len(test_dataset) == 1
|
|
|
|
|
|
def test_loveda():
|
|
test_dataset = LoveDADataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_loveda_dataset/img_dir'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_loveda_dataset/ann_dir')))
|
|
assert len(test_dataset) == 3
|
|
|
|
|
|
def test_potsdam():
|
|
test_dataset = PotsdamDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_potsdam_dataset/img_dir'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_potsdam_dataset/ann_dir')))
|
|
assert len(test_dataset) == 1
|
|
|
|
|
|
def test_vaihingen():
|
|
test_dataset = ISPRSDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_vaihingen_dataset/img_dir'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_vaihingen_dataset/ann_dir')))
|
|
assert len(test_dataset) == 1
|
|
|
|
|
|
def test_synapse():
|
|
test_dataset = SynapseDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_synapse_dataset/img_dir'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_synapse_dataset/ann_dir')))
|
|
assert len(test_dataset) == 2
|
|
|
|
|
|
def test_refuge():
|
|
test_dataset = REFUGEDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_refuge_dataset/img_dir'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_refuge_dataset/ann_dir')))
|
|
assert len(test_dataset) == 1
|
|
|
|
|
|
def test_isaid():
|
|
test_dataset = iSAIDDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_isaid_dataset/ann_dir')))
|
|
assert len(test_dataset) == 2
|
|
test_dataset = iSAIDDataset(
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_isaid_dataset/ann_dir')),
|
|
ann_file=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_isaid_dataset/splits/train.txt'))
|
|
assert len(test_dataset) == 1
|
|
|
|
|
|
def test_decathlon():
|
|
data_root = osp.join(osp.dirname(__file__), '../data')
|
|
# test load training dataset
|
|
test_dataset = DecathlonDataset(
|
|
pipeline=[], data_root=data_root, ann_file='dataset.json')
|
|
assert len(test_dataset) == 1
|
|
|
|
# test load test dataset
|
|
test_dataset = DecathlonDataset(
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
ann_file='dataset.json',
|
|
test_mode=True)
|
|
assert len(test_dataset) == 3
|
|
|
|
|
|
def test_lip():
|
|
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_lip_dataset')
|
|
# train load training dataset
|
|
train_dataset = LIPDataset(
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=dict(
|
|
img_path='train_images', seg_map_path='train_segmentations'))
|
|
assert len(train_dataset) == 1
|
|
|
|
# test load training dataset
|
|
test_dataset = LIPDataset(
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=dict(
|
|
img_path='val_images', seg_map_path='val_segmentations'))
|
|
assert len(test_dataset) == 1
|
|
|
|
|
|
def test_mapillary():
|
|
test_dataset = MapillaryDataset_v1(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_mapillary_dataset/images'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_mapillary_dataset/v1.2')))
|
|
assert len(test_dataset) == 1
|
|
|
|
|
|
def test_bdd100k():
|
|
test_dataset = BDD100KDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(
|
|
img_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_bdd100k_dataset/images/10k/val'),
|
|
seg_map_path=osp.join(
|
|
osp.dirname(__file__),
|
|
'../data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val')))
|
|
assert len(test_dataset) == 3
|
|
|
|
|
|
@pytest.mark.parametrize('dataset, classes', [
|
|
('ADE20KDataset', ('wall', 'building')),
|
|
('CityscapesDataset', ('road', 'sidewalk')),
|
|
('BaseSegDataset', ('bus', 'car')),
|
|
('PascalVOCDataset', ('aeroplane', 'bicycle')),
|
|
])
|
|
def test_custom_classes_override_default(dataset, classes):
|
|
|
|
dataset_class = DATASETS.get(dataset)
|
|
original_classes = dataset_class.METAINFO.get('classes', None)
|
|
|
|
tmp_file = tempfile.NamedTemporaryFile()
|
|
ann_file = tmp_file.name
|
|
img_path = tempfile.mkdtemp()
|
|
|
|
# Test setting classes as a tuple
|
|
custom_dataset = dataset_class(
|
|
data_prefix=dict(img_path=img_path),
|
|
ann_file=ann_file,
|
|
metainfo=dict(classes=classes),
|
|
test_mode=True,
|
|
lazy_init=True)
|
|
|
|
assert custom_dataset.metainfo['classes'] != original_classes
|
|
assert custom_dataset.metainfo['classes'] == classes
|
|
if not isinstance(custom_dataset, BaseSegDataset):
|
|
assert isinstance(custom_dataset.label_map, dict)
|
|
|
|
# Test setting classes as a list
|
|
custom_dataset = dataset_class(
|
|
data_prefix=dict(img_path=img_path),
|
|
ann_file=ann_file,
|
|
metainfo=dict(classes=list(classes)),
|
|
test_mode=True,
|
|
lazy_init=True)
|
|
|
|
assert custom_dataset.metainfo['classes'] != original_classes
|
|
assert custom_dataset.metainfo['classes'] == list(classes)
|
|
if not isinstance(custom_dataset, BaseSegDataset):
|
|
assert isinstance(custom_dataset.label_map, dict)
|
|
|
|
# Test overriding not a subset
|
|
custom_dataset = dataset_class(
|
|
ann_file=ann_file,
|
|
data_prefix=dict(img_path=img_path),
|
|
metainfo=dict(classes=[classes[0]]),
|
|
test_mode=True,
|
|
lazy_init=True)
|
|
|
|
assert custom_dataset.metainfo['classes'] != original_classes
|
|
assert custom_dataset.metainfo['classes'] == [classes[0]]
|
|
if not isinstance(custom_dataset, BaseSegDataset):
|
|
assert isinstance(custom_dataset.label_map, dict)
|
|
|
|
# Test default behavior
|
|
if dataset_class is BaseSegDataset:
|
|
with pytest.raises(AssertionError):
|
|
custom_dataset = dataset_class(
|
|
ann_file=ann_file,
|
|
data_prefix=dict(img_path=img_path),
|
|
metainfo=None,
|
|
test_mode=True,
|
|
lazy_init=True)
|
|
else:
|
|
custom_dataset = dataset_class(
|
|
data_prefix=dict(img_path=img_path),
|
|
ann_file=ann_file,
|
|
metainfo=None,
|
|
test_mode=True,
|
|
lazy_init=True)
|
|
|
|
assert custom_dataset.METAINFO['classes'] == original_classes
|
|
assert custom_dataset.label_map is None
|
|
|
|
|
|
def test_custom_dataset_random_palette_is_generated():
|
|
dataset = BaseSegDataset(
|
|
pipeline=[],
|
|
data_prefix=dict(img_path=tempfile.mkdtemp()),
|
|
ann_file=tempfile.mkdtemp(),
|
|
metainfo=dict(classes=('bus', 'car')),
|
|
lazy_init=True,
|
|
test_mode=True)
|
|
assert len(dataset.metainfo['palette']) == 2
|
|
for class_color in dataset.metainfo['palette']:
|
|
assert len(class_color) == 3
|
|
assert all(x >= 0 and x <= 255 for x in class_color)
|
|
|
|
|
|
def test_custom_dataset_custom_palette():
|
|
dataset = BaseSegDataset(
|
|
data_prefix=dict(img_path=tempfile.mkdtemp()),
|
|
ann_file=tempfile.mkdtemp(),
|
|
metainfo=dict(
|
|
classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200,
|
|
200]]),
|
|
lazy_init=True,
|
|
test_mode=True)
|
|
assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100],
|
|
[200, 200, 200]])
|
|
# test custom class and palette don't match
|
|
with pytest.raises(ValueError):
|
|
dataset = BaseSegDataset(
|
|
data_prefix=dict(img_path=tempfile.mkdtemp()),
|
|
ann_file=tempfile.mkdtemp(),
|
|
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
|
|
lazy_init=True)
|
|
|
|
|
|
def test_dsdlseg_dataset():
|
|
if DSDLDataset is not None:
|
|
dataset = DSDLSegDataset(
|
|
data_root='tests/data/dsdl_seg', ann_file='set-train/train.yaml')
|
|
assert len(dataset) == 3
|
|
assert len(dataset.metainfo['classes']) == 21
|
|
else:
|
|
ImportWarning('Package `dsdl` is not installed.')
|