Colle bac181f393
[Feature] Support Multi-task. (#1229)
* unit test for multi_task_head

* [Feature] MultiTaskHead (#628, #481)

* [Fix] lint for multi_task_head

* [Feature] Add `MultiTaskDataset` to support multi-task training.

* Update MultiTaskClsHead

* Update docs

* [CI] Add test mim CI. (#879)

* [Fix] Remove duplicated wide-resnet metafile.

* [Feature] Support MPS device. (#894)

* [Feature] Support MPS device.

* Add `auto_select_device`

* Add unit tests

* [Fix] Fix Albu crash bug. (#918)

* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning

* Fix common

* Using copy incase potential bug in multi-label tasks

* Improve coding

* Improve code logic

* Add unit test

* Fix typo

* Fix yapf

* Bump version to 0.23.2. (#937)

* [Improve] Use `forward_dummy` to calculate FLOPS. (#953)

* Update README

* [Docs] Fix typo for wrong reference. (#1036)

* [Doc] Fix typo in tutorial 2 (#1043)

* [Docs] Fix a typo in ImageClassifier (#1050)

* add mask to loss

* add another pipeline

* adpat the pipeline if there is no mask

* switch mask and task

* first version of multi data smaple

* fix problem with attribut by getattr

* rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label'

* training  without evaluation

* first version work

* add others metrics

* delete evaluation from dataset

* fix linter

* fix linter

* multi metrics

* first version of test

* change evaluate metric

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_models/test_heads.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* add tests

* add test for multidatasample

* create a generic test

* create a generic test

* create a generic test

* change multi data sample

* correct test

* test

* add new test

* add test for dataset

* correct test

* correct test

* correct test

* correct test

* fix : #5

* run yapf

* fix linter

* fix linter

* fix linter

* fix isort

* fix isort

* fix docformmater

* fix docformmater

* fix linter

* fix linter

* fix data sample

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/structures/multi_task_data_sample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update tests/test_structures/test_datasample.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update data sample

* update head

* update head

* update multi data sample

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* fix linter

* update head

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix problem we don't  set pred or  gt

* fix linter

* fix : #2

* fix : linter

* update multi head

* fix linter

* fix linter

* update data sample

* update data sample

* fix ; linter

* update test

* test pipeline

* update pipeline

* update test

* update dataset

* update dataset

* fix linter

* fix linter

* update formatting

* add test for multi-task-eval

* update formatting

* fix linter

* update test

* update

* add test

* update metrics

* update metrics

* add doc for functions

* fix linter

* training for multitask 1.x

* fix linter

* run flake8

* run linter

* update test

* add mask in evaluation

* update metric doc

* update metric doc

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* Update mmcls/evaluation/metrics/multi_task.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update metric doc

* update metric doc

* Fix cannot import name MultiTaskDataSample

* fix test_datasets

* fix test_datasets

* fix linter

* add an example of multitask

* change name of configs dataset

* Refactor the multi-task support

* correct test and metric

* add test to multidatasample

* add test to multidatasample

* correct test

* correct metrics and clshead

* Update mmcls/models/heads/cls_head.py

Co-authored-by: Colle <piercus@users.noreply.github.com>

* update cls_head.py documentation

* lint

* lint

* fix: lint

* fix linter

* add eval mask

* fix documentation

* fix: single_label.py back to 1.x

* Update mmcls/models/heads/multi_task_head.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Remove multi-task configs.

Co-authored-by: mzr1996 <mzr1996@163.com>
Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com>
Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com>
Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com>
Co-authored-by: marouaneamz <maroineamil99@gmail.com>
Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00

944 lines
34 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import pickle
import sys
import tempfile
from unittest import TestCase
from unittest.mock import MagicMock, call, patch
import numpy as np
from mmengine.logging import MMLogger
from mmengine.registry import TRANSFORMS
from mmcls.registry import DATASETS
from mmcls.utils import register_all_modules
register_all_modules()
ASSETS_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '../data/dataset'))
class TestBaseDataset(TestCase):
DATASET_TYPE = 'BaseDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.json')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test loading metainfo from ann_file
cfg = {**self.DEFAULT_ARGS, 'metainfo': None, 'classes': None}
dataset = dataset_class(**cfg)
self.assertEqual(
dataset.CLASSES,
dataset_class.METAINFO.get('classes', ('first', 'second')))
self.assertFalse(dataset.test_mode)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
cat_ids = dataset.get_cat_ids(0)
self.assertIsInstance(cat_ids, list)
self.assertEqual(len(cat_ids), 1)
self.assertIsInstance(cat_ids[0], int)
def test_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
head = 'Dataset ' + dataset.__class__.__name__
self.assertIn(head, repr(dataset))
if dataset.CLASSES is not None:
num_classes = len(dataset.CLASSES)
self.assertIn(f'Number of categories: \t{num_classes}',
repr(dataset))
else:
self.assertIn('The `CLASSES` meta info is not set.', repr(dataset))
self.assertIn('Haven\'t been initialized', repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
TRANSFORMS.register_module(name='test_mock', module=MagicMock)
cfg = {**self.DEFAULT_ARGS, 'pipeline': [dict(type='test_mock')]}
dataset = dataset_class(**cfg)
self.assertIn('With transforms', repr(dataset))
del TRANSFORMS.module_dict['test_mock']
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Annotation file: \t{dataset.ann_file}', repr(dataset))
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
repr(dataset))
class TestCustomDataset(TestBaseDataset):
DATASET_TYPE = 'CustomDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# test load without ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ('a', 'b')) # auto infer classes
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a', '1.JPG'),
'gt_label': 0
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b', 'subb', '3.jpg'),
'gt_label': 1
}.items())
# test ann_file assertion
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': ['ann_file.txt'],
}
with self.assertRaisesRegex(TypeError, 'expected str'):
dataset_class(**cfg)
# test load with ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_root': ASSETS_ROOT,
'ann_file': 'ann.txt',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'a/1.JPG'),
'gt_label': 0,
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b/subb/3.jpg'),
'gt_label': 1
}.items())
np.testing.assert_equal(dataset.get_gt_labels(), np.array([0, 1, 1]))
# test load with absolute ann_file
cfg = {
**self.DEFAULT_ARGS,
'data_root': '',
'data_prefix': '',
'ann_file': osp.join(ASSETS_ROOT, 'ann.txt'),
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
# custom dataset won't infer CLASSES from ann_file
self.assertIsNone(dataset.CLASSES, None)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': 'a/1.JPG',
'gt_label': 0,
}.items())
self.assertGreaterEqual(
dataset.get_data_info(2).items(), {
'img_path': 'b/subb/3.jpg',
'gt_label': 1
}.items())
# test extensions filter
cfg = {
**self.DEFAULT_ARGS, 'data_prefix': dict(img_path=ASSETS_ROOT),
'ann_file': '',
'extensions': ('.txt', )
}
with self.assertRaisesRegex(RuntimeError,
'Supported extensions are: .txt'):
dataset_class(**cfg)
cfg = {
**self.DEFAULT_ARGS, 'data_prefix': ASSETS_ROOT,
'ann_file': '',
'extensions': ('.jpeg', )
}
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset = dataset_class(**cfg)
self.assertIn('Supported extensions are: .jpeg', log.output[0])
self.assertEqual(len(dataset), 1)
self.assertGreaterEqual(
dataset.get_data_info(0).items(), {
'img_path': osp.join(ASSETS_ROOT, 'b', '2.jpeg'),
'gt_label': 1
}.items())
# test classes check
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'classes': ('apple', 'banana'),
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('apple', 'banana'))
cfg['classes'] = ['apple', 'banana', 'dog']
with self.assertRaisesRegex(AssertionError,
r"\(2\) doesn't match .* classes \(3\)"):
dataset_class(**cfg)
class TestImageNet(TestCustomDataset):
DATASET_TYPE = 'ImageNet'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='ann.txt')
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# test classes number
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'ann_file': '',
}
with self.assertRaisesRegex(
AssertionError, r"\(2\) doesn't match .* classes \(1000\)"):
dataset_class(**cfg)
# test override classes
cfg = {
**self.DEFAULT_ARGS,
'data_prefix': ASSETS_ROOT,
'classes': ['cat', 'dog'],
'ann_file': '',
}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
self.assertEqual(dataset.CLASSES, ('cat', 'dog'))
class TestImageNet21k(TestCustomDataset):
DATASET_TYPE = 'ImageNet21k'
DEFAULT_ARGS = dict(
data_root=ASSETS_ROOT, classes=['cat', 'dog'], ann_file='ann.txt')
def test_load_data_list(self):
super().test_initialize()
dataset_class = DATASETS.get(self.DATASET_TYPE)
# The multi_label option is not implemented not.
cfg = {**self.DEFAULT_ARGS, 'multi_label': True}
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
dataset_class(**cfg)
# Warn about ann_file
cfg = {**self.DEFAULT_ARGS, 'ann_file': ''}
logger = MMLogger.get_current_instance()
with self.assertLogs(logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `ann_file`', log.output[0])
# Warn about classes
cfg = {**self.DEFAULT_ARGS, 'classes': None}
with self.assertLogs(logger, 'WARN') as log:
dataset_class(**cfg)
self.assertIn('specify the `classes`', log.output[0])
class TestCIFAR10(TestBaseDataset):
DATASET_TYPE = 'CIFAR10'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
data_prefix = tmpdir.name
cls.DEFAULT_ARGS = dict(
data_prefix=data_prefix, pipeline=[], test_mode=False)
dataset_class = DATASETS.get(cls.DATASET_TYPE)
base_folder = osp.join(data_prefix, dataset_class.base_folder)
os.mkdir(base_folder)
cls.fake_imgs = np.random.randint(
0, 255, size=(6, 3 * 32 * 32), dtype=np.uint8)
cls.fake_labels = np.random.randint(0, 10, size=(6, ))
cls.fake_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
batch1 = dict(
data=cls.fake_imgs[:2], labels=cls.fake_labels[:2].tolist())
with open(osp.join(base_folder, 'data_batch_1'), 'wb') as f:
f.write(pickle.dumps(batch1))
batch2 = dict(
data=cls.fake_imgs[2:4], labels=cls.fake_labels[2:4].tolist())
with open(osp.join(base_folder, 'data_batch_2'), 'wb') as f:
f.write(pickle.dumps(batch2))
test_batch = dict(
data=cls.fake_imgs[4:], fine_labels=cls.fake_labels[4:].tolist())
with open(osp.join(base_folder, 'test_batch'), 'wb') as f:
f.write(pickle.dumps(test_batch))
meta = {dataset_class.meta['key']: cls.fake_classes}
meta_filename = dataset_class.meta['filename']
with open(osp.join(base_folder, meta_filename), 'wb') as f:
f.write(pickle.dumps(meta))
dataset_class.train_list = [['data_batch_1', None],
['data_batch_2', None]]
dataset_class.test_list = [['test_batch', None]]
dataset_class.meta['md5'] = None
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `metainfo` argument
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 4)
self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
data_info = dataset[0]
fake_img = self.fake_imgs[0].reshape(3, 32, 32).transpose(1, 2, 0)
np.testing.assert_equal(data_info['img'], fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_labels[0])
# Test with test_mode=True
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img = self.fake_imgs[4].reshape(3, 32, 32).transpose(1, 2, 0)
np.testing.assert_equal(data_info['img'], fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_labels[4])
# Test load meta
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
dataset._metainfo = {}
dataset.full_init()
self.assertEqual(dataset.CLASSES, self.fake_classes)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
dataset._metainfo = {}
dataset.meta['filename'] = 'invalid'
with self.assertRaisesRegex(RuntimeError, 'not found or corrupted'):
dataset.full_init()
# Test automatically download
with patch(
'mmcls.datasets.cifar.download_and_extract_archive') as mock:
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'test_mode': True}
dataset = dataset_class(**cfg)
dataset.test_list = [['invalid_batch', None]]
with self.assertRaisesRegex(AssertionError, 'Download failed'):
dataset.full_init()
mock.assert_called_once_with(
dataset.url,
dataset.data_prefix['root'],
filename=dataset.filename,
md5=dataset.tgz_md5)
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'test_mode': True,
'download': False
}
dataset = dataset_class(**cfg)
dataset.test_list = [['test_batch', 'invalid_md5']]
dataset.full_init()
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_prefix': 'http://openmmlab/cifar'
}
dataset = dataset_class(**cfg)
dataset._check_integrity = MagicMock(return_value=False)
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/cifar'):
dataset.full_init()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}",
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestCIFAR100(TestCIFAR10):
DATASET_TYPE = 'CIFAR100'
class TestMultiLabelDataset(TestBaseDataset):
DATASET_TYPE = 'MultiLabelDataset'
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='multi_label_ann.json')
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
cat_ids = dataset.get_cat_ids(0)
self.assertTrue(cat_ids, [0])
cat_ids = dataset.get_cat_ids(1)
self.assertTrue(cat_ids, [1])
cat_ids = dataset.get_cat_ids(1)
self.assertTrue(cat_ids, [0, 1])
class TestVOC(TestBaseDataset):
DATASET_TYPE = 'VOC'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
data_root = tmpdir.name
cls.DEFAULT_ARGS = dict(
data_root=data_root,
image_set_path='ImageSets/train.txt',
data_prefix=dict(img_path='JPEGImages', ann_path='Annotations'),
pipeline=[],
test_mode=False)
cls.image_folder = osp.join(data_root, 'JPEGImages')
cls.ann_folder = osp.join(data_root, 'Annotations')
cls.image_set_folder = osp.join(data_root, 'ImageSets')
os.mkdir(cls.image_set_folder)
os.mkdir(cls.image_folder)
os.mkdir(cls.ann_folder)
cls.fake_img_paths = [f'{i}' for i in range(6)]
cls.fake_labels = [[
np.random.randint(10) for _ in range(np.random.randint(1, 4))
] for _ in range(6)]
cls.fake_classes = [f'C_{i}' for i in range(10)]
train_list = [i for i in range(0, 4)]
test_list = [i for i in range(4, 6)]
with open(osp.join(cls.image_set_folder, 'train.txt'), 'w') as f:
for train_item in train_list:
f.write(str(train_item) + '\n')
with open(osp.join(cls.image_set_folder, 'test.txt'), 'w') as f:
for test_item in test_list:
f.write(str(test_item) + '\n')
with open(osp.join(cls.image_set_folder, 'full_path_test.txt'),
'w') as f:
for test_item in test_list:
f.write(osp.join(cls.image_folder, str(test_item)) + '\n')
for train_item in train_list:
with open(osp.join(cls.ann_folder, f'{train_item}.xml'), 'w') as f:
temple = '<object><name>C_{}</name>{}</object>'
ann_data = ''.join([
temple.format(label, '<difficult>0</difficult>')
for label in cls.fake_labels[train_item]
])
# add difficult label
ann_data += ''.join([
temple.format(label, '<difficult>1</difficult>')
for label in cls.fake_labels[train_item]
])
xml_ann_data = f'<annotation>{ann_data}</annotation>'
f.write(xml_ann_data + '\n')
for test_item in test_list:
with open(osp.join(cls.ann_folder, f'{test_item}.xml'), 'w') as f:
temple = '<object><name>C_{}</name>{}</object>'
ann_data = ''.join([
temple.format(label, '<difficult>0</difficult>')
for label in cls.fake_labels[test_item]
])
xml_ann_data = f'<annotation>{ann_data}</annotation>'
f.write(xml_ann_data + '\n')
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test overriding metainfo by `classes` argument
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
# Test overriding CLASSES by classes file
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
# Test invalid classes
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
dataset_class(**cfg)
def test_get_cat_ids(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {'classes': self.fake_classes, **self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
cat_ids = dataset.get_cat_ids(0)
self.assertIsInstance(cat_ids, list)
self.assertIsInstance(cat_ids[0], int)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 4)
self.assertEqual(len(dataset.CLASSES), 20)
cfg = {
'classes': self.fake_classes,
'lazy_init': True,
**self.DEFAULT_ARGS
}
dataset = dataset_class(**cfg)
self.assertIn('Haven\'t been initialized', repr(dataset))
dataset.full_init()
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[0])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(set(data_info['gt_label']), set(self.fake_labels[0]))
# Test with test_mode=True
cfg['image_set_path'] = 'ImageSets/test.txt'
cfg['test_mode'] = True
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[4])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(set(data_info['gt_label']), set(self.fake_labels[4]))
# Test with test_mode=True and ann_path = None
cfg['image_set_path'] = 'ImageSets/test.txt'
cfg['test_mode'] = True
cfg['data_prefix'] = 'JPEGImages'
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[4])
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
self.assertEqual(data_info['gt_label'], None)
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_root': 's3://openmmlab/voc'
}
petrel_mock = MagicMock()
sys.modules['petrel_client'] = petrel_mock
dataset = dataset_class(**cfg)
petrel_mock.client.Client.assert_called()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Path of image set: \t{dataset.image_set_path}',
repr(dataset))
self.assertIn(f'Prefix of dataset: \t{dataset.data_root}',
repr(dataset))
self.assertIn(f'Prefix of annotations: \t{dataset.ann_prefix}',
repr(dataset))
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestMNIST(TestBaseDataset):
DATASET_TYPE = 'MNIST'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
data_prefix = tmpdir.name
cls.DEFAULT_ARGS = dict(
data_prefix=data_prefix, pipeline=[], test_mode=False)
dataset_class = DATASETS.get(cls.DATASET_TYPE)
def rm_suffix(s):
return s[:s.rfind('.')]
train_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.train_list[0][0]))
train_label_file = osp.join(data_prefix,
rm_suffix(dataset_class.train_list[1][0]))
test_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.test_list[0][0]))
test_label_file = osp.join(data_prefix,
rm_suffix(dataset_class.test_list[1][0]))
cls.fake_img = np.random.randint(0, 255, size=(28, 28), dtype=np.uint8)
cls.fake_label = np.random.randint(0, 10, size=(1, ), dtype=np.uint8)
for file in [train_image_file, test_image_file]:
magic = b'\x00\x00\x08\x03' # num_dims = 3, type = uint8
head = b'\x00\x00\x00\x01' + b'\x00\x00\x00\x1c' * 2 # (1, 28, 28)
data = magic + head + cls.fake_img.flatten().tobytes()
with open(file, 'wb') as f:
f.write(data)
for file in [train_label_file, test_label_file]:
magic = b'\x00\x00\x08\x01' # num_dims = 3, type = uint8
head = b'\x00\x00\x00\x01' # (1, )
data = magic + head + cls.fake_label.tobytes()
with open(file, 'wb') as f:
f.write(data)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 1)
self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test with test_mode=True
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test automatically download
with patch(
'mmcls.datasets.mnist.download_and_extract_archive') as mock:
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'test_mode': True}
dataset = dataset_class(**cfg)
dataset.train_list = [['invalid_train_file', None]]
dataset.test_list = [['invalid_test_file', None]]
with self.assertRaisesRegex(AssertionError, 'Download failed'):
dataset.full_init()
calls = [
call(
osp.join(dataset.url_prefix, dataset.train_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.train_list[0][0],
md5=None),
call(
osp.join(dataset.url_prefix, dataset.test_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.test_list[0][0],
md5=None)
]
mock.assert_has_calls(calls)
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'test_mode': True,
'download': False
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
dataset.full_init()
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_prefix': 'http://openmmlab/mnist'
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/mnist'):
dataset.full_init()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f"Prefix of data: \t{dataset.data_prefix['root']}",
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class FashionMNIST(TestMNIST):
DATASET_TYPE = 'FashionMNIST'
class TestCUB(TestBaseDataset):
DATASET_TYPE = 'CUB'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.ann_file = 'ann_file.txt'
cls.image_folder = 'images'
cls.image_class_labels_file = 'classes.txt'
cls.train_test_split_file = 'split.txt'
cls.train_test_split_file2 = 'split2.txt'
cls.DEFAULT_ARGS = dict(
data_root=cls.root,
test_mode=False,
data_prefix=cls.image_folder,
pipeline=[],
ann_file=cls.ann_file,
image_class_labels_file=cls.image_class_labels_file,
train_test_split_file=cls.train_test_split_file)
with open(osp.join(cls.root, cls.ann_file), 'w') as f:
f.write('\n'.join([
'1 1.txt',
'2 2.txt',
'3 3.txt',
]))
with open(osp.join(cls.root, cls.image_class_labels_file), 'w') as f:
f.write('\n'.join([
'1 2',
'2 3',
'3 1',
]))
with open(osp.join(cls.root, cls.train_test_split_file), 'w') as f:
f.write('\n'.join([
'1 0',
'2 1',
'3 1',
]))
with open(osp.join(cls.root, cls.train_test_split_file2), 'w') as f:
f.write('\n'.join([
'1 0',
'2 1',
]))
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '2.txt'))
self.assertEqual(data_info['gt_label'], 3 - 1)
# Test with test_mode=True
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
osp.join(self.root, self.image_folder, '1.txt'))
self.assertEqual(data_info['gt_label'], 2 - 1)
# Test if the numbers of line are not match
cfg = {
**self.DEFAULT_ARGS, 'train_test_split_file':
self.train_test_split_file2
}
with self.assertRaisesRegex(AssertionError,
'sample_ids should be same'):
dataset_class(**cfg)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class TestMultiTaskDataset(TestCase):
DATASET_TYPE = 'MultiTaskDataset'
DEFAULT_ARGS = dict(
data_root=ASSETS_ROOT,
ann_file=osp.join(ASSETS_ROOT, 'multi-task.json'),
pipeline=[])
def test_metainfo(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
metainfo = {'tasks': ['gender', 'wear']}
self.assertDictEqual(dataset.metainfo, metainfo)
self.assertFalse(dataset.test_mode)
def test_parse_data_info(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
data = dataset.parse_data_info({
'img_path': 'a.jpg',
'gt_label': {
'gender': 0
}
})
self.assertDictContainsSubset(
{
'img_path': os.path.join(ASSETS_ROOT, 'a.jpg'),
'gt_label': {
'gender': 0
}
}, data)
np.testing.assert_equal(data['gt_label']['gender'], 0)
# Test missing path
with self.assertRaisesRegex(AssertionError, 'have `img_path` field'):
dataset.parse_data_info(
{'gt_label': {
'gender': 0,
'wear': [1, 0, 1, 0]
}})
def test_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
dataset = dataset_class(**self.DEFAULT_ARGS)
task_doc = ('For 2 tasks\n gender \n wear ')
self.assertIn(task_doc, repr(dataset))
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file'])
self.assertIsInstance(data, list)
np.testing.assert_equal(len(data), 3)
np.testing.assert_equal(data[0]['gt_label'], {'gender': 0})
np.testing.assert_equal(data[1]['gt_label'], {
'gender': 0,
'wear': [1, 0, 1, 0]
})