diff --git a/docs/en/api/datasets.rst b/docs/en/api/datasets.rst index f72dca0ce..1e5239340 100644 --- a/docs/en/api/datasets.rst +++ b/docs/en/api/datasets.rst @@ -49,6 +49,11 @@ CUB .. autoclass:: CUB +Retrieval +--------- + +.. autoclass:: InShop + Base classes ------------ diff --git a/mmcls/datasets/__init__.py b/mmcls/datasets/__init__.py index 22abdadcc..93059ae4e 100644 --- a/mmcls/datasets/__init__.py +++ b/mmcls/datasets/__init__.py @@ -6,6 +6,7 @@ from .cub import CUB from .custom import CustomDataset from .dataset_wrappers import KFoldDataset from .imagenet import ImageNet, ImageNet21k +from .inshop import InShop from .mnist import MNIST, FashionMNIST from .multi_label import MultiLabelDataset from .multi_task import MultiTaskDataset @@ -16,5 +17,5 @@ from .voc import VOC __all__ = [ 'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST', 'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB', - 'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset' + 'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset', 'InShop' ] diff --git a/mmcls/datasets/inshop.py b/mmcls/datasets/inshop.py new file mode 100644 index 000000000..544e13bdb --- /dev/null +++ b/mmcls/datasets/inshop.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine import get_file_backend, list_from_file + +from mmcls.registry import DATASETS +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class InShop(BaseDataset): + """InShop Dataset for Image Retrieval. + + Please download the images from the homepage + 'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html' + (In-shop Clothes Retrieval Benchmark -> Img -> img.zip, + Eval/list_eval_partition.txt), and organize them as follows way: :: + + In-shop Clothes Retrieval Benchmark (data_root)/ + ├── Eval / + │ └── list_eval_partition.txt (ann_file) + ├── Img + │ └── img/ (img_prefix) + ├── README.txt + └── ..... + + Args: + data_root (str): The root directory for dataset. + split (str): Choose from 'train', 'query' and 'gallery'. + Defaults to 'train'. + data_prefix (str | dict): Prefix for training data. + Defaults to 'Img/img'. + ann_file (str): Annotation file path, path relative to + ``data_root``. Defaults to 'Eval/list_eval_partition.txt'. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + + Examples: + >>> from mmcls.datasets import InShop + >>> + >>> # build train InShop dataset + >>> inshop_train_cfg = dict(data_root='data/inshop', split='train') + >>> inshop_train = InShop(**inshop_train_cfg) + >>> inshop_train + Dataset InShop + Number of samples: 25882 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + >>> + >>> # build query InShop dataset + >>> inshop_query_cfg = dict(data_root='data/inshop', split='query') + >>> inshop_query = InShop(**inshop_query_cfg) + >>> inshop_query + Dataset InShop + Number of samples: 14218 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + >>> + >>> # build gallery InShop dataset + >>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery') + >>> inshop_gallery = InShop(**inshop_gallery_cfg) + >>> inshop_gallery + Dataset InShop + Number of samples: 12612 + The `CLASSES` meta info is not set. + Root of dataset: data/inshop + """ + + def __init__(self, + data_root: str, + split: str = 'train', + data_prefix: str = 'Img/img', + ann_file: str = 'Eval/list_eval_partition.txt', + **kwargs): + + assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \ + f" must be one of ['train', 'query', 'gallery'], bu get '{split}'" + self.backend = get_file_backend(data_root, enable_singleton=True) + self.split = split + super().__init__( + data_root=data_root, + data_prefix=data_prefix, + ann_file=ann_file, + **kwargs) + + def _process_annotations(self): + lines = list_from_file(self.ann_file) + + anno_train = dict(metainfo=dict(), data_list=list()) + anno_gallery = dict(metainfo=dict(), data_list=list()) + + # item_id to label, each item corresponds to one class label + class_num = 0 + gt_label_train = {} + + # item_id to label, each label corresponds to several items + gallery_num = 0 + gt_label_gallery = {} + + # (lines[0], lines[1]) is the image number and the field name; + # Each line format as 'image_name, item_id, evaluation_status' + for line in lines[2:]: + img_name, item_id, status = line.split() + img_path = self.backend.join_path(self.img_prefix, img_name) + if status == 'train': + if item_id not in gt_label_train: + gt_label_train[item_id] = class_num + class_num += 1 + # item_id to class_id (for the training set) + anno_train['data_list'].append( + dict(img_path=img_path, gt_label=gt_label_train[item_id])) + elif status == 'gallery': + if item_id not in gt_label_gallery: + gt_label_gallery[item_id] = [] + # Since there are multiple images for each item, + # record the corresponding item for each image. + gt_label_gallery[item_id].append(gallery_num) + anno_gallery['data_list'].append( + dict(img_path=img_path, sample_idx=gallery_num)) + gallery_num += 1 + + if self.split == 'train': + anno_train['metainfo']['class_number'] = class_num + anno_train['metainfo']['sample_number'] = \ + len(anno_train['data_list']) + return anno_train + elif self.split == 'gallery': + anno_gallery['metainfo']['sample_number'] = gallery_num + return anno_gallery + + # Generate the label for the query(val) set + anno_query = dict(metainfo=dict(), data_list=list()) + query_num = 0 + for line in lines[2:]: + img_name, item_id, status = line.split() + img_path = self.backend.join_path(self.img_prefix, img_name) + if status == 'query': + anno_query['data_list'].append( + dict( + img_path=img_path, gt_label=gt_label_gallery[item_id])) + query_num += 1 + + anno_query['metainfo']['sample_number'] = query_num + return anno_query + + def load_data_list(self): + """load data list. + + For the train set, return image and ground truth label. For the query + set, return image and ids of images in gallery. For the gallery set, + return image and its id. + """ + data_info = self._process_annotations() + data_list = data_info['data_list'] + for data in data_list: + data['img_path'] = self.backend.join_path(self.data_root, + data['img_path']) + return data_list + + def extra_repr(self): + """The extra repr information of the dataset.""" + body = [f'Root of dataset: \t{self.data_root}'] + return body diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py index 0af3e9055..b284b50bd 100644 --- a/tests/test_datasets/test_datasets.py +++ b/tests/test_datasets/test_datasets.py @@ -929,7 +929,6 @@ class TestMultiTaskDataset(TestCase): # Test default behavior dataset = dataset_class(**self.DEFAULT_ARGS) - data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file']) self.assertIsInstance(data, list) np.testing.assert_equal(len(data), 3) @@ -938,3 +937,103 @@ class TestMultiTaskDataset(TestCase): 'gender': 0, 'wear': [1, 0, 1, 0] }) + + +class TestInShop(TestBaseDataset): + DATASET_TYPE = 'InShop' + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + tmpdir = tempfile.TemporaryDirectory() + cls.tmpdir = tmpdir + cls.root = tmpdir.name + cls.list_eval_partition = 'Eval/list_eval_partition.txt' + cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train') + cls.ann_file = osp.join(cls.root, cls.list_eval_partition) + os.makedirs(osp.join(cls.root, 'Eval')) + with open(cls.ann_file, 'w') as f: + f.write('\n'.join([ + '8', + 'image_name item_id evaluation_status', + '02_1_front.jpg id_00000002 train', + '02_2_side.jpg id_00000002 train', + '12_3_back.jpg id_00007982 gallery', + '12_7_additional.jpg id_00007982 gallery', + '13_1_front.jpg id_00007982 query', + '13_2_side.jpg id_00007983 gallery', + '13_3_back.jpg id_00007983 query ', + '13_7_additional.jpg id_00007983 query', + ])) + + def test_initialize(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test with mode=train + cfg = {**self.DEFAULT_ARGS} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.split, 'train') + self.assertEqual(dataset.data_root, self.root) + self.assertEqual(dataset.ann_file, self.ann_file) + + # Test with mode=query + cfg = {**self.DEFAULT_ARGS, 'split': 'query'} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.split, 'query') + self.assertEqual(dataset.data_root, self.root) + self.assertEqual(dataset.ann_file, self.ann_file) + + # Test with mode=gallery + cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.split, 'gallery') + self.assertEqual(dataset.data_root, self.root) + self.assertEqual(dataset.ann_file, self.ann_file) + + # Test with mode=other + cfg = {**self.DEFAULT_ARGS, 'split': 'other'} + with self.assertRaisesRegex(AssertionError, "'split' of `InS"): + dataset_class(**cfg) + + def test_load_data_list(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test with mode=train + cfg = {**self.DEFAULT_ARGS} + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 2) + data_info = dataset[0] + self.assertEqual(data_info['img_path'], + os.path.join(self.root, 'Img/img', '02_1_front.jpg')) + self.assertEqual(data_info['gt_label'], 0) + + # Test with mode=query + cfg = {**self.DEFAULT_ARGS, 'split': 'query'} + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 3) + data_info = dataset[0] + self.assertEqual(data_info['img_path'], + os.path.join(self.root, 'Img/img', '13_1_front.jpg')) + self.assertEqual(data_info['gt_label'], [0, 1]) + + # Test with mode=gallery + cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'} + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 3) + data_info = dataset[0] + self.assertEqual( + data_info['img_path'], + os.path.join(self.root, self.root, 'Img/img', '12_3_back.jpg')) + self.assertEqual(data_info['sample_idx'], 0) + + def test_extra_repr(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + cfg = {**self.DEFAULT_ARGS} + dataset = dataset_class(**cfg) + + self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset)) + + @classmethod + def tearDownClass(cls): + cls.tmpdir.cleanup()