844 lines
33 KiB
Python
844 lines
33 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os.path as osp
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
|
|
ConcatDataset, RepeatDataset, force_full_init)
|
|
from mmengine.registry import DATASETS, TRANSFORMS
|
|
|
|
|
|
def function_pipeline(data_info):
|
|
return data_info
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class CallableTransform:
|
|
|
|
def __call__(self, data_info):
|
|
return data_info
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class NotCallableTransform:
|
|
pass
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class CustomDataset(BaseDataset):
|
|
pass
|
|
|
|
|
|
class TestBaseDataset:
|
|
|
|
def setup(self):
|
|
self.data_info = dict(
|
|
filename='test_img.jpg', height=604, width=640, sample_idx=0)
|
|
self.imgs = torch.rand((2, 3, 32, 32))
|
|
self.ori_meta = BaseDataset.METAINFO
|
|
self.ori_parse_data_info = BaseDataset.parse_data_info
|
|
BaseDataset.parse_data_info = MagicMock(return_value=self.data_info)
|
|
self.pipeline = MagicMock(return_value=dict(imgs=self.imgs))
|
|
|
|
def teardown(self):
|
|
BaseDataset.METAINFO = self.ori_meta
|
|
BaseDataset.parse_data_info = self.ori_parse_data_info
|
|
|
|
def test_init(self):
|
|
# test the instantiation of self.base_dataset
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert hasattr(dataset, 'data_address')
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img=None),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert hasattr(dataset, 'data_address')
|
|
|
|
# test the instantiation of self.base_dataset with
|
|
# `serialize_data=False`
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
serialize_data=False)
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert not hasattr(dataset, 'data_address')
|
|
assert len(dataset) == 3
|
|
assert dataset.get_data_info(0) == self.data_info
|
|
|
|
# test the instantiation of self.base_dataset with lazy init
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
lazy_init=True)
|
|
assert not dataset._fully_initialized
|
|
assert not dataset.data_list
|
|
|
|
# test the instantiation of self.base_dataset if ann_file is not
|
|
# existed.
|
|
with pytest.raises(FileNotFoundError):
|
|
BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/not_existed_annotation.json')
|
|
# Use the default value of ann_file, i.e., ''
|
|
with pytest.raises(TypeError):
|
|
BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'))
|
|
|
|
# test the instantiation of self.base_dataset when the ann_file is
|
|
# wrong
|
|
with pytest.raises(ValueError):
|
|
BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/annotation_wrong_keys.json')
|
|
with pytest.raises(TypeError):
|
|
BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/annotation_wrong_format.json')
|
|
with pytest.raises(TypeError):
|
|
BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img=['img']),
|
|
ann_file='annotations/annotation_wrong_format.json')
|
|
|
|
# test the instantiation of self.base_dataset when `parse_data_info`
|
|
# return `list[dict]`
|
|
BaseDataset.parse_data_info = MagicMock(
|
|
return_value=[self.data_info,
|
|
self.data_info.copy()])
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
dataset.pipeline = self.pipeline
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert hasattr(dataset, 'data_address')
|
|
assert len(dataset) == 6
|
|
assert dataset[0] == dict(imgs=self.imgs)
|
|
assert dataset.get_data_info(0) == self.data_info
|
|
|
|
# test the instantiation of self.base_dataset when `parse_data_info`
|
|
# return unsupported data.
|
|
with pytest.raises(TypeError):
|
|
BaseDataset.parse_data_info = MagicMock(return_value='xxx')
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
with pytest.raises(TypeError):
|
|
BaseDataset.parse_data_info = MagicMock(
|
|
return_value=[self.data_info, 'xxx'])
|
|
BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
# test the instantiation of self.base_dataset without `ann_file`
|
|
BaseDataset.parse_data_info = self.ori_parse_data_info
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='',
|
|
serialize_data=False,
|
|
lazy_init=True)
|
|
assert not dataset.ann_file
|
|
|
|
def test_meta(self):
|
|
# test dataset.metainfo with setting the metainfo from annotation file
|
|
# as the metainfo of self.base_dataset.
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
|
|
assert dataset.metainfo == dict(
|
|
dataset_type='test_dataset', task_name='test_task', empty_list=[])
|
|
|
|
# test dataset.metainfo with setting METAINFO in self.base_dataset
|
|
dataset_type = 'new_dataset'
|
|
BaseDataset.METAINFO = dict(
|
|
dataset_type=dataset_type, classes=('dog', 'cat'))
|
|
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
assert dataset.metainfo == dict(
|
|
dataset_type=dataset_type,
|
|
task_name='test_task',
|
|
classes=('dog', 'cat'),
|
|
empty_list=[])
|
|
|
|
# test dataset.metainfo with passing metainfo into self.base_dataset
|
|
metainfo = dict(classes=('dog', ), task_name='new_task')
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
metainfo=metainfo)
|
|
assert BaseDataset.METAINFO == dict(
|
|
dataset_type=dataset_type, classes=('dog', 'cat'))
|
|
assert dataset.metainfo == dict(
|
|
dataset_type=dataset_type,
|
|
task_name='new_task',
|
|
classes=('dog', ),
|
|
empty_list=[])
|
|
# reset `base_dataset.METAINFO`, the `dataset.metainfo` should not
|
|
# change
|
|
BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish')
|
|
assert BaseDataset.METAINFO == dict(
|
|
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
|
|
assert dataset.metainfo == dict(
|
|
dataset_type=dataset_type,
|
|
task_name='new_task',
|
|
classes=('dog', ),
|
|
empty_list=[])
|
|
|
|
# test dataset.metainfo with passing metainfo containing a file into
|
|
# self.base_dataset
|
|
metainfo = dict(
|
|
classes=osp.join(
|
|
osp.dirname(__file__), '../data/meta/classes.txt'))
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
metainfo=metainfo)
|
|
assert dataset.metainfo == dict(
|
|
dataset_type=dataset_type,
|
|
task_name='test_task',
|
|
classes=['dog'],
|
|
empty_list=[])
|
|
|
|
# test dataset.metainfo with passing unsupported metainfo into
|
|
# self.base_dataset
|
|
with pytest.raises(TypeError):
|
|
metainfo = 'dog'
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
metainfo=metainfo)
|
|
|
|
# test dataset.metainfo with passing metainfo into self.base_dataset
|
|
# and lazy_init is True
|
|
metainfo = dict(classes=('dog', ))
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
metainfo=metainfo,
|
|
lazy_init=True)
|
|
# 'task_name' and 'empty_list' not in dataset.metainfo
|
|
assert dataset.metainfo == dict(
|
|
dataset_type=dataset_type, classes=('dog', ))
|
|
|
|
# test whether self.base_dataset.METAINFO is changed when a customize
|
|
# dataset inherit self.base_dataset
|
|
# test reset METAINFO in ToyDataset.
|
|
class ToyDataset(BaseDataset):
|
|
METAINFO = dict(xxx='xxx')
|
|
|
|
assert ToyDataset.METAINFO == dict(xxx='xxx')
|
|
assert BaseDataset.METAINFO == dict(
|
|
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
|
|
|
|
# test update METAINFO in ToyDataset.
|
|
class ToyDataset(BaseDataset):
|
|
METAINFO = copy.deepcopy(BaseDataset.METAINFO)
|
|
METAINFO['classes'] = ('bird', )
|
|
|
|
assert ToyDataset.METAINFO == dict(
|
|
dataset_type=dataset_type, classes=('bird', ))
|
|
assert BaseDataset.METAINFO == dict(
|
|
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
|
|
|
|
@pytest.mark.parametrize('lazy_init', [True, False])
|
|
def test_length(self, lazy_init):
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
lazy_init=lazy_init)
|
|
if not lazy_init:
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert len(dataset) == 3
|
|
else:
|
|
# test `__len__()` when lazy_init is True
|
|
assert not dataset._fully_initialized
|
|
assert not dataset.data_list
|
|
# call `full_init()` automatically
|
|
assert len(dataset) == 3
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
|
|
def test_compose(self):
|
|
# test callable transform
|
|
transforms = [function_pipeline]
|
|
compose = Compose(transforms=transforms)
|
|
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
|
|
# test transform build from cfg_dict
|
|
transforms = [dict(type='CallableTransform')]
|
|
compose = Compose(transforms=transforms)
|
|
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
|
|
# test return None in advance
|
|
none_func = MagicMock(return_value=None)
|
|
transforms = [none_func, function_pipeline]
|
|
compose = Compose(transforms=transforms)
|
|
assert compose(dict(img=self.imgs)) is None
|
|
# test repr
|
|
repr_str = f'Compose(\n' \
|
|
f' {none_func}\n' \
|
|
f' {function_pipeline}\n' \
|
|
f')'
|
|
assert repr(compose) == repr_str
|
|
# non-callable transform will raise error
|
|
with pytest.raises(TypeError):
|
|
transforms = [dict(type='NotCallableTransform')]
|
|
Compose(transforms)
|
|
|
|
# transform must be callable or dict
|
|
with pytest.raises(TypeError):
|
|
Compose([1])
|
|
|
|
@pytest.mark.parametrize('lazy_init', [True, False])
|
|
def test_getitem(self, lazy_init):
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
lazy_init=lazy_init)
|
|
dataset.pipeline = self.pipeline
|
|
if not lazy_init:
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert dataset[0] == dict(imgs=self.imgs)
|
|
else:
|
|
# Test `__getitem__()` when lazy_init is True
|
|
assert not dataset._fully_initialized
|
|
assert not dataset.data_list
|
|
# Call `full_init()` automatically
|
|
assert dataset[0] == dict(imgs=self.imgs)
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
|
|
# Test with test mode
|
|
dataset.test_mode = False
|
|
assert dataset[0] == dict(imgs=self.imgs)
|
|
# Test cannot get a valid image.
|
|
dataset.prepare_data = MagicMock(return_value=None)
|
|
with pytest.raises(Exception):
|
|
dataset[0]
|
|
# Test get valid image by `_rand_another`
|
|
|
|
def fake_prepare_data(idx):
|
|
if idx == 0:
|
|
return None
|
|
else:
|
|
return 1
|
|
|
|
dataset.prepare_data = fake_prepare_data
|
|
dataset[0]
|
|
dataset.test_mode = True
|
|
with pytest.raises(Exception):
|
|
dataset[0]
|
|
|
|
@pytest.mark.parametrize('lazy_init', [True, False])
|
|
def test_get_data_info(self, lazy_init):
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
lazy_init=lazy_init)
|
|
|
|
if not lazy_init:
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert dataset.get_data_info(0) == self.data_info
|
|
else:
|
|
# test `get_data_info()` when lazy_init is True
|
|
assert not dataset._fully_initialized
|
|
assert not dataset.data_list
|
|
# call `full_init()` automatically
|
|
assert dataset.get_data_info(0) == self.data_info
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
# Test parse_data_info with `data_prefix`
|
|
BaseDataset.parse_data_info = self.ori_parse_data_info
|
|
data_root = osp.join(osp.dirname(__file__), '../data/')
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img_path='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
data_info = dataset.get_data_info(0)
|
|
assert data_info['img_path'] == osp.join(data_root, 'imgs',
|
|
'test_img.jpg')
|
|
|
|
def test_force_full_init(self):
|
|
with pytest.raises(AttributeError):
|
|
|
|
class ClassWithoutFullInit:
|
|
|
|
@force_full_init
|
|
def foo(self):
|
|
pass
|
|
|
|
class_without_full_init = ClassWithoutFullInit()
|
|
class_without_full_init.foo()
|
|
|
|
def test_full_init(self):
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
lazy_init=True)
|
|
dataset.pipeline = self.pipeline
|
|
# test `full_init()` when lazy_init is True
|
|
assert not dataset._fully_initialized
|
|
assert not dataset.data_list
|
|
# call `full_init()` manually
|
|
dataset.full_init()
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert len(dataset) == 3
|
|
assert dataset[0] == dict(imgs=self.imgs)
|
|
assert dataset.get_data_info(0) == self.data_info
|
|
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
lazy_init=False)
|
|
|
|
dataset.pipeline = self.pipeline
|
|
assert dataset._fully_initialized
|
|
assert hasattr(dataset, 'data_list')
|
|
assert len(dataset) == 3
|
|
assert dataset[0] == dict(imgs=self.imgs)
|
|
assert dataset.get_data_info(0) == self.data_info
|
|
|
|
# test the instantiation of self.base_dataset when passing indices
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img=None),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
dataset_sliced = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img=None),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
indices=1)
|
|
assert dataset_sliced[0] == dataset[0]
|
|
assert len(dataset_sliced) == 1
|
|
|
|
@pytest.mark.parametrize(
|
|
'lazy_init, serialize_data',
|
|
([True, False], [False, True], [True, True], [False, False]))
|
|
def test_get_subset_(self, lazy_init, serialize_data):
|
|
# Test positive int indices.
|
|
indices = 2
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img=None),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
lazy_init=lazy_init,
|
|
serialize_data=serialize_data)
|
|
|
|
dataset_copy = copy.deepcopy(dataset)
|
|
dataset_copy.get_subset_(indices)
|
|
assert len(dataset_copy) == 2
|
|
for i in range(len(dataset_copy)):
|
|
ori_data = dataset[i]
|
|
assert dataset_copy[i] == ori_data
|
|
|
|
# Test negative int indices.
|
|
indices = -2
|
|
dataset_copy = copy.deepcopy(dataset)
|
|
dataset_copy.get_subset_(indices)
|
|
assert len(dataset_copy) == 2
|
|
for i in range(len(dataset_copy)):
|
|
ori_data = dataset[i + 1]
|
|
ori_data['sample_idx'] = i
|
|
assert dataset_copy[i] == ori_data
|
|
|
|
# If indices is 0, return empty dataset.
|
|
dataset_copy = copy.deepcopy(dataset)
|
|
dataset_copy.get_subset_(0)
|
|
assert len(dataset_copy) == 0
|
|
|
|
# Test list indices with positive element.
|
|
indices = [1]
|
|
dataset_copy = copy.deepcopy(dataset)
|
|
ori_data = dataset[1]
|
|
ori_data['sample_idx'] = 0
|
|
dataset_copy.get_subset_(indices)
|
|
assert len(dataset_copy) == 1
|
|
assert dataset_copy[0] == ori_data
|
|
|
|
# Test list indices with negative element.
|
|
indices = [-1]
|
|
dataset_copy = copy.deepcopy(dataset)
|
|
ori_data = dataset[2]
|
|
ori_data['sample_idx'] = 0
|
|
dataset_copy.get_subset_(indices)
|
|
assert len(dataset_copy) == 1
|
|
assert dataset_copy[0] == ori_data
|
|
|
|
# Test empty list.
|
|
indices = []
|
|
dataset_copy = copy.deepcopy(dataset)
|
|
dataset_copy.get_subset_(indices)
|
|
assert len(dataset_copy) == 0
|
|
# Test list with multiple positive indices.
|
|
indices = [0, 1, 2]
|
|
dataset_copy = copy.deepcopy(dataset)
|
|
dataset_copy.get_subset_(indices)
|
|
for i in range(len(dataset_copy)):
|
|
ori_data = dataset[i]
|
|
ori_data['sample_idx'] = i
|
|
assert dataset_copy[i] == ori_data
|
|
# Test list with multiple negative indices.
|
|
indices = [-1, -2, 0]
|
|
dataset_copy = copy.deepcopy(dataset)
|
|
dataset_copy.get_subset_(indices)
|
|
for i in range(len(dataset_copy)):
|
|
ori_data = dataset[len(dataset) - i - 1]
|
|
ori_data['sample_idx'] = i
|
|
assert dataset_copy[i] == ori_data
|
|
|
|
with pytest.raises(TypeError):
|
|
dataset.get_subset_(dict())
|
|
|
|
@pytest.mark.parametrize(
|
|
'lazy_init, serialize_data',
|
|
([True, False], [False, True], [True, True], [False, False]))
|
|
def test_get_subset(self, lazy_init, serialize_data):
|
|
# Test positive indices.
|
|
indices = 2
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img=None),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
lazy_init=lazy_init,
|
|
serialize_data=serialize_data)
|
|
dataset_sliced = dataset.get_subset(indices)
|
|
assert len(dataset_sliced) == 2
|
|
assert dataset_sliced[0] == dataset[0]
|
|
for i in range(len(dataset_sliced)):
|
|
assert dataset_sliced[i] == dataset[i]
|
|
# Test negative indices.
|
|
indices = -2
|
|
dataset_sliced = dataset.get_subset(indices)
|
|
assert len(dataset_sliced) == 2
|
|
for i in range(len(dataset_sliced)):
|
|
ori_data = dataset[i + 1]
|
|
ori_data['sample_idx'] = i
|
|
assert dataset_sliced[i] == ori_data
|
|
# If indices is 0 or empty list, return empty dataset.
|
|
assert len(dataset.get_subset(0)) == 0
|
|
assert len(dataset.get_subset([])) == 0
|
|
# test list indices.
|
|
indices = [1]
|
|
dataset_sliced = dataset.get_subset(indices)
|
|
ori_data = dataset[1]
|
|
ori_data['sample_idx'] = 0
|
|
assert len(dataset_sliced) == 1
|
|
assert dataset_sliced[0] == ori_data
|
|
# Test list with multiple positive index.
|
|
indices = [0, 1, 2]
|
|
dataset_sliced = dataset.get_subset(indices)
|
|
for i in range(len(dataset_sliced)):
|
|
ori_data = dataset[i]
|
|
ori_data['sample_idx'] = i
|
|
assert dataset_sliced[i] == ori_data
|
|
# Test list with multiple negative index.
|
|
indices = [-1, -2, 0]
|
|
dataset_sliced = dataset.get_subset(indices)
|
|
for i in range(len(dataset_sliced)):
|
|
ori_data = dataset[len(dataset) - i - 1]
|
|
ori_data['sample_idx'] = i
|
|
assert dataset_sliced[i] == ori_data
|
|
|
|
def test_rand_another(self):
|
|
# test the instantiation of self.base_dataset when passing num_samples
|
|
dataset = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img=None),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
indices=1)
|
|
assert dataset._rand_another() >= 0
|
|
assert dataset._rand_another() < len(dataset)
|
|
|
|
|
|
class TestConcatDataset:
|
|
|
|
def setup(self):
|
|
dataset = BaseDataset
|
|
|
|
# create dataset_a
|
|
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
|
dataset.parse_data_info = MagicMock(return_value=data_info)
|
|
imgs = torch.rand((2, 3, 32, 32))
|
|
|
|
self.dataset_a = dataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
|
|
|
# create dataset_b
|
|
data_info = dict(filename='gray.jpg', height=288, width=512)
|
|
dataset.parse_data_info = MagicMock(return_value=data_info)
|
|
imgs = torch.rand((2, 3, 32, 32))
|
|
self.dataset_b = dataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
|
# test init
|
|
self.cat_datasets = ConcatDataset(
|
|
datasets=[self.dataset_a, self.dataset_b])
|
|
|
|
def test_init(self):
|
|
# Test build dataset from cfg.
|
|
dataset_cfg_b = dict(
|
|
type=CustomDataset,
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b])
|
|
cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline
|
|
assert len(cat_datasets) == len(self.cat_datasets)
|
|
for i in range(len(cat_datasets)):
|
|
assert (cat_datasets.get_data_info(i) ==
|
|
self.cat_datasets.get_data_info(i))
|
|
assert (cat_datasets[i] == self.cat_datasets[i])
|
|
|
|
with pytest.raises(TypeError):
|
|
ConcatDataset(datasets=[0])
|
|
|
|
def test_full_init(self):
|
|
# test init with lazy_init=True
|
|
self.cat_datasets.full_init()
|
|
assert len(self.cat_datasets) == 6
|
|
self.cat_datasets.full_init()
|
|
self.cat_datasets._fully_initialized = False
|
|
self.cat_datasets[1]
|
|
assert len(self.cat_datasets) == 6
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
self.cat_datasets.get_subset_(1)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
self.cat_datasets.get_subset(1)
|
|
# Different meta information will raise error.
|
|
with pytest.raises(ValueError):
|
|
dataset_b = BaseDataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json',
|
|
metainfo=dict(classes=('cat')))
|
|
ConcatDataset(datasets=[self.dataset_a, dataset_b])
|
|
|
|
def test_metainfo(self):
|
|
assert self.cat_datasets.metainfo == self.dataset_a.metainfo
|
|
|
|
def test_length(self):
|
|
assert len(self.cat_datasets) == (
|
|
len(self.dataset_a) + len(self.dataset_b))
|
|
|
|
def test_getitem(self):
|
|
assert (
|
|
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
|
|
assert (self.cat_datasets[0]['imgs'] !=
|
|
self.dataset_b[0]['imgs']).all()
|
|
|
|
assert (
|
|
self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
|
|
assert (self.cat_datasets[-1]['imgs'] !=
|
|
self.dataset_a[-1]['imgs']).all()
|
|
|
|
def test_get_data_info(self):
|
|
assert self.cat_datasets.get_data_info(
|
|
0) == self.dataset_a.get_data_info(0)
|
|
assert self.cat_datasets.get_data_info(
|
|
0) != self.dataset_b.get_data_info(0)
|
|
|
|
assert self.cat_datasets.get_data_info(
|
|
-1) == self.dataset_b.get_data_info(-1)
|
|
assert self.cat_datasets.get_data_info(
|
|
-1) != self.dataset_a.get_data_info(-1)
|
|
|
|
def test_get_ori_dataset_idx(self):
|
|
assert self.cat_datasets._get_ori_dataset_idx(3) == (
|
|
1, 3 - len(self.dataset_a))
|
|
assert self.cat_datasets._get_ori_dataset_idx(-1) == (
|
|
1, len(self.dataset_b) - 1)
|
|
with pytest.raises(ValueError):
|
|
assert self.cat_datasets._get_ori_dataset_idx(-10)
|
|
|
|
|
|
class TestRepeatDataset:
|
|
|
|
def setup(self):
|
|
dataset = BaseDataset
|
|
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
|
dataset.parse_data_info = MagicMock(return_value=data_info)
|
|
imgs = torch.rand((2, 3, 32, 32))
|
|
self.dataset = dataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
|
|
|
self.repeat_times = 5
|
|
# test init
|
|
self.repeat_datasets = RepeatDataset(
|
|
dataset=self.dataset, times=self.repeat_times)
|
|
|
|
def test_init(self):
|
|
# Test build dataset from cfg.
|
|
dataset_cfg = dict(
|
|
type=CustomDataset,
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
repeat_dataset = RepeatDataset(
|
|
dataset=dataset_cfg, times=self.repeat_times)
|
|
repeat_dataset.dataset.pipeline = self.dataset.pipeline
|
|
assert len(repeat_dataset) == len(self.repeat_datasets)
|
|
for i in range(len(repeat_dataset)):
|
|
assert (repeat_dataset.get_data_info(i) ==
|
|
self.repeat_datasets.get_data_info(i))
|
|
assert (repeat_dataset[i] == self.repeat_datasets[i])
|
|
|
|
with pytest.raises(TypeError):
|
|
RepeatDataset(dataset=[0], times=5)
|
|
|
|
def test_full_init(self):
|
|
self.repeat_datasets.full_init()
|
|
assert len(
|
|
self.repeat_datasets) == self.repeat_times * len(self.dataset)
|
|
self.repeat_datasets.full_init()
|
|
self.repeat_datasets._fully_initialized = False
|
|
self.repeat_datasets[1]
|
|
assert len(self.repeat_datasets) == \
|
|
self.repeat_times * len(self.dataset)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
self.repeat_datasets.get_subset_(1)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
self.repeat_datasets.get_subset(1)
|
|
|
|
def test_metainfo(self):
|
|
assert self.repeat_datasets.metainfo == self.dataset.metainfo
|
|
|
|
def test_length(self):
|
|
assert len(
|
|
self.repeat_datasets) == len(self.dataset) * self.repeat_times
|
|
|
|
def test_getitem(self):
|
|
for i in range(self.repeat_times):
|
|
assert self.repeat_datasets[len(self.dataset) *
|
|
i] == self.dataset[0]
|
|
|
|
def test_get_data_info(self):
|
|
for i in range(self.repeat_times):
|
|
assert self.repeat_datasets.get_data_info(
|
|
len(self.dataset) * i) == self.dataset.get_data_info(0)
|
|
|
|
|
|
class TestClassBalancedDataset:
|
|
|
|
def setup(self):
|
|
dataset = BaseDataset
|
|
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
|
dataset.parse_data_info = MagicMock(return_value=data_info)
|
|
imgs = torch.rand((2, 3, 32, 32))
|
|
dataset.get_cat_ids = MagicMock(return_value=[0])
|
|
self.dataset = dataset(
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
|
|
|
self.repeat_indices = [0, 0, 1, 1, 1]
|
|
# test init
|
|
self.cls_banlanced_datasets = ClassBalancedDataset(
|
|
dataset=self.dataset, oversample_thr=1e-3)
|
|
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
|
|
|
def test_init(self):
|
|
# Test build dataset from cfg.
|
|
dataset_cfg = dict(
|
|
type=CustomDataset,
|
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
|
data_prefix=dict(img='imgs'),
|
|
ann_file='annotations/dummy_annotation.json')
|
|
cls_banlanced_datasets = ClassBalancedDataset(
|
|
dataset=dataset_cfg, oversample_thr=1e-3)
|
|
cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
|
cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline
|
|
assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets)
|
|
for i in range(len(cls_banlanced_datasets)):
|
|
assert (cls_banlanced_datasets.get_data_info(i) ==
|
|
self.cls_banlanced_datasets.get_data_info(i))
|
|
assert (
|
|
cls_banlanced_datasets[i] == self.cls_banlanced_datasets[i])
|
|
|
|
with pytest.raises(TypeError):
|
|
ClassBalancedDataset(dataset=[0], times=5)
|
|
|
|
def test_full_init(self):
|
|
self.cls_banlanced_datasets.full_init()
|
|
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
|
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
|
|
# Reinit `repeat_indices`.
|
|
self.cls_banlanced_datasets._fully_initialized = False
|
|
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
|
assert len(self.cls_banlanced_datasets) != len(self.repeat_indices)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
self.cls_banlanced_datasets.get_subset_(1)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
self.cls_banlanced_datasets.get_subset(1)
|
|
|
|
def test_metainfo(self):
|
|
assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo
|
|
|
|
def test_length(self):
|
|
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
|
|
|
|
def test_getitem(self):
|
|
for i in range(len(self.repeat_indices)):
|
|
assert self.cls_banlanced_datasets[i] == self.dataset[
|
|
self.repeat_indices[i]]
|
|
|
|
def test_get_data_info(self):
|
|
for i in range(len(self.repeat_indices)):
|
|
assert self.cls_banlanced_datasets.get_data_info(
|
|
i) == self.dataset.get_data_info(self.repeat_indices[i])
|
|
|
|
def test_get_cat_ids(self):
|
|
for i in range(len(self.repeat_indices)):
|
|
assert self.cls_banlanced_datasets.get_cat_ids(
|
|
i) == self.dataset.get_cat_ids(self.repeat_indices[i])
|