mmengine/tests/test_data/test_base_dataset.py

656 lines
26 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from unittest.mock import MagicMock
import pytest
import torch
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
ConcatDataset, RepeatDataset, force_full_init)
from mmengine.registry import TRANSFORMS
def function_pipeline(data_info):
return data_info
@TRANSFORMS.register_module()
class CallableTransform:
def __call__(self, data_info):
return data_info
@TRANSFORMS.register_module()
class NotCallableTransform:
pass
class TestBaseDataset:
dataset_type = BaseDataset
data_info = dict(
filename='test_img.jpg', height=604, width=640, sample_idx=0)
imgs = torch.rand((2, 3, 32, 32))
pipeline = MagicMock(return_value=dict(imgs=imgs))
META: dict = dict()
parse_annotations = MagicMock(return_value=data_info)
def _init_dataset(self):
self.dataset_type.META = self.META
self.dataset_type.parse_annotations = self.parse_annotations
def test_init(self):
self._init_dataset()
# test the instantiation of self.base_dataset
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert hasattr(dataset, 'data_address')
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json')
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert hasattr(dataset, 'data_address')
# test the instantiation of self.base_dataset with
# `serialize_data=False`
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
serialize_data=False)
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert not hasattr(dataset, 'data_address')
assert len(dataset) == 2
assert dataset.get_data_info(0) == self.data_info
# test the instantiation of self.base_dataset with lazy init
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=True)
assert not dataset._fully_initialized
assert not dataset.data_infos
# test the instantiation of self.base_dataset if ann_file is not
# existed.
with pytest.raises(FileNotFoundError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/not_existed_annotation.json')
# test the instantiation of self.base_dataset when the ann_file is
# wrong
with pytest.raises(ValueError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/annotation_wrong_keys.json')
with pytest.raises(TypeError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/annotation_wrong_format.json')
with pytest.raises(TypeError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=['img']),
ann_file='annotations/annotation_wrong_format.json')
# test the instantiation of self.base_dataset when `parse_annotations`
# return `list[dict]`
self.dataset_type.parse_annotations = MagicMock(
return_value=[self.data_info,
self.data_info.copy()])
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
dataset.pipeline = self.pipeline
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert hasattr(dataset, 'data_address')
assert len(dataset) == 4
assert dataset[0] == dict(imgs=self.imgs)
assert dataset.get_data_info(0) == self.data_info
# test the instantiation of self.base_dataset when `parse_annotations`
# return unsupported data.
with pytest.raises(TypeError):
self.dataset_type.parse_annotations = MagicMock(return_value='xxx')
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
with pytest.raises(TypeError):
self.dataset_type.parse_annotations = MagicMock(
return_value=[self.data_info, 'xxx'])
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
def test_meta(self):
self._init_dataset()
# test dataset.meta with setting the meta from annotation file as the
# meta of self.base_dataset
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset.meta == dict(
dataset_type='test_dataset', task_name='test_task', empty_list=[])
# test dataset.meta with setting META in self.base_dataset
dataset_type = 'new_dataset'
self.dataset_type.META = dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset.meta == dict(
dataset_type=dataset_type,
task_name='test_task',
classes=('dog', 'cat'),
empty_list=[])
# test dataset.meta with passing meta into self.base_dataset
meta = dict(classes=('dog', ), task_name='new_task')
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=meta)
assert self.dataset_type.META == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.meta == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])
# reset `base_dataset.META`, the `dataset.meta` should not change
self.dataset_type.META['classes'] = ('dog', 'cat', 'fish')
assert self.dataset_type.META == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
assert dataset.meta == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])
# test dataset.meta with passing meta containing a file into
# self.base_dataset
meta = dict(
classes=osp.join(
osp.dirname(__file__), '../data/meta/classes.txt'))
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=meta)
assert dataset.meta == dict(
dataset_type=dataset_type,
task_name='test_task',
classes=['dog'],
empty_list=[])
# test dataset.meta with passing unsupported meta into
# self.base_dataset
with pytest.raises(TypeError):
meta = 'dog'
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=meta)
# test dataset.meta with passing meta into self.base_dataset and
# lazy_init is True
meta = dict(classes=('dog', ))
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=meta,
lazy_init=True)
# 'task_name' and 'empty_list' not in dataset.meta
assert dataset.meta == dict(
dataset_type=dataset_type, classes=('dog', ))
# test whether self.base_dataset.META is changed when a customize
# dataset inherit self.base_dataset
# test reset META in ToyDataset.
class ToyDataset(self.dataset_type):
META = dict(xxx='xxx')
assert ToyDataset.META == dict(xxx='xxx')
assert self.dataset_type.META == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
# test update META in ToyDataset.
class ToyDataset(self.dataset_type):
META = copy.deepcopy(self.dataset_type.META)
META['classes'] = ('bird', )
assert ToyDataset.META == dict(
dataset_type=dataset_type, classes=('bird', ))
assert self.dataset_type.META == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
@pytest.mark.parametrize('lazy_init', [True, False])
def test_length(self, lazy_init):
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert len(dataset) == 2
else:
# test `__len__()` when lazy_init is True
assert not dataset._fully_initialized
assert not dataset.data_infos
# call `full_init()` automatically
assert len(dataset) == 2
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
def test_compose(self):
# test callable transform
transforms = [function_pipeline]
compose = Compose(transforms=transforms)
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
# test transform build from cfg_dict
transforms = [dict(type='CallableTransform')]
compose = Compose(transforms=transforms)
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
# test return None in advance
none_func = MagicMock(return_value=None)
transforms = [none_func, function_pipeline]
compose = Compose(transforms=transforms)
assert compose(dict(img=self.imgs)) is None
# test repr
repr_str = f'Compose(\n' \
f' {none_func}\n' \
f' {function_pipeline}\n' \
f')'
assert repr(compose) == repr_str
# non-callable transform will raise error
with pytest.raises(TypeError):
transforms = [dict(type='NotCallableTransform')]
Compose(transforms)
# transform must be callable or dict
with pytest.raises(TypeError):
Compose([1])
@pytest.mark.parametrize('lazy_init', [True, False])
def test_getitem(self, lazy_init):
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
dataset.pipeline = self.pipeline
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert dataset[0] == dict(imgs=self.imgs)
else:
# test `__getitem__()` when lazy_init is True
assert not dataset._fully_initialized
assert not dataset.data_infos
# call `full_init()` automatically
assert dataset[0] == dict(imgs=self.imgs)
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
# test with test mode
dataset.test_mode = True
assert dataset[0] == dict(imgs=self.imgs)
pipeline = MagicMock(return_value=None)
dataset.pipeline = pipeline
# test cannot get a valid image.
dataset.test_mode = False
with pytest.raises(Exception):
dataset[0]
@pytest.mark.parametrize('lazy_init', [True, False])
def test_get_data_info(self, lazy_init):
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert dataset.get_data_info(0) == self.data_info
else:
# test `get_data_info()` when lazy_init is True
assert not dataset._fully_initialized
assert not dataset.data_infos
# call `full_init()` automatically
assert dataset.get_data_info(0) == self.data_info
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
def test_force_full_init(self):
with pytest.raises(AttributeError):
class ClassWithoutFullInit:
@force_full_init
def foo(self):
pass
class_without_full_init = ClassWithoutFullInit()
class_without_full_init.foo()
@pytest.mark.parametrize('lazy_init', [True, False])
def test_full_init(self, lazy_init):
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
dataset.pipeline = self.pipeline
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert len(dataset) == 2
assert dataset[0] == dict(imgs=self.imgs)
assert dataset.get_data_info(0) == self.data_info
else:
# test `full_init()` when lazy_init is True
assert not dataset._fully_initialized
assert not dataset.data_infos
# call `full_init()` manually
dataset.full_init()
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert len(dataset) == 2
assert dataset[0] == dict(imgs=self.imgs)
assert dataset.get_data_info(0) == self.data_info
def test_slice_data(self):
# test the instantiation of self.base_dataset when passing num_samples
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json',
num_samples=1)
assert len(dataset) == 1
def test_rand_another(self):
# test the instantiation of self.base_dataset when passing num_samples
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json',
num_samples=1)
assert dataset._rand_another() >= 0
assert dataset._rand_another() < len(dataset)
class TestConcatDataset:
def _init_dataset(self):
dataset = BaseDataset
# create dataset_a
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
self.dataset_a = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
# create dataset_b
data_info = dict(filename='gray.jpg', height=288, width=512)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
self.dataset_b = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=dict(classes=('dog', 'cat')))
self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
# test init
self.cat_datasets = ConcatDataset(
datasets=[self.dataset_a, self.dataset_b])
def test_full_init(self):
dataset = BaseDataset
# create dataset_a
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset_a = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
# create dataset_b
data_info = dict(filename='gray.jpg', height=288, width=512)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset_b = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=dict(classes=('dog', 'cat')))
dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
# test init with lazy_init=True
cat_datasets = ConcatDataset(
datasets=[dataset_a, dataset_b], lazy_init=True)
cat_datasets.full_init()
assert len(cat_datasets) == 4
cat_datasets.full_init()
cat_datasets._fully_initialized = False
cat_datasets[1]
assert len(cat_datasets) == 4
def test_meta(self):
self._init_dataset()
assert self.cat_datasets.meta == self.dataset_a.meta
# meta of self.cat_datasets is from the first dataset when
# concatnating datasets with different metas.
assert self.cat_datasets.meta != self.dataset_b.meta
def test_length(self):
self._init_dataset()
assert len(self.cat_datasets) == (
len(self.dataset_a) + len(self.dataset_b))
def test_getitem(self):
self._init_dataset()
assert (
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
assert (self.cat_datasets[0]['imgs'] !=
self.dataset_b[0]['imgs']).all()
assert (
self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
assert (self.cat_datasets[-1]['imgs'] !=
self.dataset_a[-1]['imgs']).all()
def test_get_data_info(self):
self._init_dataset()
assert self.cat_datasets.get_data_info(
0) == self.dataset_a.get_data_info(0)
assert self.cat_datasets.get_data_info(
0) != self.dataset_b.get_data_info(0)
assert self.cat_datasets.get_data_info(
-1) == self.dataset_b.get_data_info(-1)
assert self.cat_datasets.get_data_info(
-1) != self.dataset_a.get_data_info(-1)
def test_get_ori_dataset_idx(self):
self._init_dataset()
assert self.cat_datasets._get_ori_dataset_idx(3) == (
1, 3 - len(self.dataset_a))
assert self.cat_datasets._get_ori_dataset_idx(-1) == (
1, len(self.dataset_b) - 1)
with pytest.raises(ValueError):
assert self.cat_datasets._get_ori_dataset_idx(-10)
class TestRepeatDataset:
def _init_dataset(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
self.dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.repeat_times = 5
# test init
self.repeat_datasets = RepeatDataset(
dataset=self.dataset, times=self.repeat_times)
def test_full_init(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
repeat_times = 5
# test init
repeat_datasets = RepeatDataset(
dataset=dataset, times=repeat_times, lazy_init=True)
repeat_datasets.full_init()
assert len(repeat_datasets) == repeat_times * len(dataset)
repeat_datasets.full_init()
repeat_datasets._fully_initialized = False
repeat_datasets[1]
assert len(repeat_datasets) == repeat_times * len(dataset)
def test_meta(self):
self._init_dataset()
assert self.repeat_datasets.meta == self.dataset.meta
def test_length(self):
self._init_dataset()
assert len(
self.repeat_datasets) == len(self.dataset) * self.repeat_times
def test_getitem(self):
self._init_dataset()
for i in range(self.repeat_times):
assert self.repeat_datasets[len(self.dataset) *
i] == self.dataset[0]
def test_get_data_info(self):
self._init_dataset()
for i in range(self.repeat_times):
assert self.repeat_datasets.get_data_info(
len(self.dataset) * i) == self.dataset.get_data_info(0)
class TestClassBalancedDataset:
def _init_dataset(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset.get_cat_ids = MagicMock(return_value=[0])
self.dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.repeat_indices = [0, 0, 1, 1, 1]
# test init
self.cls_banlanced_datasets = ClassBalancedDataset(
dataset=self.dataset, oversample_thr=1e-3)
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
def test_full_init(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset.get_cat_ids = MagicMock(return_value=[0])
dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
repeat_indices = [0, 0, 1, 1, 1]
# test init
cls_banlanced_datasets = ClassBalancedDataset(
dataset=dataset, oversample_thr=1e-3, lazy_init=True)
cls_banlanced_datasets.full_init()
cls_banlanced_datasets.repeat_indices = repeat_indices
assert len(cls_banlanced_datasets) == len(repeat_indices)
cls_banlanced_datasets.full_init()
cls_banlanced_datasets._fully_initialized = False
cls_banlanced_datasets[1]
cls_banlanced_datasets.repeat_indices = repeat_indices
assert len(cls_banlanced_datasets) == len(repeat_indices)
def test_meta(self):
self._init_dataset()
assert self.cls_banlanced_datasets.meta == self.dataset.meta
def test_length(self):
self._init_dataset()
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
def test_getitem(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets[i] == self.dataset[
self.repeat_indices[i]]
def test_get_data_info(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets.get_data_info(
i) == self.dataset.get_data_info(self.repeat_indices[i])
def test_get_cat_ids(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets.get_cat_ids(
i) == self.dataset.get_cat_ids(self.repeat_indices[i])