mmpretrain/tests/test_datasets/test_datasets.py

877 lines
32 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import pickle
import sys
import tempfile
from unittest import TestCase
from unittest.mock import MagicMock, call, patch
2020-07-01 16:09:06 +08:00
import numpy as np
2022-07-12 16:10:59 +08:00
from mmengine.logging import MMLogger
from mmengine.registry import TRANSFORMS
2020-07-01 16:09:06 +08:00
from mmcls.registry import DATASETS
2022-07-12 16:10:59 +08:00
from mmcls.utils import register_all_modules
2022-07-12 16:10:59 +08:00
register_all_modules()
ASSETS_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '../data/dataset'))
class TestBaseDataset(TestCase):
DATASET_TYPE = 'BaseDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.json')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test loading metainfo from ann_file
cfg = {**self.DEFAULT_ARGS, 'metainfo': None, 'classes': None}
dataset = dataset_class(**cfg)
self.assertEqual(
dataset.CLASSES,
dataset_class.METAINFO.get('classes', ('first', 'second')))
self.assertFalse(dataset.test_mode)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
cat_ids = dataset.get_cat_ids(0)
self.assertIsInstance(cat_ids, list)
self.assertEqual(len(cat_ids), 1)
self.assertIsInstance(cat_ids[0], int)
def test_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
head = 'Dataset ' + dataset.__class__.__name__
self.assertIn(head, repr(dataset))
if dataset.CLASSES is not None:
num_classes = len(dataset.CLASSES)
self.assertIn(f'Number of categories: \t{num_classes}',
repr(dataset))
else:
self.assertIn('The `CLASSES` meta info is not set.', repr(dataset))
self.assertIn("Haven't been initialized", repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
TRANSFORMS.register_module(name='test_mock', module=MagicMock)
cfg = {**self.DEFAULT_ARGS, 'pipeline': [dict(type='test_mock')]}
dataset = dataset_class(**cfg)
self.assertIn('With transforms', repr(dataset))
del TRANSFORMS.module_dict['test_mock']
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Annotation file: \t{dataset.ann_file}', repr(dataset))
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
repr(dataset))
class TestCustomDataset(TestBaseDataset):
DATASET_TYPE = 'CustomDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# test load without ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ('a', 'b')) # auto infer classes
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a', '1.JPG'),
'gt_label': 0
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b', 'subb', '3.jpg'),
'gt_label': 1
}.items())
# test ann_file assertion
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': ['ann_file.txt'],
}
with self.assertRaisesRegex(TypeError, 'expected str'):
dataset_class(**cfg)
# test load with ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_root': ASSETS_ROOT,
'ann_file': 'ann.txt',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a/1.JPG'),
'gt_label': 0,
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b/subb/3.jpg'),
'gt_label': 1
}.items())
np.testing.assert_equal(dataset.get_gt_labels(), np.array([0, 1, 1]))
# test load with absolute ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_root': '',
'data_prefix': '',
'ann_file': osp.join(ASSETS_ROOT, 'ann.txt'),
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': 'a/1.JPG',
'gt_label': 0,
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': 'b/subb/3.jpg',
'gt_label': 1
}.items())
# test extensions filter
cfg = {
**self.DEFAULT_ARGS, 'data_prefix': dict(img_path=ASSETS_ROOT),
'ann_file': '',
'extensions': ('.txt', )
}
with self.assertRaisesRegex(RuntimeError,
'Supported extensions are: .txt'):
dataset_class(**cfg)
cfg = {
**self.DEFAULT_ARGS, 'data_prefix': ASSETS_ROOT,
'ann_file': '',
'extensions': ('.jpeg', )
}
2022-07-12 16:10:59 +08:00
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset = dataset_class(**cfg)
self.assertIn('Supported extensions are: .jpeg', log.output[0])
self.assertEqual(len(dataset), 1)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b', '2.jpeg'),
'gt_label': 1
}.items())
# test classes check
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'classes': ('apple', 'banana'),
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('apple', 'banana'))
cfg['classes'] = ['apple', 'banana', 'dog']
with self.assertRaisesRegex(AssertionError,
r"\(2\) doesn't match .* classes \(3\)"):
dataset_class(**cfg)
class TestImageNet(TestCustomDataset):
DATASET_TYPE = 'ImageNet'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# test classes number
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': '',
}
with self.assertRaisesRegex(
AssertionError, r"\(2\) doesn't match .* classes \(1000\)"):
dataset_class(**cfg)
# test override classes
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'classes': ['cat', 'dog'],
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ('cat', 'dog'))
class TestImageNet21k(TestCustomDataset):
DATASET_TYPE = 'ImageNet21k'
DEFAULT_ARGS = dict(
data_root=ASSETS_ROOT, classes=['cat', 'dog'], ann_file='ann.txt')
def test_load_data_list(self):
super().test_initialize()
dataset_class = DATASETS.get(self.DATASET_TYPE)
# The multi_label option is not implemented not.
cfg = {**self.DEFAULT_ARGS, 'multi_label': True}
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
dataset_class(**cfg)
# Warn about ann_file
cfg = {**self.DEFAULT_ARGS, 'ann_file': ''}
2022-07-12 16:10:59 +08:00
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `ann_file`', log.output[0])
# Warn about classes
cfg = {**self.DEFAULT_ARGS, 'classes': None}
2022-07-12 16:10:59 +08:00
with self.assertLogs(logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `classes`', log.output[0])
class TestCIFAR10(TestBaseDataset):
DATASET_TYPE = 'CIFAR10'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
data_prefix = tmpdir.name
cls.DEFAULT_ARGS = dict(
data_prefix=data_prefix, pipeline=[], test_mode=False)
dataset_class = DATASETS.get(cls.DATASET_TYPE)
base_folder = osp.join(data_prefix, dataset_class.base_folder)
os.mkdir(base_folder)
cls.fake_imgs = np.random.randint(
0, 255, size=(6, 3 * 32 * 32), dtype=np.uint8)
cls.fake_labels = np.random.randint(0, 10, size=(6, ))
cls.fake_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
batch1 = dict(
data=cls.fake_imgs[:2], labels=cls.fake_labels[:2].tolist())
with open(osp.join(base_folder, 'data_batch_1'), 'wb') as f:
f.write(pickle.dumps(batch1))
batch2 = dict(
data=cls.fake_imgs[2:4], labels=cls.fake_labels[2:4].tolist())
with open(osp.join(base_folder, 'data_batch_2'), 'wb') as f:
f.write(pickle.dumps(batch2))
test_batch = dict(
data=cls.fake_imgs[4:], fine_labels=cls.fake_labels[4:].tolist())
with open(osp.join(base_folder, 'test_batch'), 'wb') as f:
f.write(pickle.dumps(test_batch))
meta = {dataset_class.meta['key']: cls.fake_classes}
meta_filename = dataset_class.meta['filename']
with open(osp.join(base_folder, meta_filename), 'wb') as f:
f.write(pickle.dumps(meta))
dataset_class.train_list = [['data_batch_1', None],
['data_batch_2', None]]
dataset_class.test_list = [['test_batch', None]]
dataset_class.meta['md5'] = None
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 4)
self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
data_info = dataset[0]
fake_img = self.fake_imgs[0].reshape(3, 32, 32).transpose(1, 2, 0)
np.testing.assert_equal(data_info['img'], fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_labels[0])
# Test with test_mode=True
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img = self.fake_imgs[4].reshape(3, 32, 32).transpose(1, 2, 0)
np.testing.assert_equal(data_info['img'], fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_labels[4])
# Test load meta
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
dataset._metainfo = {}
dataset.full_init()
self.assertEqual(dataset.CLASSES, self.fake_classes)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
dataset._metainfo = {}
dataset.meta['filename'] = 'invalid'
with self.assertRaisesRegex(RuntimeError, 'not found or corrupted'):
dataset.full_init()
# Test automatically download
with patch(
'mmcls.datasets.cifar.download_and_extract_archive') as mock:
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'test_mode': True}
dataset = dataset_class(**cfg)
dataset.test_list = [['invalid_batch', None]]
with self.assertRaisesRegex(AssertionError, 'Download failed'):
dataset.full_init()
mock.assert_called_once_with(
dataset.url,
dataset.data_prefix['root'],
filename=dataset.filename,
md5=dataset.tgz_md5)
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'test_mode': True,
'download': False
}
dataset = dataset_class(**cfg)
dataset.test_list = [['test_batch', 'invalid_md5']]
dataset.full_init()
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_prefix': 'http://openmmlab/cifar'
}
dataset = dataset_class(**cfg)
dataset._check_integrity = MagicMock(return_value=False)
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/cifar'):
dataset.full_init()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestCIFAR100(TestCIFAR10):
DATASET_TYPE = 'CIFAR100'
class TestMultiLabelDataset(TestBaseDataset):
DATASET_TYPE = 'MultiLabelDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='multi_label_ann.json')
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
cat_ids = dataset.get_cat_ids(0)
self.assertTrue(cat_ids, [0])
cat_ids = dataset.get_cat_ids(1)
self.assertTrue(cat_ids, [1])
cat_ids = dataset.get_cat_ids(1)
self.assertTrue(cat_ids, [0, 1])
class TestVOC(TestBaseDataset):
DATASET_TYPE = 'VOC'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
data_root = tmpdir.name
cls.DEFAULT_ARGS = dict(
data_root=data_root,
image_set_path='ImageSets/train.txt',
data_prefix=dict(img_path='JPEGImages', ann_path='Annotations'),
pipeline=[],
test_mode=False)
cls.image_folder = osp.join(data_root, 'JPEGImages')
cls.ann_folder = osp.join(data_root, 'Annotations')
cls.image_set_folder = osp.join(data_root, 'ImageSets')
os.mkdir(cls.image_set_folder)
os.mkdir(cls.image_folder)
os.mkdir(cls.ann_folder)
cls.fake_img_paths = [f'{i}' for i in range(6)]
cls.fake_labels = [[
np.random.randint(10) for _ in range(np.random.randint(1, 4))
] for _ in range(6)]
cls.fake_classes = [f'C_{i}' for i in range(10)]
train_list = [i for i in range(0, 4)]
test_list = [i for i in range(4, 6)]
with open(osp.join(cls.image_set_folder, 'train.txt'), 'w') as f:
for train_item in train_list:
f.write(str(train_item) + '\n')
with open(osp.join(cls.image_set_folder, 'test.txt'), 'w') as f:
for test_item in test_list:
f.write(str(test_item) + '\n')
with open(osp.join(cls.image_set_folder, 'full_path_test.txt'),
'w') as f:
for test_item in test_list:
f.write(osp.join(cls.image_folder, str(test_item)) + '\n')
for train_item in train_list:
with open(osp.join(cls.ann_folder, f'{train_item}.xml'), 'w') as f:
temple = '<object><name>C_{}</name>{}</object>'
ann_data = ''.join([
temple.format(label, '<difficult>0</difficult>')
for label in cls.fake_labels[train_item]
])
# add difficult label
ann_data += ''.join([
temple.format(label, '<difficult>1</difficult>')
for label in cls.fake_labels[train_item]
])
xml_ann_data = f'<annotation>{ann_data}</annotation>'
f.write(xml_ann_data + '\n')
for test_item in test_list:
with open(osp.join(cls.ann_folder, f'{test_item}.xml'), 'w') as f:
temple = '<object><name>C_{}</name>{}</object>'
ann_data = ''.join([
temple.format(label, '<difficult>0</difficult>')
for label in cls.fake_labels[test_item]
])
xml_ann_data = f'<annotation>{ann_data}</annotation>'
f.write(xml_ann_data + '\n')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding CLASSES by classes file
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {'classes': self.fake_classes, **self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
cat_ids = dataset.get_cat_ids(0)
self.assertIsInstance(cat_ids, list)
self.assertIsInstance(cat_ids[0], int)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 4)
self.assertEqual(len(dataset.CLASSES), 20)
cfg = {
'classes': self.fake_classes,
'lazy_init': True,
**self.DEFAULT_ARGS
}
dataset = dataset_class(**cfg)
self.assertIn("Haven't been initialized", repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[0])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(set(data_info['gt_label']), set(self.fake_labels[0]))
# Test with test_mode=True
cfg['image_set_path'] = 'ImageSets/test.txt'
cfg['test_mode'] = True
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[4])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(set(data_info['gt_label']), set(self.fake_labels[4]))
# Test with test_mode=True and ann_path = None
cfg['image_set_path'] = 'ImageSets/test.txt'
cfg['test_mode'] = True
cfg['data_prefix'] = 'JPEGImages'
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[4])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(data_info['gt_label'], None)
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_root': 's3://openmmlab/voc'
}
petrel_mock = MagicMock()
sys.modules['petrel_client'] = petrel_mock
dataset = dataset_class(**cfg)
petrel_mock.client.Client.assert_called()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Path of image set: \t{dataset.image_set_path}',
repr(dataset))
self.assertIn(f'Prefix of dataset: \t{dataset.data_root}',
repr(dataset))
self.assertIn(f'Prefix of annotations: \t{dataset.ann_prefix}',
repr(dataset))
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestMNIST(TestBaseDataset):
DATASET_TYPE = 'MNIST'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
data_prefix = tmpdir.name
cls.DEFAULT_ARGS = dict(
data_prefix=data_prefix, pipeline=[], test_mode=False)
dataset_class = DATASETS.get(cls.DATASET_TYPE)
def rm_suffix(s):
return s[:s.rfind('.')]
train_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.train_list[0][0]))
train_label_file = osp.join(data_prefix,
rm_suffix(dataset_class.train_list[1][0]))
test_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.test_list[0][0]))
test_label_file = osp.join(data_prefix,
rm_suffix(dataset_class.test_list[1][0]))
cls.fake_img = np.random.randint(0, 255, size=(28, 28), dtype=np.uint8)
cls.fake_label = np.random.randint(0, 10, size=(1, ), dtype=np.uint8)
for file in [train_image_file, test_image_file]:
magic = b'\x00\x00\x08\x03' # num_dims = 3, type = uint8
head = b'\x00\x00\x00\x01' + b'\x00\x00\x00\x1c' * 2 # (1, 28, 28)
data = magic + head + cls.fake_img.flatten().tobytes()
with open(file, 'wb') as f:
f.write(data)
for file in [train_label_file, test_label_file]:
magic = b'\x00\x00\x08\x01' # num_dims = 3, type = uint8
head = b'\x00\x00\x00\x01' # (1, )
data = magic + head + cls.fake_label.tobytes()
with open(file, 'wb') as f:
f.write(data)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 1)
self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test with test_mode=True
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test automatically download
with patch(
'mmcls.datasets.mnist.download_and_extract_archive') as mock:
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'test_mode': True}
dataset = dataset_class(**cfg)
dataset.train_list = [['invalid_train_file', None]]
dataset.test_list = [['invalid_test_file', None]]
with self.assertRaisesRegex(AssertionError, 'Download failed'):
dataset.full_init()
calls = [
call(
osp.join(dataset.url_prefix, dataset.train_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.train_list[0][0],
md5=None),
call(
osp.join(dataset.url_prefix, dataset.test_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.test_list[0][0],
md5=None)
]
mock.assert_has_calls(calls)
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'test_mode': True,
'download': False
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
dataset.full_init()
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_prefix': 'http://openmmlab/mnist'
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/mnist'):
dataset.full_init()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class FashionMNIST(TestMNIST):
DATASET_TYPE = 'FashionMNIST'
class TestCUB(TestBaseDataset):
DATASET_TYPE = 'CUB'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
2022-06-27 18:00:52 +08:00
cls.root = tmpdir.name
cls.ann_file = 'ann_file.txt'
cls.image_folder = 'images'
cls.image_class_labels_file = 'classes.txt'
cls.train_test_split_file = 'split.txt'
cls.train_test_split_file2 = 'split2.txt'
cls.DEFAULT_ARGS = dict(
2022-06-27 18:00:52 +08:00
data_root=cls.root,
test_mode=False,
data_prefix=cls.image_folder,
pipeline=[],
ann_file=cls.ann_file,
image_class_labels_file=cls.image_class_labels_file,
train_test_split_file=cls.train_test_split_file)
2022-06-27 18:00:52 +08:00
with open(osp.join(cls.root, cls.ann_file), 'w') as f:
f.write('\n'.join([
'1 1.txt',
'2 2.txt',
'3 3.txt',
]))
2022-06-27 18:00:52 +08:00
with open(osp.join(cls.root, cls.image_class_labels_file), 'w') as f:
f.write('\n'.join([
'1 2',
'2 3',
'3 1',
]))
2022-06-27 18:00:52 +08:00
with open(osp.join(cls.root, cls.train_test_split_file), 'w') as f:
f.write('\n'.join([
'1 0',
'2 1',
'3 1',
]))
2022-06-27 18:00:52 +08:00
with open(osp.join(cls.root, cls.train_test_split_file2), 'w') as f:
f.write('\n'.join([
'1 0',
'2 1',
]))
2022-06-27 18:00:52 +08:00
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
2022-06-27 18:00:52 +08:00
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '2.txt'))
self.assertEqual(data_info['gt_label'], 3 - 1)
# Test with test_mode=True
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
2022-06-27 18:00:52 +08:00
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '1.txt'))
self.assertEqual(data_info['gt_label'], 2 - 1)
# Test if the numbers of line are not match
cfg = {
**self.DEFAULT_ARGS, 'train_test_split_file':
self.train_test_split_file2
}
2022-06-27 18:00:52 +08:00
with self.assertRaisesRegex(AssertionError,
'sample_ids should be same'):
dataset_class(**cfg)
2022-06-27 18:00:52 +08:00
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()