[Feature] Support InShop Dataset (Image Retrieval). (#1019)

* rebase

* feat: add inshop dataset (retrieval)

* update fileIO

* update unit tests

* fix windows ci

* fix windows ci

* fix lint

* update unit tests

* update docs

* update docs

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
pull/1314/merge
zzc98 2023-01-18 17:16:54 +08:00 committed by GitHub
parent 353886eaca
commit 7e4502b0ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 267 additions and 2 deletions

View File

@ -49,6 +49,11 @@ CUB
.. autoclass:: CUB
Retrieval
---------
.. autoclass:: InShop
Base classes
------------

View File

@ -6,6 +6,7 @@ from .cub import CUB
from .custom import CustomDataset
from .dataset_wrappers import KFoldDataset
from .imagenet import ImageNet, ImageNet21k
from .inshop import InShop
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .multi_task import MultiTaskDataset
@ -16,5 +17,5 @@ from .voc import VOC
__all__ = [
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset'
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset', 'InShop'
]

View File

@ -0,0 +1,160 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import get_file_backend, list_from_file
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class InShop(BaseDataset):
"""InShop Dataset for Image Retrieval.
Please download the images from the homepage
'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html'
(In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
Eval/list_eval_partition.txt), and organize them as follows way: ::
In-shop Clothes Retrieval Benchmark (data_root)/
Eval /
list_eval_partition.txt (ann_file)
Img
img/ (img_prefix)
README.txt
.....
Args:
data_root (str): The root directory for dataset.
split (str): Choose from 'train', 'query' and 'gallery'.
Defaults to 'train'.
data_prefix (str | dict): Prefix for training data.
Defaults to 'Img/img'.
ann_file (str): Annotation file path, path relative to
``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
Examples:
>>> from mmcls.datasets import InShop
>>>
>>> # build train InShop dataset
>>> inshop_train_cfg = dict(data_root='data/inshop', split='train')
>>> inshop_train = InShop(**inshop_train_cfg)
>>> inshop_train
Dataset InShop
Number of samples: 25882
The `CLASSES` meta info is not set.
Root of dataset: data/inshop
>>>
>>> # build query InShop dataset
>>> inshop_query_cfg = dict(data_root='data/inshop', split='query')
>>> inshop_query = InShop(**inshop_query_cfg)
>>> inshop_query
Dataset InShop
Number of samples: 14218
The `CLASSES` meta info is not set.
Root of dataset: data/inshop
>>>
>>> # build gallery InShop dataset
>>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery')
>>> inshop_gallery = InShop(**inshop_gallery_cfg)
>>> inshop_gallery
Dataset InShop
Number of samples: 12612
The `CLASSES` meta info is not set.
Root of dataset: data/inshop
"""
def __init__(self,
data_root: str,
split: str = 'train',
data_prefix: str = 'Img/img',
ann_file: str = 'Eval/list_eval_partition.txt',
**kwargs):
assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \
f" must be one of ['train', 'query', 'gallery'], bu get '{split}'"
self.backend = get_file_backend(data_root, enable_singleton=True)
self.split = split
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs)
def _process_annotations(self):
lines = list_from_file(self.ann_file)
anno_train = dict(metainfo=dict(), data_list=list())
anno_gallery = dict(metainfo=dict(), data_list=list())
# item_id to label, each item corresponds to one class label
class_num = 0
gt_label_train = {}
# item_id to label, each label corresponds to several items
gallery_num = 0
gt_label_gallery = {}
# (lines[0], lines[1]) is the image number and the field name;
# Each line format as 'image_name, item_id, evaluation_status'
for line in lines[2:]:
img_name, item_id, status = line.split()
img_path = self.backend.join_path(self.img_prefix, img_name)
if status == 'train':
if item_id not in gt_label_train:
gt_label_train[item_id] = class_num
class_num += 1
# item_id to class_id (for the training set)
anno_train['data_list'].append(
dict(img_path=img_path, gt_label=gt_label_train[item_id]))
elif status == 'gallery':
if item_id not in gt_label_gallery:
gt_label_gallery[item_id] = []
# Since there are multiple images for each item,
# record the corresponding item for each image.
gt_label_gallery[item_id].append(gallery_num)
anno_gallery['data_list'].append(
dict(img_path=img_path, sample_idx=gallery_num))
gallery_num += 1
if self.split == 'train':
anno_train['metainfo']['class_number'] = class_num
anno_train['metainfo']['sample_number'] = \
len(anno_train['data_list'])
return anno_train
elif self.split == 'gallery':
anno_gallery['metainfo']['sample_number'] = gallery_num
return anno_gallery
# Generate the label for the query(val) set
anno_query = dict(metainfo=dict(), data_list=list())
query_num = 0
for line in lines[2:]:
img_name, item_id, status = line.split()
img_path = self.backend.join_path(self.img_prefix, img_name)
if status == 'query':
anno_query['data_list'].append(
dict(
img_path=img_path, gt_label=gt_label_gallery[item_id]))
query_num += 1
anno_query['metainfo']['sample_number'] = query_num
return anno_query
def load_data_list(self):
"""load data list.
For the train set, return image and ground truth label. For the query
set, return image and ids of images in gallery. For the gallery set,
return image and its id.
"""
data_info = self._process_annotations()
data_list = data_info['data_list']
for data in data_list:
data['img_path'] = self.backend.join_path(self.data_root,
data['img_path'])
return data_list
def extra_repr(self):
"""The extra repr information of the dataset."""
body = [f'Root of dataset: \t{self.data_root}']
return body

View File

@ -929,7 +929,6 @@ class TestMultiTaskDataset(TestCase):
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file'])
self.assertIsInstance(data, list)
np.testing.assert_equal(len(data), 3)
@ -938,3 +937,103 @@ class TestMultiTaskDataset(TestCase):
'gender': 0,
'wear': [1, 0, 1, 0]
})
class TestInShop(TestBaseDataset):
DATASET_TYPE = 'InShop'
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
cls.root = tmpdir.name
cls.list_eval_partition = 'Eval/list_eval_partition.txt'
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
cls.ann_file = osp.join(cls.root, cls.list_eval_partition)
os.makedirs(osp.join(cls.root, 'Eval'))
with open(cls.ann_file, 'w') as f:
f.write('\n'.join([
'8',
'image_name item_id evaluation_status',
'02_1_front.jpg id_00000002 train',
'02_2_side.jpg id_00000002 train',
'12_3_back.jpg id_00007982 gallery',
'12_7_additional.jpg id_00007982 gallery',
'13_1_front.jpg id_00007982 query',
'13_2_side.jpg id_00007983 gallery',
'13_3_back.jpg id_00007983 query ',
'13_7_additional.jpg id_00007983 query',
]))
def test_initialize(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with mode=train
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'train')
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
# Test with mode=query
cfg = {**self.DEFAULT_ARGS, 'split': 'query'}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'query')
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
# Test with mode=gallery
cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'}
dataset = dataset_class(**cfg)
self.assertEqual(dataset.split, 'gallery')
self.assertEqual(dataset.data_root, self.root)
self.assertEqual(dataset.ann_file, self.ann_file)
# Test with mode=other
cfg = {**self.DEFAULT_ARGS, 'split': 'other'}
with self.assertRaisesRegex(AssertionError, "'split' of `InS"):
dataset_class(**cfg)
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
# Test with mode=train
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'Img/img', '02_1_front.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with mode=query
cfg = {**self.DEFAULT_ARGS, 'split': 'query'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'Img/img', '13_1_front.jpg'))
self.assertEqual(data_info['gt_label'], [0, 1])
# Test with mode=gallery
cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
os.path.join(self.root, self.root, 'Img/img', '12_3_back.jpg'))
self.assertEqual(data_info['sample_idx'], 0)
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS}
dataset = dataset_class(**cfg)
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()