1
0
mirror of https://github.com/open-mmlab/mmsegmentation.git synced 2025-06-03 22:03:48 +08:00
Peng Lu 788b37f78f
[Feature] Support NYU depth estimation dataset ()
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Please describe the motivation of this PR and the goal you want to
achieve through this PR.

## Modification

Please briefly describe what modification is made in this PR.
1. add `NYUDataset`class
2. add script to process NYU dataset
3. add transforms for loading depth map
4. add docs & unittest

## BC-breaking (Optional)

Does the modification introduce changes that break the
backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the
downstream projects should modify their code to keep compatibility with
this PR.

## Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases
here, and update the documentation.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
5. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
6. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
7. The documentation has been modified accordingly, like docstring or
example tutorials.
2023-08-17 11:39:44 +08:00

476 lines
16 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
import pytest
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, BDD100KDataset,
CityscapesDataset, COCOStuffDataset,
DecathlonDataset, DSDLSegDataset, ISPRSDataset,
LIPDataset, LoveDADataset, MapillaryDataset_v1,
MapillaryDataset_v2, NYUDataset, PascalVOCDataset,
PotsdamDataset, REFUGEDataset, SynapseDataset,
iSAIDDataset)
from mmseg.registry import DATASETS
from mmseg.utils import get_classes, get_palette
try:
from dsdl.dataset import DSDLDataset
except ImportError:
DSDLDataset = None
def test_classes():
assert list(
CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes')
assert list(PascalVOCDataset.METAINFO['classes']) == get_classes(
'voc') == get_classes('pascal_voc')
assert list(ADE20KDataset.METAINFO['classes']) == get_classes(
'ade') == get_classes('ade20k')
assert list(
COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff')
assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda')
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
assert list(
MapillaryDataset_v1.METAINFO['classes']) == get_classes('mapillary_v1')
assert list(
MapillaryDataset_v2.METAINFO['classes']) == get_classes('mapillary_v2')
assert list(BDD100KDataset.METAINFO['classes']) == get_classes('bdd100k')
with pytest.raises(ValueError):
get_classes('unsupported')
def test_classes_file_path():
tmp_file = tempfile.NamedTemporaryFile()
classes_path = f'{tmp_file.name}.txt'
train_pipeline = []
kwargs = dict(
pipeline=train_pipeline,
data_prefix=dict(img_path='./', seg_map_path='./'),
metainfo=dict(classes=classes_path))
# classes.txt with full categories
categories = get_classes('cityscapes')
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
dataset = CityscapesDataset(**kwargs)
assert list(dataset.metainfo['classes']) == categories
assert dataset.label_map is None
# classes.txt with sub categories
categories = ['road', 'sidewalk', 'building']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
dataset = CityscapesDataset(**kwargs)
assert list(dataset.metainfo['classes']) == categories
assert dataset.label_map is not None
# classes.txt with unknown categories
categories = ['road', 'sidewalk', 'unknown']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
with pytest.raises(ValueError):
CityscapesDataset(**kwargs)
tmp_file.close()
os.remove(classes_path)
assert not osp.exists(classes_path)
def test_palette():
assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes')
assert PascalVOCDataset.METAINFO['palette'] == get_palette(
'voc') == get_palette('pascal_voc')
assert ADE20KDataset.METAINFO['palette'] == get_palette(
'ade') == get_palette('ade20k')
assert LoveDADataset.METAINFO['palette'] == get_palette('loveda')
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
assert list(
MapillaryDataset_v1.METAINFO['palette']) == get_palette('mapillary_v1')
assert list(
MapillaryDataset_v2.METAINFO['palette']) == get_palette('mapillary_v2')
assert list(BDD100KDataset.METAINFO['palette']) == get_palette('bdd100k')
with pytest.raises(ValueError):
get_palette('unsupported')
def test_custom_dataset():
# with 'img_path' and 'seg_map_path' in data_prefix
train_dataset = BaseSegDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path='imgs/',
seg_map_path='gts/',
),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with 'img_path' and 'seg_map_path' in data_prefix and ann_file
train_dataset = BaseSegDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path='imgs/',
seg_map_path='gts/',
),
img_suffix='img.jpg',
seg_map_suffix='gt.png',
ann_file='splits/train.txt')
assert len(train_dataset) == 4
# no data_root
train_dataset = BaseSegDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
seg_map_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with data_root but 'img_path' and 'seg_map_path' in data_prefix are
# abs path
train_dataset = BaseSegDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
seg_map_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# test_mode=True
test_dataset = BaseSegDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
img_suffix='img.jpg',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )))
assert len(test_dataset) == 5
# training data get
train_data = train_dataset[0]
assert isinstance(train_data, dict)
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
assert 'seg_map_path' in train_data and osp.isfile(
train_data['seg_map_path'])
# test data get
test_data = test_dataset[0]
assert isinstance(test_data, dict)
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
assert 'seg_map_path' in train_data and osp.isfile(
train_data['seg_map_path'])
def test_ade():
test_dataset = ADE20KDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs')))
assert len(test_dataset) == 5
def test_cityscapes():
test_dataset = CityscapesDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit/val'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/gtFine/val')))
assert len(test_dataset) == 1
def test_loveda():
test_dataset = LoveDADataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_loveda_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_loveda_dataset/ann_dir')))
assert len(test_dataset) == 3
def test_potsdam():
test_dataset = PotsdamDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_potsdam_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_potsdam_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_vaihingen():
test_dataset = ISPRSDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_vaihingen_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_vaihingen_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_synapse():
test_dataset = SynapseDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_synapse_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_synapse_dataset/ann_dir')))
assert len(test_dataset) == 2
def test_refuge():
test_dataset = REFUGEDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_refuge_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_refuge_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_isaid():
test_dataset = iSAIDDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/ann_dir')))
assert len(test_dataset) == 2
test_dataset = iSAIDDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/ann_dir')),
ann_file=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/splits/train.txt'))
assert len(test_dataset) == 1
def test_decathlon():
data_root = osp.join(osp.dirname(__file__), '../data')
# test load training dataset
test_dataset = DecathlonDataset(
pipeline=[], data_root=data_root, ann_file='dataset.json')
assert len(test_dataset) == 1
# test load test dataset
test_dataset = DecathlonDataset(
pipeline=[],
data_root=data_root,
ann_file='dataset.json',
test_mode=True)
assert len(test_dataset) == 3
def test_lip():
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_lip_dataset')
# train load training dataset
train_dataset = LIPDataset(
pipeline=[],
data_root=data_root,
data_prefix=dict(
img_path='train_images', seg_map_path='train_segmentations'))
assert len(train_dataset) == 1
# test load training dataset
test_dataset = LIPDataset(
pipeline=[],
data_root=data_root,
data_prefix=dict(
img_path='val_images', seg_map_path='val_segmentations'))
assert len(test_dataset) == 1
def test_mapillary():
test_dataset = MapillaryDataset_v1(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_mapillary_dataset/images'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_mapillary_dataset/v1.2')))
assert len(test_dataset) == 1
def test_bdd100k():
test_dataset = BDD100KDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_bdd100k_dataset/images/10k/val'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val')))
assert len(test_dataset) == 3
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('BaseSegDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):
dataset_class = DATASETS.get(dataset)
original_classes = dataset_class.METAINFO.get('classes', None)
tmp_file = tempfile.NamedTemporaryFile()
ann_file = tmp_file.name
img_path = tempfile.mkdtemp()
# Test setting classes as a tuple
custom_dataset = dataset_class(
data_prefix=dict(img_path=img_path),
ann_file=ann_file,
metainfo=dict(classes=classes),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == classes
if not isinstance(custom_dataset, BaseSegDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test setting classes as a list
custom_dataset = dataset_class(
data_prefix=dict(img_path=img_path),
ann_file=ann_file,
metainfo=dict(classes=list(classes)),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == list(classes)
if not isinstance(custom_dataset, BaseSegDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test overriding not a subset
custom_dataset = dataset_class(
ann_file=ann_file,
data_prefix=dict(img_path=img_path),
metainfo=dict(classes=[classes[0]]),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == [classes[0]]
if not isinstance(custom_dataset, BaseSegDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test default behavior
if dataset_class is BaseSegDataset:
with pytest.raises(AssertionError):
custom_dataset = dataset_class(
ann_file=ann_file,
data_prefix=dict(img_path=img_path),
metainfo=None,
test_mode=True,
lazy_init=True)
else:
custom_dataset = dataset_class(
data_prefix=dict(img_path=img_path),
ann_file=ann_file,
metainfo=None,
test_mode=True,
lazy_init=True)
assert custom_dataset.METAINFO['classes'] == original_classes
assert custom_dataset.label_map is None
def test_custom_dataset_random_palette_is_generated():
dataset = BaseSegDataset(
pipeline=[],
data_prefix=dict(img_path=tempfile.mkdtemp()),
ann_file=tempfile.mkdtemp(),
metainfo=dict(classes=('bus', 'car')),
lazy_init=True,
test_mode=True)
assert len(dataset.metainfo['palette']) == 2
for class_color in dataset.metainfo['palette']:
assert len(class_color) == 3
assert all(x >= 0 and x <= 255 for x in class_color)
def test_custom_dataset_custom_palette():
dataset = BaseSegDataset(
data_prefix=dict(img_path=tempfile.mkdtemp()),
ann_file=tempfile.mkdtemp(),
metainfo=dict(
classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200,
200]]),
lazy_init=True,
test_mode=True)
assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100],
[200, 200, 200]])
# test custom class and palette don't match
with pytest.raises(ValueError):
dataset = BaseSegDataset(
data_prefix=dict(img_path=tempfile.mkdtemp()),
ann_file=tempfile.mkdtemp(),
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
lazy_init=True)
def test_dsdlseg_dataset():
if DSDLDataset is not None:
dataset = DSDLSegDataset(
data_root='tests/data/dsdl_seg', ann_file='set-train/train.yaml')
assert len(dataset) == 3
assert len(dataset.metainfo['classes']) == 21
else:
ImportWarning('Package `dsdl` is not installed.')
def test_nyu_dataset():
dataset = NYUDataset(
data_root='tests/data/pseudo_nyu_dataset',
data_prefix=dict(img_path='images', depth_map_path='annotations'),
)
assert len(dataset) == 1
data = dataset[0]
assert data.get('depth_map_path', None) is not None
assert data.get('category_id', -1) == 26