mmpretrain/tests/test_datasets/test_datasets.py

2202 lines
77 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import pickle
import sys
import tempfile
from unittest import TestCase
from unittest.mock import MagicMock, call, patch
import mat4py
import numpy as np
from mmengine.logging import MMLogger
from mmpretrain.registry import DATASETS, TRANSFORMS
ASSETS_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '../data/dataset'))
class TestBaseDataset(TestCase):
DATASET_TYPE = 'BaseDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.json')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test loading metainfo from ann_file
cfg = {**self.DEFAULT_ARGS, 'metainfo': None, 'classes': None}
dataset = dataset_class(**cfg)
self.assertEqual(
dataset.CLASSES,
dataset_class.METAINFO.get('classes', ('first', 'second')))
self.assertFalse(dataset.test_mode)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
cat_ids = dataset.get_cat_ids(0)
self.assertIsInstance(cat_ids, list)
self.assertEqual(len(cat_ids), 1)
self.assertIsInstance(cat_ids[0], int)
def test_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
head = 'Dataset ' + dataset.__class__.__name__
self.assertIn(head, repr(dataset))
if dataset.CLASSES is not None:
num_classes = len(dataset.CLASSES)
self.assertIn(f'Number of categories: \t{num_classes}',
repr(dataset))
self.assertIn('Haven\'t been initialized', repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
TRANSFORMS.register_module(name='test_mock', module=MagicMock)
cfg = {**self.DEFAULT_ARGS, 'pipeline': [dict(type='test_mock')]}
dataset = dataset_class(**cfg)
self.assertIn('With transforms', repr(dataset))
del TRANSFORMS.module_dict['test_mock']
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Annotation file: \t{dataset.ann_file}', repr(dataset))
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
repr(dataset))
class TestCustomDataset(TestBaseDataset):
DATASET_TYPE = 'CustomDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# test load without ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ('a', 'b')) # auto infer classes
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a', '1.JPG'),
'gt_label': 0
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b', 'subb', '3.jpg'),
'gt_label': 1
}.items())
# test load without ann_file and without labels
# (no specific folder structures)
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': '',
'with_label': False,
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 4)
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, '3.jpeg'),
}.items())
self.assertGreaterEqual(
dataset.get_data_info(1).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a', '1.JPG'),
}.items())
self.assertGreaterEqual(
dataset.get_data_info(3).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b', 'subb', '3.jpg'),
}.items())
# test ann_file assertion
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': ['ann_file.txt'],
}
with self.assertRaisesRegex(TypeError, 'expected str'):
dataset_class(**cfg)
# test load with ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_root': ASSETS_ROOT,
'ann_file': 'ann.txt',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a/1.JPG'),
'gt_label': 0,
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b/subb/3.jpg'),
'gt_label': 1
}.items())
np.testing.assert_equal(dataset.get_gt_labels(), np.array([0, 1, 1]))
# test load with absolute ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_root': '',
'data_prefix': '',
'ann_file': osp.join(ASSETS_ROOT, 'ann.txt'),
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': 'a/1.JPG',
'gt_label': 0,
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': 'b/subb/3.jpg',
'gt_label': 1
}.items())
# test load with absolute ann_file and without label
cfg = {
**self.DEFAULT_ARGS,
'data_root': '',
'data_prefix': '',
'ann_file': osp.join(ASSETS_ROOT, 'ann_without_labels.txt'),
'with_label': False,
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': 'a/1.JPG',
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': 'b/subb/3.jpg',
}.items())
# test extensions filter
cfg = {
**self.DEFAULT_ARGS, 'data_prefix': dict(img_path=ASSETS_ROOT),
'ann_file': '',
'extensions': ('.txt', )
}
with self.assertRaisesRegex(RuntimeError,
'Supported extensions are: .txt'):
dataset_class(**cfg)
cfg = {
**self.DEFAULT_ARGS, 'data_prefix': ASSETS_ROOT,
'ann_file': '',
'extensions': ('.jpeg', )
}
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset = dataset_class(**cfg)
self.assertIn('Supported extensions are: .jpeg', log.output[0])
self.assertEqual(len(dataset), 1)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b', '2.jpeg'),
'gt_label': 1
}.items())
# test classes check
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'classes': ('apple', 'banana'),
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('apple', 'banana'))
cfg['classes'] = ['apple', 'banana', 'dog']
with self.assertRaisesRegex(AssertionError,
r"\(2\) doesn't match .* classes \(3\)"):
dataset_class(**cfg)
class TestImageNet(TestCustomDataset):
DATASET_TYPE = 'ImageNet'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.meta_folder = 'meta'
cls.train_file = 'train.txt'
cls.val_file = 'val.txt'
cls.test_file = 'test.txt'
cls.categories = ['cat', 'dog']
os.mkdir(osp.join(cls.root, cls.meta_folder))
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
with open(osp.join(cls.root, cls.meta_folder, cls.train_file),
'w') as f:
f.write('\n'.join([
'1.jpg 0',
'2.jpg 1',
'3.jpg 1',
]))
with open(osp.join(cls.root, cls.meta_folder, cls.val_file), 'w') as f:
f.write('\n'.join([
'11.jpg 0',
'22.jpg 1',
]))
with open(osp.join(cls.root, cls.meta_folder, cls.test_file),
'w') as f:
f.write('\n'.join([
'aa.jpg',
'bb.jpg',
]))
def test_initialize(self):
super().test_initialize()
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['train', 'val']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
cfg['classes'] = self.categories
dataset = dataset_class(**cfg)
self.assertEqual(dataset.data_root, self.root)
# Test split="test"
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'test'
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'INFO') as log:
dataset = dataset_class(**cfg)
self.assertFalse(dataset.with_label)
self.assertIn('Since the ImageNet1k test set', log.output[0])
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, 'train', '1.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test split="val"
cfg = {**self.DEFAULT_ARGS, 'split': 'val'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, 'val', '11.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, 'test', 'aa.jpg'))
# test override classes
cfg = {
**self.DEFAULT_ARGS,
'classes': ['cat', 'dog'],
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ('cat', 'dog'))
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
class TestImageNet21k(TestCustomDataset):
DATASET_TYPE = 'ImageNet21k'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.meta_folder = 'meta'
cls.train_file = 'train.txt'
os.mkdir(osp.join(cls.root, cls.meta_folder))
with open(osp.join(cls.root, cls.meta_folder, cls.train_file),
'w') as f:
f.write('\n'.join([
'cat/a.jpg 0',
'cat/b.jpg 0',
'dog/a.jpg 1',
'dog/b.jpg 1',
]))
cls.DEFAULT_ARGS = dict(
data_root=cls.root,
classes=['cat', 'dog'],
ann_file='meta/train.txt')
def test_initialize(self):
super().test_initialize()
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'train'
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'train')
self.assertEqual(dataset.data_root, self.root)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# The multi_label option is not implemented not.
cfg = {**self.DEFAULT_ARGS, 'multi_label': True}
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
dataset_class(**cfg)
# Warn about ann_file
cfg = {**self.DEFAULT_ARGS, 'ann_file': '', 'lazy_init': True}
ann_path = osp.join(self.root, self.meta_folder, self.train_file)
os.rename(ann_path, ann_path + 'copy')
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'INFO') as log:
dataset_class(**cfg)
self.assertIn('specify the `ann_file`', log.output[0])
os.rename(ann_path + 'copy', ann_path)
# Warn about classes
cfg = {**self.DEFAULT_ARGS, 'classes': None}
with self.assertLogs(logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `classes`', log.output[0])
# Test split='train'
cfg = {**self.DEFAULT_ARGS, 'split': 'train', 'classes': None}
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 4)
class TestPlaces205(TestCustomDataset):
DATASET_TYPE = 'Places205'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# test classes number
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': '',
}
with self.assertRaisesRegex(AssertionError,
r"\(2\) doesn't match .* classes \(205\)"):
dataset_class(**cfg)
# test override classes
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'classes': ['cat', 'dog'],
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ('cat', 'dog'))
class TestCIFAR10(TestBaseDataset):
DATASET_TYPE = 'CIFAR10'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
dataset_class = DATASETS.get(cls.DATASET_TYPE)
base_folder = osp.join(cls.root, dataset_class.base_folder)
os.mkdir(base_folder)
cls.fake_imgs = np.random.randint(
0, 255, size=(6, 3 * 32 * 32), dtype=np.uint8)
cls.fake_labels = np.random.randint(0, 10, size=(6, ))
cls.fake_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
batch1 = dict(
data=cls.fake_imgs[:2], labels=cls.fake_labels[:2].tolist())
with open(osp.join(base_folder, 'data_batch_1'), 'wb') as f:
f.write(pickle.dumps(batch1))
batch2 = dict(
data=cls.fake_imgs[2:4], labels=cls.fake_labels[2:4].tolist())
with open(osp.join(base_folder, 'data_batch_2'), 'wb') as f:
f.write(pickle.dumps(batch2))
test_batch = dict(
data=cls.fake_imgs[4:], fine_labels=cls.fake_labels[4:].tolist())
with open(osp.join(base_folder, 'test_batch'), 'wb') as f:
f.write(pickle.dumps(test_batch))
meta = {dataset_class.meta['key']: cls.fake_classes}
meta_filename = dataset_class.meta['filename']
with open(osp.join(base_folder, meta_filename), 'wb') as f:
f.write(pickle.dumps(meta))
dataset_class.train_list = [['data_batch_1', None],
['data_batch_2', None]]
dataset_class.test_list = [['test_batch', None]]
dataset_class.meta['md5'] = None
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test with valid split
splits = ['train', 'test']
test_modes = [False, True]
for split in splits:
for test_mode in test_modes:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
cfg['test_mode'] = test_mode
if split == 'train' and test_mode:
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.test_mode, test_mode)
self.assertEqual(dataset.data_root, self.root)
self.assertIn('training set will be used', log.output[0])
else:
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.test_mode, test_mode)
self.assertEqual(dataset.data_root, self.root)
# Test without dataset path
with self.assertRaisesRegex(RuntimeError, 'specify the dataset path'):
dataset = dataset_class()
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 4)
self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
data_info = dataset[0]
fake_img = self.fake_imgs[0].reshape(3, 32, 32).transpose(1, 2, 0)
np.testing.assert_equal(data_info['img'], fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_labels[0])
# Test with split='test'
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img = self.fake_imgs[4].reshape(3, 32, 32).transpose(1, 2, 0)
np.testing.assert_equal(data_info['img'], fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_labels[4])
# Test load meta
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
dataset._metainfo = {}
dataset.full_init()
self.assertEqual(dataset.CLASSES, self.fake_classes)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
dataset._metainfo = {}
dataset.meta['filename'] = 'invalid'
with self.assertRaisesRegex(RuntimeError, 'not found or corrupted'):
dataset.full_init()
# Test automatically download
with patch('mmpretrain.datasets.cifar.download_and_extract_archive'
) as mock:
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'split': 'test'}
dataset = dataset_class(**cfg)
dataset.test_list = [['invalid_batch', None]]
with self.assertRaisesRegex(AssertionError, 'Download failed'):
dataset.full_init()
mock.assert_called_once_with(
dataset.url,
dataset.data_prefix['root'],
filename=dataset.filename,
md5=dataset.tgz_md5)
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'split': 'test',
'download': False
}
dataset = dataset_class(**cfg)
dataset.test_list = [['test_batch', 'invalid_md5']]
dataset.full_init()
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_prefix': 'http://openmmlab/cifar'
}
dataset = dataset_class(**cfg)
dataset._check_integrity = MagicMock(return_value=False)
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/cifar'):
dataset.full_init()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}",
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestCIFAR100(TestCIFAR10):
DATASET_TYPE = 'CIFAR100'
class TestMultiLabelDataset(TestBaseDataset):
DATASET_TYPE = 'MultiLabelDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='multi_label_ann.json')
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
cat_ids = dataset.get_cat_ids(0)
self.assertTrue(cat_ids, [0])
cat_ids = dataset.get_cat_ids(1)
self.assertTrue(cat_ids, [1])
cat_ids = dataset.get_cat_ids(1)
self.assertTrue(cat_ids, [0, 1])
class TestVOC(TestBaseDataset):
DATASET_TYPE = 'VOC'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
data_root = tmpdir.name
cls.DEFAULT_ARGS = dict(data_root=data_root, split='trainval')
cls.image_folder = osp.join(data_root, 'JPEGImages')
cls.ann_folder = osp.join(data_root, 'Annotations')
cls.image_set_folder = osp.join(data_root, 'ImageSets', 'Main')
os.makedirs(cls.image_set_folder)
os.mkdir(cls.image_folder)
os.mkdir(cls.ann_folder)
cls.fake_img_paths = [f'{i}' for i in range(6)]
cls.fake_labels = [[
np.random.randint(10) for _ in range(np.random.randint(1, 4))
] for _ in range(6)]
cls.fake_classes = [f'C_{i}' for i in range(10)]
train_list = [i for i in range(0, 4)]
test_list = [i for i in range(4, 6)]
with open(osp.join(cls.image_set_folder, 'trainval.txt'), 'w') as f:
for train_item in train_list:
f.write(str(train_item) + '\n')
with open(osp.join(cls.image_set_folder, 'test.txt'), 'w') as f:
for test_item in test_list:
f.write(str(test_item) + '\n')
with open(osp.join(cls.image_set_folder, 'full_path_test.txt'),
'w') as f:
for test_item in test_list:
f.write(osp.join(cls.image_folder, str(test_item)) + '\n')
for train_item in train_list:
with open(osp.join(cls.ann_folder, f'{train_item}.xml'), 'w') as f:
temple = '<object><name>C_{}</name>{}</object>'
ann_data = ''.join([
temple.format(label, '<difficult>0</difficult>')
for label in cls.fake_labels[train_item]
])
# add difficult label
ann_data += ''.join([
temple.format(label, '<difficult>1</difficult>')
for label in cls.fake_labels[train_item]
])
xml_ann_data = f'<annotation>{ann_data}</annotation>'
f.write(xml_ann_data + '\n')
for test_item in test_list:
with open(osp.join(cls.ann_folder, f'{test_item}.xml'), 'w') as f:
temple = '<object><name>C_{}</name>{}</object>'
ann_data = ''.join([
temple.format(label, '<difficult>0</difficult>')
for label in cls.fake_labels[test_item]
])
xml_ann_data = f'<annotation>{ann_data}</annotation>'
f.write(xml_ann_data + '\n')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding CLASSES by classes file
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['trainval', 'test']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
# Test split='trainval' and test_mode = True
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'trainval'
cfg['test_mode'] = True
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'trainval')
self.assertEqual(dataset.test_mode, True)
self.assertIn('The trainval set will be used', log.output[0])
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {'classes': self.fake_classes, **self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
cat_ids = dataset.get_cat_ids(0)
self.assertIsInstance(cat_ids, list)
self.assertIsInstance(cat_ids[0], int)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 4)
self.assertEqual(len(dataset.CLASSES), 20)
cfg = {
'classes': self.fake_classes,
'lazy_init': True,
**self.DEFAULT_ARGS
}
dataset = dataset_class(**cfg)
self.assertIn('Haven\'t been initialized', repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[0])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(set(data_info['gt_label']), set(self.fake_labels[0]))
# Test with split='test'
cfg['split'] = 'test'
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[4])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(set(data_info['gt_label']), set(self.fake_labels[4]))
# Test with test_mode=True and ann_path = None
cfg['split'] = ''
cfg['image_set_path'] = 'ImageSets/Main/test.txt'
cfg['test_mode'] = True
cfg['data_prefix'] = 'JPEGImages'
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[4])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(data_info['gt_label'], None)
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_root': 's3://openmmlab/voc'
}
petrel_mock = MagicMock()
sys.modules['petrel_client'] = petrel_mock
dataset = dataset_class(**cfg)
petrel_mock.client.Client.assert_called()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Path of image set: \t{dataset.image_set_path}',
repr(dataset))
self.assertIn(f'Prefix of dataset: \t{dataset.data_root}',
repr(dataset))
self.assertIn(f'Prefix of annotations: \t{dataset.ann_prefix}',
repr(dataset))
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestMNIST(TestBaseDataset):
DATASET_TYPE = 'MNIST'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
data_prefix = tmpdir.name
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
dataset_class = DATASETS.get(cls.DATASET_TYPE)
def rm_suffix(s):
return s[:s.rfind('.')]
train_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.train_list[0][0]))
train_label_file = osp.join(data_prefix,
rm_suffix(dataset_class.train_list[1][0]))
test_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.test_list[0][0]))
test_label_file = osp.join(data_prefix,
rm_suffix(dataset_class.test_list[1][0]))
cls.fake_img = np.random.randint(0, 255, size=(28, 28), dtype=np.uint8)
cls.fake_label = np.random.randint(0, 10, size=(1, ), dtype=np.uint8)
for file in [train_image_file, test_image_file]:
magic = b'\x00\x00\x08\x03' # num_dims = 3, type = uint8
head = b'\x00\x00\x00\x01' + b'\x00\x00\x00\x1c' * 2 # (1, 28, 28)
data = magic + head + cls.fake_img.flatten().tobytes()
with open(file, 'wb') as f:
f.write(data)
for file in [train_label_file, test_label_file]:
magic = b'\x00\x00\x08\x01' # num_dims = 3, type = uint8
head = b'\x00\x00\x00\x01' # (1, )
data = magic + head + cls.fake_label.tobytes()
with open(file, 'wb') as f:
f.write(data)
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test with valid split
splits = ['train', 'test']
test_modes = [False, True]
for split in splits:
for test_mode in test_modes:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
cfg['test_mode'] = test_mode
if split == 'train' and test_mode:
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.test_mode, test_mode)
self.assertEqual(dataset.data_root, self.root)
self.assertIn('training set will be used', log.output[0])
else:
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.test_mode, test_mode)
self.assertEqual(dataset.data_root, self.root)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 1)
self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test with split='test'
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test automatically download
with patch('mmpretrain.datasets.mnist.download_and_extract_archive'
) as mock:
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'split': 'test'}
dataset = dataset_class(**cfg)
dataset.train_list = [['invalid_train_file', None]]
dataset.test_list = [['invalid_test_file', None]]
with self.assertRaisesRegex(AssertionError, 'Download failed'):
dataset.full_init()
calls = [
call(
osp.join(dataset.url_prefix, dataset.train_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.train_list[0][0],
md5=None),
call(
osp.join(dataset.url_prefix, dataset.test_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.test_list[0][0],
md5=None)
]
mock.assert_has_calls(calls)
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'split': 'test',
'download': False
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
dataset.full_init()
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_prefix': 'http://openmmlab/mnist'
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/mnist'):
dataset.full_init()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}",
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class FashionMNIST(TestMNIST):
DATASET_TYPE = 'FashionMNIST'
class TestCUB(TestBaseDataset):
DATASET_TYPE = 'CUB'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.ann_file = 'images.txt'
cls.image_folder = 'images'
cls.image_class_labels_file = 'image_class_labels.txt'
cls.train_test_split_file = 'train_test_split.txt'
cls.DEFAULT_ARGS = dict(
data_root=cls.root, split='train', test_mode=False)
with open(osp.join(cls.root, cls.ann_file), 'w') as f:
f.write('\n'.join([
'1 1.txt',
'2 2.txt',
'3 3.txt',
]))
with open(osp.join(cls.root, cls.image_class_labels_file), 'w') as f:
f.write('\n'.join([
'1 2',
'2 3',
'3 1',
]))
with open(osp.join(cls.root, cls.train_test_split_file), 'w') as f:
f.write('\n'.join([
'1 0',
'2 1',
'3 1',
]))
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test with valid split
splits = ['train', 'test']
test_modes = [False, True]
for split in splits:
for test_mode in test_modes:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
cfg['test_mode'] = test_mode
if split == 'train' and test_mode:
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset = dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.test_mode, test_mode)
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file,
osp.join(self.root, self.ann_file))
self.assertIn('training set will be used', log.output[0])
else:
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.test_mode, test_mode)
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file,
osp.join(self.root, self.ann_file))
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '2.txt'))
self.assertEqual(data_info['gt_label'], 3 - 1)
# # Test with split='test'
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '1.txt'))
self.assertEqual(data_info['gt_label'], 2 - 1)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestMultiTaskDataset(TestCase):
DATASET_TYPE = 'MultiTaskDataset'
DEFAULT_ARGS = dict(
data_root=ASSETS_ROOT,
ann_file=osp.join(ASSETS_ROOT, 'multi-task.json'),
pipeline=[])
def test_metainfo(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
metainfo = {'tasks': ['gender', 'wear']}
self.assertDictEqual(dataset.metainfo, metainfo)
self.assertFalse(dataset.test_mode)
def test_parse_data_info(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
data = dataset.parse_data_info({
'img_path': 'a.jpg',
'gt_label': {
'gender': 0
}
})
self.assertDictContainsSubset(
{
'img_path': os.path.join(ASSETS_ROOT, 'a.jpg'),
'gt_label': {
'gender': 0
}
}, data)
np.testing.assert_equal(data['gt_label']['gender'], 0)
# Test missing path
with self.assertRaisesRegex(AssertionError, 'have `img_path` field'):
dataset.parse_data_info(
{'gt_label': {
'gender': 0,
'wear': [1, 0, 1, 0]
}})
def test_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
task_doc = ('For 2 tasks\n gender \n wear ')
self.assertIn(task_doc, repr(dataset))
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file'])
self.assertIsInstance(data, list)
np.testing.assert_equal(len(data), 3)
np.testing.assert_equal(data[0]['gt_label'], {'gender': 0})
np.testing.assert_equal(data[1]['gt_label'], {
'gender': 0,
'wear': [1, 0, 1, 0]
})
class TestInShop(TestBaseDataset):
DATASET_TYPE = 'InShop'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.list_eval_partition = 'Eval/list_eval_partition.txt'
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
cls.ann_file = osp.join(cls.root, cls.list_eval_partition)
os.makedirs(osp.join(cls.root, 'Eval'))
with open(cls.ann_file, 'w') as f:
f.write('\n'.join([
'8',
'image_name item_id evaluation_status',
f'{osp.join("img", "02_1_front.jpg")} id_00000002 train',
f'{osp.join("img", "02_2_side.jpg")} id_00000002 train',
f'{osp.join("img", "12_3_back.jpg")} id_00007982 gallery',
f'{osp.join("img", "12_7_addition.jpg")} id_00007982 gallery',
f'{osp.join("img", "13_1_front.jpg")} id_00007982 query',
f'{osp.join("img", "13_2_side.jpg")} id_00007983 gallery',
f'{osp.join("img", "13_3_back.jpg")} id_00007983 query ',
f'{osp.join("img", "13_7_additional.jpg")} id_00007983 query',
]))
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with mode=train
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'train')
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
# Test with mode=query
cfg = {**self.DEFAULT_ARGS, 'split': 'query'}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'query')
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
# Test with mode=gallery
cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'gallery')
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
# Test with mode=other
cfg = {**self.DEFAULT_ARGS, 'split': 'other'}
with self.assertRaisesRegex(AssertionError, "'split' of `InS"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with mode=train
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
os.path.join(self.root, 'Img', 'img', '02_1_front.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with mode=query
cfg = {**self.DEFAULT_ARGS, 'split': 'query'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
os.path.join(self.root, 'Img', 'img', '13_1_front.jpg'))
self.assertEqual(data_info['gt_label'], [0, 1])
# Test with mode=gallery
cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
os.path.join(self.root, 'Img', 'img', '12_3_back.jpg'))
self.assertEqual(data_info['sample_idx'], 0)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestFlowers102(TestBaseDataset):
DATASET_TYPE = 'Flowers102'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
cls.ann_file = osp.join(cls.root, 'imagelabels.mat')
cls.train_test_split_file = osp.join(cls.root, 'setid.mat')
mat4py.savemat(cls.ann_file,
{'labels': [1, 1, 2, 2, 2, 3, 3, 4, 4, 5]})
mat4py.savemat(cls.train_test_split_file, {
'trnid': [1, 3, 5],
'valid': [7, 9],
'tstid': [2, 4, 6, 8, 10],
})
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['train', 'val', 'trainval', 'test']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with split="train"
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'jpg', 'image_00001.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with split="val"
cfg = {**self.DEFAULT_ARGS, 'split': 'val'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'jpg', 'image_00007.jpg'))
self.assertEqual(data_info['gt_label'], 2)
# Test with split="trainval"
cfg = {**self.DEFAULT_ARGS, 'split': 'trainval'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 5)
data_info = dataset[2]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'jpg', 'image_00005.jpg'))
self.assertEqual(data_info['gt_label'], 1)
# Test with split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 5)
data_info = dataset[2]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'jpg', 'image_00006.jpg'))
self.assertEqual(data_info['gt_label'], 2)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestOxfordIIITPet(TestBaseDataset):
DATASET_TYPE = 'OxfordIIITPet'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.trainval_file = 'trainval.txt'
cls.image_folder = 'images'
cls.meta_folder = 'annotations'
cls.test_file = 'test.txt'
os.mkdir(osp.join(cls.root, cls.meta_folder))
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='trainval')
with open(osp.join(cls.root, cls.meta_folder, cls.trainval_file),
'w') as f:
f.write('\n'.join([
'Abyssinian_100 1 1 1',
'american_bulldog_100 2 2 1',
'basset_hound_126 4 2 3',
]))
with open(osp.join(cls.root, cls.meta_folder, cls.test_file),
'w') as f:
f.write('\n'.join([
'Abyssinian_204 1 1 1',
'american_bulldog_208 2 2 1',
]))
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['trainval', 'test']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.data_root, self.root)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
osp.join(self.root, self.image_folder, 'Abyssinian_100.jpg'))
self.assertEqual(data_info['gt_label'], 1 - 1)
# Test with split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
osp.join(self.root, self.image_folder, 'Abyssinian_204.jpg'))
self.assertEqual(data_info['gt_label'], 1 - 1)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestDTD(TestBaseDataset):
DATASET_TYPE = 'DTD'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
cls.meta_folder = 'imdb'
os.makedirs(osp.join(cls.root, cls.meta_folder))
cls.ann_file = osp.join(cls.root, cls.meta_folder, 'imdb.mat')
mat4py.savemat(
cls.ann_file, {
'images': {
'name': [
'1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg', '6.jpg',
'7.jpg', '8.jpg', '9.jpg', '10.jpg'
],
'class': [1, 1, 2, 2, 2, 3, 3, 4, 4, 5],
'set': [1, 2, 3, 1, 2, 3, 1, 2, 3, 1]
}
})
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['train', 'val', 'trainval', 'test']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with split="train"
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 4)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'images', '1.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with split="val"
cfg = {**self.DEFAULT_ARGS, 'split': 'val'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'images', '2.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with split="trainval"
cfg = {**self.DEFAULT_ARGS, 'split': 'trainval'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 7)
data_info = dataset[2]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'images', '4.jpg'))
self.assertEqual(data_info['gt_label'], 1)
# Test with split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'images', '3.jpg'))
self.assertEqual(data_info['gt_label'], 1)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestFGVCAircraft(TestBaseDataset):
DATASET_TYPE = 'FGVCAircraft'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
os.makedirs(osp.join(cls.root, 'data'))
cls.train_file = osp.join('data', 'images_variant_train.txt')
cls.val_file = osp.join('data', 'images_variant_val.txt')
cls.trainval_file = osp.join('data', 'images_variant_trainval.txt')
cls.test_file = osp.join('data', 'images_variant_test.txt')
cls.image_folder = osp.join('data', 'images')
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='trainval')
with open(osp.join(cls.root, cls.train_file), 'w') as f:
f.write('\n'.join([
'1025794 707-320',
'1019011 727-200',
]))
with open(osp.join(cls.root, cls.val_file), 'w') as f:
f.write('\n'.join([
'0209554 737-200',
]))
with open(osp.join(cls.root, cls.trainval_file), 'w') as f:
f.write('\n'.join([
'1025794 707-320',
'1019011 727-200',
'0209554 737-200',
]))
with open(osp.join(cls.root, cls.test_file), 'w') as f:
f.write('\n'.join([
'1514522 707-320',
'0116175 727-200',
'0713752 737-200',
'2126017 737-300',
]))
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['train', 'val', 'trainval', 'test']
ann_files = [
self.train_file, self.val_file, self.trainval_file, self.test_file
]
for i, split in enumerate(splits):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file,
osp.join(self.root, ann_files[i]))
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior (split="trainval")
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '1025794.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# # Test with split="train"
cfg = {**self.DEFAULT_ARGS, 'split': 'train'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '1025794.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with split="val"
cfg = {**self.DEFAULT_ARGS, 'split': 'val'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '0209554.jpg'))
self.assertEqual(data_info['gt_label'], 2)
# Test with split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 4)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '1514522.jpg'))
self.assertEqual(data_info['gt_label'], 0)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestStanfordCars(TestBaseDataset):
DATASET_TYPE = 'StanfordCars'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.ann_file = osp.join(cls.root, 'cars_annos.mat')
cls.meta_folder = 'devkit'
cls.train_ann_file = osp.join(cls.root, cls.meta_folder,
'cars_train_annos.mat')
cls.test_ann_file = osp.join(cls.root, cls.meta_folder,
'cars_test_annos_withlabels.mat')
cls.train_folder = 'cars_train'
cls.test_folder = 'cars_test'
os.makedirs(osp.join(cls.root, cls.meta_folder))
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
mat4py.savemat(
cls.ann_file, {
'annotations': {
'relative_im_path':
['car_ims/001.jpg', 'car_ims/002.jpg', 'car_ims/003.jpg'],
'class': [1, 2, 3],
'test': [0, 0, 1]
}
})
mat4py.savemat(
cls.train_ann_file, {
'annotations': {
'fname': ['001.jpg', '002.jpg', '012.jpg'],
'class': [10, 15, 150],
}
})
mat4py.savemat(
cls.test_ann_file, {
'annotations': {
'fname': ['025.jpg', '111.jpg', '222.jpg'],
'class': [150, 1, 15],
}
})
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test first way
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['train', 'test']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
# Test second way
os.rename(self.ann_file, self.ann_file + 'copy')
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'train'
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'train')
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file,
osp.join(self.meta_folder, self.train_ann_file))
# Test valid splits
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'test'
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'test')
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file,
osp.join(self.meta_folder, self.test_ann_file))
# wrong dataset organization
os.rename(self.train_ann_file, self.train_ann_file + 'copy')
os.rename(self.test_ann_file, self.test_ann_file + 'copy')
with self.assertRaisesRegex(RuntimeError,
'The dataset is incorrectly organized'):
cfg = {**self.DEFAULT_ARGS}
dataset_class(**cfg)
with self.assertRaisesRegex(RuntimeError,
'The dataset is incorrectly organized'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'test'
dataset_class(**cfg)
os.rename(self.train_ann_file + 'copy', self.train_ann_file)
os.rename(self.test_ann_file + 'copy', self.test_ann_file)
os.rename(self.ann_file + 'copy', self.ann_file)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test first way
# Test default behavior
assert osp.exists(osp.join(self.root, 'cars_annos.mat')), osp.join(
self.root, 'cars_annos.mat')
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, 'car_ims/001.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, 'car_ims/003.jpg'))
self.assertEqual(data_info['gt_label'], 2)
# Test second way
os.rename(self.ann_file, self.ann_file + 'copy')
# Test with split="train"
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.train_folder, '001.jpg'))
self.assertEqual(data_info['gt_label'], 9)
# Test with split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.test_folder, '025.jpg'))
self.assertEqual(data_info['gt_label'], 149)
os.rename(self.ann_file + 'copy', self.ann_file)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestCaltech101(TestBaseDataset):
DATASET_TYPE = 'Caltech101'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.image_folder = '101_ObjectCategories'
cls.meta_folder = 'meta'
cls.train_file = 'train.txt'
cls.test_file = 'test.txt'
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
os.makedirs(osp.join(cls.root, cls.meta_folder))
with open(osp.join(cls.root, cls.meta_folder, cls.train_file),
'w') as f:
f.write('\n'.join([
'1.jpg 0',
'2.jpg 1',
'3.jpg 2',
]))
with open(osp.join(cls.root, cls.meta_folder, cls.test_file),
'w') as f:
f.write('\n'.join([
'100.jpg 99',
'101.jpg 100',
'102.jpg 101',
'103.jpg 101',
]))
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['train', 'test']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.data_root, self.root)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '1.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 4)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '100.jpg'))
self.assertEqual(data_info['gt_label'], 99)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestFood101(TestBaseDataset):
DATASET_TYPE = 'Food101'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.image_folder = 'images'
cls.meta_folder = 'meta'
cls.train_file = 'train.txt'
cls.test_file = 'test.txt'
os.makedirs(osp.join(cls.root, cls.meta_folder))
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
with open(osp.join(cls.root, cls.meta_folder, cls.train_file),
'w') as f:
f.write('\n'.join([
'apple_pie/0001',
'baby_back_ribs/0002',
'baklava/0003',
]))
with open(osp.join(cls.root, cls.meta_folder, cls.test_file),
'w') as f:
f.write('\n'.join([
'beef_carpaccio/0004',
'beef_tartare/0005',
'beet_salad/0006',
]))
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['train', 'test']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.data_root, self.root)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
osp.join(self.root, self.image_folder, 'apple_pie', '0001.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with split="test"
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
osp.join(self.root, self.image_folder, 'beef_carpaccio',
'0004.jpg'))
self.assertEqual(data_info['gt_label'], 3)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestSUN397(TestBaseDataset):
DATASET_TYPE = 'SUN397'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.train_file = 'Training_01.txt'
cls.test_file = 'Testing_01.txt'
cls.data_prefix = 'SUN397'
cls.meta_folder = 'Partitions'
os.makedirs(osp.join(cls.root, cls.meta_folder))
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
with open(osp.join(cls.root, cls.meta_folder, cls.train_file),
'w') as f:
f.write('\n'.join([
'/a/abbey/sun_aqswjsnjlrfzzhiz.jpg',
'/a/airplane_cabin/sun_blczihbhbntqccux.jpg',
'/a/assembly_line/sun_ajckcfldgdrdjogj.jpg',
]))
with open(osp.join(cls.root, cls.meta_folder, cls.test_file),
'w') as f:
f.write('\n'.join([
'/a/abbey/sun_ajkqrqitspwywirx.jpg',
'/a/airplane_cabin/sun_aqylhacwdsqfjuuu.jpg',
'/a/auto_factory/sun_apfsprenzdnzbhmt.jpg',
'/b/baggage_claim/sun_avittiqqaiibgcau.jpg',
]))
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test invalid split
with self.assertRaisesRegex(AssertionError, 'The split must be'):
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = 'unknown'
dataset_class(**cfg)
# Test valid splits
splits = ['train', 'test']
for split in splits:
cfg = {**self.DEFAULT_ARGS}
cfg['split'] = split
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, split)
self.assertEqual(dataset.data_root, self.root)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
osp.join(self.root, self.data_prefix,
'a/abbey/sun_aqswjsnjlrfzzhiz.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with split='test'
cfg = {**self.DEFAULT_ARGS, 'split': 'test'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 4)
data_info = dataset[-1]
self.assertEqual(
data_info['img_path'],
osp.join(self.root, self.data_prefix,
'b/baggage_claim/sun_avittiqqaiibgcau.jpg'))
self.assertEqual(data_info['gt_label'], 26)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()