[Feature] Support InShop Dataset (Image Retrieval). (#1019)
* rebase * feat: add inshop dataset (retrieval) * update fileIO * update unit tests * fix windows ci * fix windows ci * fix lint * update unit tests * update docs * update docs Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>pull/1314/merge
parent
353886eaca
commit
7e4502b0ac
|
@ -49,6 +49,11 @@ CUB
|
|||
|
||||
.. autoclass:: CUB
|
||||
|
||||
Retrieval
|
||||
---------
|
||||
|
||||
.. autoclass:: InShop
|
||||
|
||||
Base classes
|
||||
------------
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from .cub import CUB
|
|||
from .custom import CustomDataset
|
||||
from .dataset_wrappers import KFoldDataset
|
||||
from .imagenet import ImageNet, ImageNet21k
|
||||
from .inshop import InShop
|
||||
from .mnist import MNIST, FashionMNIST
|
||||
from .multi_label import MultiLabelDataset
|
||||
from .multi_task import MultiTaskDataset
|
||||
|
@ -16,5 +17,5 @@ from .voc import VOC
|
|||
__all__ = [
|
||||
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
||||
'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
|
||||
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset'
|
||||
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset', 'InShop'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine import get_file_backend, list_from_file
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class InShop(BaseDataset):
|
||||
"""InShop Dataset for Image Retrieval.
|
||||
|
||||
Please download the images from the homepage
|
||||
'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html'
|
||||
(In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
|
||||
Eval/list_eval_partition.txt), and organize them as follows way: ::
|
||||
|
||||
In-shop Clothes Retrieval Benchmark (data_root)/
|
||||
├── Eval /
|
||||
│ └── list_eval_partition.txt (ann_file)
|
||||
├── Img
|
||||
│ └── img/ (img_prefix)
|
||||
├── README.txt
|
||||
└── .....
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for dataset.
|
||||
split (str): Choose from 'train', 'query' and 'gallery'.
|
||||
Defaults to 'train'.
|
||||
data_prefix (str | dict): Prefix for training data.
|
||||
Defaults to 'Img/img'.
|
||||
ann_file (str): Annotation file path, path relative to
|
||||
``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.datasets import InShop
|
||||
>>>
|
||||
>>> # build train InShop dataset
|
||||
>>> inshop_train_cfg = dict(data_root='data/inshop', split='train')
|
||||
>>> inshop_train = InShop(**inshop_train_cfg)
|
||||
>>> inshop_train
|
||||
Dataset InShop
|
||||
Number of samples: 25882
|
||||
The `CLASSES` meta info is not set.
|
||||
Root of dataset: data/inshop
|
||||
>>>
|
||||
>>> # build query InShop dataset
|
||||
>>> inshop_query_cfg = dict(data_root='data/inshop', split='query')
|
||||
>>> inshop_query = InShop(**inshop_query_cfg)
|
||||
>>> inshop_query
|
||||
Dataset InShop
|
||||
Number of samples: 14218
|
||||
The `CLASSES` meta info is not set.
|
||||
Root of dataset: data/inshop
|
||||
>>>
|
||||
>>> # build gallery InShop dataset
|
||||
>>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery')
|
||||
>>> inshop_gallery = InShop(**inshop_gallery_cfg)
|
||||
>>> inshop_gallery
|
||||
Dataset InShop
|
||||
Number of samples: 12612
|
||||
The `CLASSES` meta info is not set.
|
||||
Root of dataset: data/inshop
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_root: str,
|
||||
split: str = 'train',
|
||||
data_prefix: str = 'Img/img',
|
||||
ann_file: str = 'Eval/list_eval_partition.txt',
|
||||
**kwargs):
|
||||
|
||||
assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \
|
||||
f" must be one of ['train', 'query', 'gallery'], bu get '{split}'"
|
||||
self.backend = get_file_backend(data_root, enable_singleton=True)
|
||||
self.split = split
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
ann_file=ann_file,
|
||||
**kwargs)
|
||||
|
||||
def _process_annotations(self):
|
||||
lines = list_from_file(self.ann_file)
|
||||
|
||||
anno_train = dict(metainfo=dict(), data_list=list())
|
||||
anno_gallery = dict(metainfo=dict(), data_list=list())
|
||||
|
||||
# item_id to label, each item corresponds to one class label
|
||||
class_num = 0
|
||||
gt_label_train = {}
|
||||
|
||||
# item_id to label, each label corresponds to several items
|
||||
gallery_num = 0
|
||||
gt_label_gallery = {}
|
||||
|
||||
# (lines[0], lines[1]) is the image number and the field name;
|
||||
# Each line format as 'image_name, item_id, evaluation_status'
|
||||
for line in lines[2:]:
|
||||
img_name, item_id, status = line.split()
|
||||
img_path = self.backend.join_path(self.img_prefix, img_name)
|
||||
if status == 'train':
|
||||
if item_id not in gt_label_train:
|
||||
gt_label_train[item_id] = class_num
|
||||
class_num += 1
|
||||
# item_id to class_id (for the training set)
|
||||
anno_train['data_list'].append(
|
||||
dict(img_path=img_path, gt_label=gt_label_train[item_id]))
|
||||
elif status == 'gallery':
|
||||
if item_id not in gt_label_gallery:
|
||||
gt_label_gallery[item_id] = []
|
||||
# Since there are multiple images for each item,
|
||||
# record the corresponding item for each image.
|
||||
gt_label_gallery[item_id].append(gallery_num)
|
||||
anno_gallery['data_list'].append(
|
||||
dict(img_path=img_path, sample_idx=gallery_num))
|
||||
gallery_num += 1
|
||||
|
||||
if self.split == 'train':
|
||||
anno_train['metainfo']['class_number'] = class_num
|
||||
anno_train['metainfo']['sample_number'] = \
|
||||
len(anno_train['data_list'])
|
||||
return anno_train
|
||||
elif self.split == 'gallery':
|
||||
anno_gallery['metainfo']['sample_number'] = gallery_num
|
||||
return anno_gallery
|
||||
|
||||
# Generate the label for the query(val) set
|
||||
anno_query = dict(metainfo=dict(), data_list=list())
|
||||
query_num = 0
|
||||
for line in lines[2:]:
|
||||
img_name, item_id, status = line.split()
|
||||
img_path = self.backend.join_path(self.img_prefix, img_name)
|
||||
if status == 'query':
|
||||
anno_query['data_list'].append(
|
||||
dict(
|
||||
img_path=img_path, gt_label=gt_label_gallery[item_id]))
|
||||
query_num += 1
|
||||
|
||||
anno_query['metainfo']['sample_number'] = query_num
|
||||
return anno_query
|
||||
|
||||
def load_data_list(self):
|
||||
"""load data list.
|
||||
|
||||
For the train set, return image and ground truth label. For the query
|
||||
set, return image and ids of images in gallery. For the gallery set,
|
||||
return image and its id.
|
||||
"""
|
||||
data_info = self._process_annotations()
|
||||
data_list = data_info['data_list']
|
||||
for data in data_list:
|
||||
data['img_path'] = self.backend.join_path(self.data_root,
|
||||
data['img_path'])
|
||||
return data_list
|
||||
|
||||
def extra_repr(self):
|
||||
"""The extra repr information of the dataset."""
|
||||
body = [f'Root of dataset: \t{self.data_root}']
|
||||
return body
|
|
@ -929,7 +929,6 @@ class TestMultiTaskDataset(TestCase):
|
|||
|
||||
# Test default behavior
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
|
||||
data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file'])
|
||||
self.assertIsInstance(data, list)
|
||||
np.testing.assert_equal(len(data), 3)
|
||||
|
@ -938,3 +937,103 @@ class TestMultiTaskDataset(TestCase):
|
|||
'gender': 0,
|
||||
'wear': [1, 0, 1, 0]
|
||||
})
|
||||
|
||||
|
||||
class TestInShop(TestBaseDataset):
|
||||
DATASET_TYPE = 'InShop'
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
|
||||
tmpdir = tempfile.TemporaryDirectory()
|
||||
cls.tmpdir = tmpdir
|
||||
cls.root = tmpdir.name
|
||||
cls.list_eval_partition = 'Eval/list_eval_partition.txt'
|
||||
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
|
||||
cls.ann_file = osp.join(cls.root, cls.list_eval_partition)
|
||||
os.makedirs(osp.join(cls.root, 'Eval'))
|
||||
with open(cls.ann_file, 'w') as f:
|
||||
f.write('\n'.join([
|
||||
'8',
|
||||
'image_name item_id evaluation_status',
|
||||
'02_1_front.jpg id_00000002 train',
|
||||
'02_2_side.jpg id_00000002 train',
|
||||
'12_3_back.jpg id_00007982 gallery',
|
||||
'12_7_additional.jpg id_00007982 gallery',
|
||||
'13_1_front.jpg id_00007982 query',
|
||||
'13_2_side.jpg id_00007983 gallery',
|
||||
'13_3_back.jpg id_00007983 query ',
|
||||
'13_7_additional.jpg id_00007983 query',
|
||||
]))
|
||||
|
||||
def test_initialize(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test with mode=train
|
||||
cfg = {**self.DEFAULT_ARGS}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.split, 'train')
|
||||
self.assertEqual(dataset.data_root, self.root)
|
||||
self.assertEqual(dataset.ann_file, self.ann_file)
|
||||
|
||||
# Test with mode=query
|
||||
cfg = {**self.DEFAULT_ARGS, 'split': 'query'}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.split, 'query')
|
||||
self.assertEqual(dataset.data_root, self.root)
|
||||
self.assertEqual(dataset.ann_file, self.ann_file)
|
||||
|
||||
# Test with mode=gallery
|
||||
cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.split, 'gallery')
|
||||
self.assertEqual(dataset.data_root, self.root)
|
||||
self.assertEqual(dataset.ann_file, self.ann_file)
|
||||
|
||||
# Test with mode=other
|
||||
cfg = {**self.DEFAULT_ARGS, 'split': 'other'}
|
||||
with self.assertRaisesRegex(AssertionError, "'split' of `InS"):
|
||||
dataset_class(**cfg)
|
||||
|
||||
def test_load_data_list(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test with mode=train
|
||||
cfg = {**self.DEFAULT_ARGS}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 2)
|
||||
data_info = dataset[0]
|
||||
self.assertEqual(data_info['img_path'],
|
||||
os.path.join(self.root, 'Img/img', '02_1_front.jpg'))
|
||||
self.assertEqual(data_info['gt_label'], 0)
|
||||
|
||||
# Test with mode=query
|
||||
cfg = {**self.DEFAULT_ARGS, 'split': 'query'}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
data_info = dataset[0]
|
||||
self.assertEqual(data_info['img_path'],
|
||||
os.path.join(self.root, 'Img/img', '13_1_front.jpg'))
|
||||
self.assertEqual(data_info['gt_label'], [0, 1])
|
||||
|
||||
# Test with mode=gallery
|
||||
cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
data_info = dataset[0]
|
||||
self.assertEqual(
|
||||
data_info['img_path'],
|
||||
os.path.join(self.root, self.root, 'Img/img', '12_3_back.jpg'))
|
||||
self.assertEqual(data_info['sample_idx'], 0)
|
||||
|
||||
def test_extra_repr(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
cfg = {**self.DEFAULT_ARGS}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
|
Loading…
Reference in New Issue