# Copyright (c) OpenMMLab. All rights reserved. import os import os.path as osp import tempfile from unittest.mock import MagicMock import pytest from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset, COCOStuffDataset, DecathlonDataset, ISPRSDataset, LIPDataset, LoveDADataset, PascalVOCDataset, PotsdamDataset, iSAIDDataset) from mmseg.registry import DATASETS from mmseg.utils import get_classes, get_palette def test_classes(): assert list( CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes') assert list(PascalVOCDataset.METAINFO['classes']) == get_classes( 'voc') == get_classes('pascal_voc') assert list(ADE20KDataset.METAINFO['classes']) == get_classes( 'ade') == get_classes('ade20k') assert list( COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff') assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda') assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam') assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen') assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid') with pytest.raises(ValueError): get_classes('unsupported') def test_classes_file_path(): tmp_file = tempfile.NamedTemporaryFile() classes_path = f'{tmp_file.name}.txt' train_pipeline = [] kwargs = dict( pipeline=train_pipeline, data_prefix=dict(img_path='./', seg_map_path='./'), metainfo=dict(classes=classes_path)) # classes.txt with full categories categories = get_classes('cityscapes') with open(classes_path, 'w') as f: f.write('\n'.join(categories)) dataset = CityscapesDataset(**kwargs) assert list(dataset.metainfo['classes']) == categories assert dataset.label_map is None # classes.txt with sub categories categories = ['road', 'sidewalk', 'building'] with open(classes_path, 'w') as f: f.write('\n'.join(categories)) dataset = CityscapesDataset(**kwargs) assert list(dataset.metainfo['classes']) == categories assert dataset.label_map is not None # classes.txt with unknown categories categories = ['road', 'sidewalk', 'unknown'] with open(classes_path, 'w') as f: f.write('\n'.join(categories)) with pytest.raises(ValueError): CityscapesDataset(**kwargs) tmp_file.close() os.remove(classes_path) assert not osp.exists(classes_path) def test_palette(): assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes') assert PascalVOCDataset.METAINFO['palette'] == get_palette( 'voc') == get_palette('pascal_voc') assert ADE20KDataset.METAINFO['palette'] == get_palette( 'ade') == get_palette('ade20k') assert LoveDADataset.METAINFO['palette'] == get_palette('loveda') assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam') assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff') assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid') with pytest.raises(ValueError): get_palette('unsupported') def test_custom_dataset(): # with 'img_path' and 'seg_map_path' in data_prefix train_dataset = BaseSegDataset( data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), data_prefix=dict( img_path='imgs/', seg_map_path='gts/', ), img_suffix='img.jpg', seg_map_suffix='gt.png') assert len(train_dataset) == 5 # with 'img_path' and 'seg_map_path' in data_prefix and ann_file train_dataset = BaseSegDataset( data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), data_prefix=dict( img_path='imgs/', seg_map_path='gts/', ), img_suffix='img.jpg', seg_map_suffix='gt.png', ann_file='splits/train.txt') assert len(train_dataset) == 4 # no data_root train_dataset = BaseSegDataset( data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_dataset/imgs'), seg_map_path=osp.join( osp.dirname(__file__), '../data/pseudo_dataset/gts')), img_suffix='img.jpg', seg_map_suffix='gt.png') assert len(train_dataset) == 5 # with data_root but 'img_path' and 'seg_map_path' in data_prefix are # abs path train_dataset = BaseSegDataset( data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'), data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_dataset/imgs'), seg_map_path=osp.join( osp.dirname(__file__), '../data/pseudo_dataset/gts')), img_suffix='img.jpg', seg_map_suffix='gt.png') assert len(train_dataset) == 5 # test_mode=True test_dataset = BaseSegDataset( data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_dataset/imgs')), img_suffix='img.jpg', test_mode=True, metainfo=dict(classes=('pseudo_class', ))) assert len(test_dataset) == 5 # training data get train_data = train_dataset[0] assert isinstance(train_data, dict) assert 'img_path' in train_data and osp.isfile(train_data['img_path']) assert 'seg_map_path' in train_data and osp.isfile( train_data['seg_map_path']) # test data get test_data = test_dataset[0] assert isinstance(test_data, dict) assert 'img_path' in train_data and osp.isfile(train_data['img_path']) assert 'seg_map_path' in train_data and osp.isfile( train_data['seg_map_path']) def test_ade(): test_dataset = ADE20KDataset( pipeline=[], data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_dataset/imgs'))) assert len(test_dataset) == 5 def test_cityscapes(): test_dataset = CityscapesDataset( pipeline=[], data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/leftImg8bit/val'), seg_map_path=osp.join( osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine/val'))) assert len(test_dataset) == 1 def test_loveda(): test_dataset = LoveDADataset( pipeline=[], data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_loveda_dataset/img_dir'), seg_map_path=osp.join( osp.dirname(__file__), '../data/pseudo_loveda_dataset/ann_dir'))) assert len(test_dataset) == 3 def test_potsdam(): test_dataset = PotsdamDataset( pipeline=[], data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_potsdam_dataset/img_dir'), seg_map_path=osp.join( osp.dirname(__file__), '../data/pseudo_potsdam_dataset/ann_dir'))) assert len(test_dataset) == 1 def test_vaihingen(): test_dataset = ISPRSDataset( pipeline=[], data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_vaihingen_dataset/img_dir'), seg_map_path=osp.join( osp.dirname(__file__), '../data/pseudo_vaihingen_dataset/ann_dir'))) assert len(test_dataset) == 1 def test_isaid(): test_dataset = iSAIDDataset( pipeline=[], data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'), seg_map_path=osp.join( osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'))) assert len(test_dataset) == 2 test_dataset = iSAIDDataset( data_prefix=dict( img_path=osp.join( osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'), seg_map_path=osp.join( osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir')), ann_file=osp.join( osp.dirname(__file__), '../data/pseudo_isaid_dataset/splits/train.txt')) assert len(test_dataset) == 1 def test_decathlon(): data_root = osp.join(osp.dirname(__file__), '../data') # test load training dataset test_dataset = DecathlonDataset( pipeline=[], data_root=data_root, ann_file='dataset.json') assert len(test_dataset) == 1 # test load test dataset test_dataset = DecathlonDataset( pipeline=[], data_root=data_root, ann_file='dataset.json', test_mode=True) assert len(test_dataset) == 3 def test_lip(): data_root = osp.join(osp.dirname(__file__), '../data/pseudo_lip_dataset') # train load training dataset train_dataset = LIPDataset( pipeline=[], data_root=data_root, data_prefix=dict( img_path='train_images', seg_map_path='train_segmentations')) assert len(train_dataset) == 1 # test load training dataset test_dataset = LIPDataset( pipeline=[], data_root=data_root, data_prefix=dict( img_path='val_images', seg_map_path='val_segmentations')) assert len(test_dataset) == 1 @pytest.mark.parametrize('dataset, classes', [ ('ADE20KDataset', ('wall', 'building')), ('CityscapesDataset', ('road', 'sidewalk')), ('BaseSegDataset', ('bus', 'car')), ('PascalVOCDataset', ('aeroplane', 'bicycle')), ]) def test_custom_classes_override_default(dataset, classes): dataset_class = DATASETS.get(dataset) if isinstance(dataset_class, PascalVOCDataset): tmp_file = tempfile.NamedTemporaryFile() ann_file = f'{tmp_file.name}.txt' else: ann_file = MagicMock() original_classes = dataset_class.METAINFO.get('classes', None) # Test setting classes as a tuple custom_dataset = dataset_class( data_prefix=dict(img_path=MagicMock()), ann_file=ann_file, metainfo=dict(classes=classes), test_mode=True, lazy_init=True) assert custom_dataset.metainfo['classes'] != original_classes assert custom_dataset.metainfo['classes'] == classes if not isinstance(custom_dataset, BaseSegDataset): assert isinstance(custom_dataset.label_map, dict) # Test setting classes as a list custom_dataset = dataset_class( data_prefix=dict(img_path=MagicMock()), ann_file=ann_file, metainfo=dict(classes=list(classes)), test_mode=True, lazy_init=True) assert custom_dataset.metainfo['classes'] != original_classes assert custom_dataset.metainfo['classes'] == list(classes) if not isinstance(custom_dataset, BaseSegDataset): assert isinstance(custom_dataset.label_map, dict) # Test overriding not a subset custom_dataset = dataset_class( ann_file=ann_file, data_prefix=dict(img_path=MagicMock()), metainfo=dict(classes=[classes[0]]), test_mode=True, lazy_init=True) assert custom_dataset.metainfo['classes'] != original_classes assert custom_dataset.metainfo['classes'] == [classes[0]] if not isinstance(custom_dataset, BaseSegDataset): assert isinstance(custom_dataset.label_map, dict) # Test default behavior if dataset_class is BaseSegDataset: with pytest.raises(AssertionError): custom_dataset = dataset_class( ann_file=ann_file, data_prefix=dict(img_path=MagicMock()), metainfo=None, test_mode=True, lazy_init=True) else: custom_dataset = dataset_class( data_prefix=dict(img_path=MagicMock()), ann_file=ann_file, metainfo=None, test_mode=True, lazy_init=True) assert custom_dataset.METAINFO['classes'] == original_classes assert custom_dataset.label_map is None def test_custom_dataset_random_palette_is_generated(): dataset = BaseSegDataset( pipeline=[], data_prefix=dict(img_path=MagicMock()), ann_file=MagicMock(), metainfo=dict(classes=('bus', 'car')), lazy_init=True, test_mode=True) assert len(dataset.metainfo['palette']) == 2 for class_color in dataset.metainfo['palette']: assert len(class_color) == 3 assert all(x >= 0 and x <= 255 for x in class_color) def test_custom_dataset_custom_palette(): dataset = BaseSegDataset( data_prefix=dict(img_path=MagicMock()), ann_file=MagicMock(), metainfo=dict( classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200, 200]]), lazy_init=True, test_mode=True) assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100], [200, 200, 200]]) # test custom class and palette don't match with pytest.raises(ValueError): dataset = BaseSegDataset( data_prefix=dict(img_path=MagicMock()), ann_file=MagicMock(), metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]), lazy_init=True)