[Feature] Support InShop Dataset (Image Retrieval). (#1019)
* rebase * feat: add inshop dataset (retrieval) * update fileIO * update unit tests * fix windows ci * fix windows ci * fix lint * update unit tests * update docs * update docs Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>pull/1314/merge
parent
353886eaca
commit
7e4502b0ac
|
@ -49,6 +49,11 @@ CUB
|
||||||
|
|
||||||
.. autoclass:: CUB
|
.. autoclass:: CUB
|
||||||
|
|
||||||
|
Retrieval
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. autoclass:: InShop
|
||||||
|
|
||||||
Base classes
|
Base classes
|
||||||
------------
|
------------
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from .cub import CUB
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
from .dataset_wrappers import KFoldDataset
|
from .dataset_wrappers import KFoldDataset
|
||||||
from .imagenet import ImageNet, ImageNet21k
|
from .imagenet import ImageNet, ImageNet21k
|
||||||
|
from .inshop import InShop
|
||||||
from .mnist import MNIST, FashionMNIST
|
from .mnist import MNIST, FashionMNIST
|
||||||
from .multi_label import MultiLabelDataset
|
from .multi_label import MultiLabelDataset
|
||||||
from .multi_task import MultiTaskDataset
|
from .multi_task import MultiTaskDataset
|
||||||
|
@ -16,5 +17,5 @@ from .voc import VOC
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
||||||
'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
|
'VOC', 'build_dataset', 'ImageNet21k', 'KFoldDataset', 'CUB',
|
||||||
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset'
|
'CustomDataset', 'MultiLabelDataset', 'MultiTaskDataset', 'InShop'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mmengine import get_file_backend, list_from_file
|
||||||
|
|
||||||
|
from mmcls.registry import DATASETS
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class InShop(BaseDataset):
|
||||||
|
"""InShop Dataset for Image Retrieval.
|
||||||
|
|
||||||
|
Please download the images from the homepage
|
||||||
|
'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html'
|
||||||
|
(In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
|
||||||
|
Eval/list_eval_partition.txt), and organize them as follows way: ::
|
||||||
|
|
||||||
|
In-shop Clothes Retrieval Benchmark (data_root)/
|
||||||
|
├── Eval /
|
||||||
|
│ └── list_eval_partition.txt (ann_file)
|
||||||
|
├── Img
|
||||||
|
│ └── img/ (img_prefix)
|
||||||
|
├── README.txt
|
||||||
|
└── .....
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root (str): The root directory for dataset.
|
||||||
|
split (str): Choose from 'train', 'query' and 'gallery'.
|
||||||
|
Defaults to 'train'.
|
||||||
|
data_prefix (str | dict): Prefix for training data.
|
||||||
|
Defaults to 'Img/img'.
|
||||||
|
ann_file (str): Annotation file path, path relative to
|
||||||
|
``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
|
||||||
|
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mmcls.datasets import InShop
|
||||||
|
>>>
|
||||||
|
>>> # build train InShop dataset
|
||||||
|
>>> inshop_train_cfg = dict(data_root='data/inshop', split='train')
|
||||||
|
>>> inshop_train = InShop(**inshop_train_cfg)
|
||||||
|
>>> inshop_train
|
||||||
|
Dataset InShop
|
||||||
|
Number of samples: 25882
|
||||||
|
The `CLASSES` meta info is not set.
|
||||||
|
Root of dataset: data/inshop
|
||||||
|
>>>
|
||||||
|
>>> # build query InShop dataset
|
||||||
|
>>> inshop_query_cfg = dict(data_root='data/inshop', split='query')
|
||||||
|
>>> inshop_query = InShop(**inshop_query_cfg)
|
||||||
|
>>> inshop_query
|
||||||
|
Dataset InShop
|
||||||
|
Number of samples: 14218
|
||||||
|
The `CLASSES` meta info is not set.
|
||||||
|
Root of dataset: data/inshop
|
||||||
|
>>>
|
||||||
|
>>> # build gallery InShop dataset
|
||||||
|
>>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery')
|
||||||
|
>>> inshop_gallery = InShop(**inshop_gallery_cfg)
|
||||||
|
>>> inshop_gallery
|
||||||
|
Dataset InShop
|
||||||
|
Number of samples: 12612
|
||||||
|
The `CLASSES` meta info is not set.
|
||||||
|
Root of dataset: data/inshop
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
data_root: str,
|
||||||
|
split: str = 'train',
|
||||||
|
data_prefix: str = 'Img/img',
|
||||||
|
ann_file: str = 'Eval/list_eval_partition.txt',
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \
|
||||||
|
f" must be one of ['train', 'query', 'gallery'], bu get '{split}'"
|
||||||
|
self.backend = get_file_backend(data_root, enable_singleton=True)
|
||||||
|
self.split = split
|
||||||
|
super().__init__(
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=data_prefix,
|
||||||
|
ann_file=ann_file,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def _process_annotations(self):
|
||||||
|
lines = list_from_file(self.ann_file)
|
||||||
|
|
||||||
|
anno_train = dict(metainfo=dict(), data_list=list())
|
||||||
|
anno_gallery = dict(metainfo=dict(), data_list=list())
|
||||||
|
|
||||||
|
# item_id to label, each item corresponds to one class label
|
||||||
|
class_num = 0
|
||||||
|
gt_label_train = {}
|
||||||
|
|
||||||
|
# item_id to label, each label corresponds to several items
|
||||||
|
gallery_num = 0
|
||||||
|
gt_label_gallery = {}
|
||||||
|
|
||||||
|
# (lines[0], lines[1]) is the image number and the field name;
|
||||||
|
# Each line format as 'image_name, item_id, evaluation_status'
|
||||||
|
for line in lines[2:]:
|
||||||
|
img_name, item_id, status = line.split()
|
||||||
|
img_path = self.backend.join_path(self.img_prefix, img_name)
|
||||||
|
if status == 'train':
|
||||||
|
if item_id not in gt_label_train:
|
||||||
|
gt_label_train[item_id] = class_num
|
||||||
|
class_num += 1
|
||||||
|
# item_id to class_id (for the training set)
|
||||||
|
anno_train['data_list'].append(
|
||||||
|
dict(img_path=img_path, gt_label=gt_label_train[item_id]))
|
||||||
|
elif status == 'gallery':
|
||||||
|
if item_id not in gt_label_gallery:
|
||||||
|
gt_label_gallery[item_id] = []
|
||||||
|
# Since there are multiple images for each item,
|
||||||
|
# record the corresponding item for each image.
|
||||||
|
gt_label_gallery[item_id].append(gallery_num)
|
||||||
|
anno_gallery['data_list'].append(
|
||||||
|
dict(img_path=img_path, sample_idx=gallery_num))
|
||||||
|
gallery_num += 1
|
||||||
|
|
||||||
|
if self.split == 'train':
|
||||||
|
anno_train['metainfo']['class_number'] = class_num
|
||||||
|
anno_train['metainfo']['sample_number'] = \
|
||||||
|
len(anno_train['data_list'])
|
||||||
|
return anno_train
|
||||||
|
elif self.split == 'gallery':
|
||||||
|
anno_gallery['metainfo']['sample_number'] = gallery_num
|
||||||
|
return anno_gallery
|
||||||
|
|
||||||
|
# Generate the label for the query(val) set
|
||||||
|
anno_query = dict(metainfo=dict(), data_list=list())
|
||||||
|
query_num = 0
|
||||||
|
for line in lines[2:]:
|
||||||
|
img_name, item_id, status = line.split()
|
||||||
|
img_path = self.backend.join_path(self.img_prefix, img_name)
|
||||||
|
if status == 'query':
|
||||||
|
anno_query['data_list'].append(
|
||||||
|
dict(
|
||||||
|
img_path=img_path, gt_label=gt_label_gallery[item_id]))
|
||||||
|
query_num += 1
|
||||||
|
|
||||||
|
anno_query['metainfo']['sample_number'] = query_num
|
||||||
|
return anno_query
|
||||||
|
|
||||||
|
def load_data_list(self):
|
||||||
|
"""load data list.
|
||||||
|
|
||||||
|
For the train set, return image and ground truth label. For the query
|
||||||
|
set, return image and ids of images in gallery. For the gallery set,
|
||||||
|
return image and its id.
|
||||||
|
"""
|
||||||
|
data_info = self._process_annotations()
|
||||||
|
data_list = data_info['data_list']
|
||||||
|
for data in data_list:
|
||||||
|
data['img_path'] = self.backend.join_path(self.data_root,
|
||||||
|
data['img_path'])
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
"""The extra repr information of the dataset."""
|
||||||
|
body = [f'Root of dataset: \t{self.data_root}']
|
||||||
|
return body
|
|
@ -929,7 +929,6 @@ class TestMultiTaskDataset(TestCase):
|
||||||
|
|
||||||
# Test default behavior
|
# Test default behavior
|
||||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||||
|
|
||||||
data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file'])
|
data = dataset.load_data_list(self.DEFAULT_ARGS['ann_file'])
|
||||||
self.assertIsInstance(data, list)
|
self.assertIsInstance(data, list)
|
||||||
np.testing.assert_equal(len(data), 3)
|
np.testing.assert_equal(len(data), 3)
|
||||||
|
@ -938,3 +937,103 @@ class TestMultiTaskDataset(TestCase):
|
||||||
'gender': 0,
|
'gender': 0,
|
||||||
'wear': [1, 0, 1, 0]
|
'wear': [1, 0, 1, 0]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class TestInShop(TestBaseDataset):
|
||||||
|
DATASET_TYPE = 'InShop'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
super().setUpClass()
|
||||||
|
|
||||||
|
tmpdir = tempfile.TemporaryDirectory()
|
||||||
|
cls.tmpdir = tmpdir
|
||||||
|
cls.root = tmpdir.name
|
||||||
|
cls.list_eval_partition = 'Eval/list_eval_partition.txt'
|
||||||
|
cls.DEFAULT_ARGS = dict(data_root=cls.root, split='train')
|
||||||
|
cls.ann_file = osp.join(cls.root, cls.list_eval_partition)
|
||||||
|
os.makedirs(osp.join(cls.root, 'Eval'))
|
||||||
|
with open(cls.ann_file, 'w') as f:
|
||||||
|
f.write('\n'.join([
|
||||||
|
'8',
|
||||||
|
'image_name item_id evaluation_status',
|
||||||
|
'02_1_front.jpg id_00000002 train',
|
||||||
|
'02_2_side.jpg id_00000002 train',
|
||||||
|
'12_3_back.jpg id_00007982 gallery',
|
||||||
|
'12_7_additional.jpg id_00007982 gallery',
|
||||||
|
'13_1_front.jpg id_00007982 query',
|
||||||
|
'13_2_side.jpg id_00007983 gallery',
|
||||||
|
'13_3_back.jpg id_00007983 query ',
|
||||||
|
'13_7_additional.jpg id_00007983 query',
|
||||||
|
]))
|
||||||
|
|
||||||
|
def test_initialize(self):
|
||||||
|
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||||
|
|
||||||
|
# Test with mode=train
|
||||||
|
cfg = {**self.DEFAULT_ARGS}
|
||||||
|
dataset = dataset_class(**cfg)
|
||||||
|
self.assertEqual(dataset.split, 'train')
|
||||||
|
self.assertEqual(dataset.data_root, self.root)
|
||||||
|
self.assertEqual(dataset.ann_file, self.ann_file)
|
||||||
|
|
||||||
|
# Test with mode=query
|
||||||
|
cfg = {**self.DEFAULT_ARGS, 'split': 'query'}
|
||||||
|
dataset = dataset_class(**cfg)
|
||||||
|
self.assertEqual(dataset.split, 'query')
|
||||||
|
self.assertEqual(dataset.data_root, self.root)
|
||||||
|
self.assertEqual(dataset.ann_file, self.ann_file)
|
||||||
|
|
||||||
|
# Test with mode=gallery
|
||||||
|
cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'}
|
||||||
|
dataset = dataset_class(**cfg)
|
||||||
|
self.assertEqual(dataset.split, 'gallery')
|
||||||
|
self.assertEqual(dataset.data_root, self.root)
|
||||||
|
self.assertEqual(dataset.ann_file, self.ann_file)
|
||||||
|
|
||||||
|
# Test with mode=other
|
||||||
|
cfg = {**self.DEFAULT_ARGS, 'split': 'other'}
|
||||||
|
with self.assertRaisesRegex(AssertionError, "'split' of `InS"):
|
||||||
|
dataset_class(**cfg)
|
||||||
|
|
||||||
|
def test_load_data_list(self):
|
||||||
|
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||||
|
|
||||||
|
# Test with mode=train
|
||||||
|
cfg = {**self.DEFAULT_ARGS}
|
||||||
|
dataset = dataset_class(**cfg)
|
||||||
|
self.assertEqual(len(dataset), 2)
|
||||||
|
data_info = dataset[0]
|
||||||
|
self.assertEqual(data_info['img_path'],
|
||||||
|
os.path.join(self.root, 'Img/img', '02_1_front.jpg'))
|
||||||
|
self.assertEqual(data_info['gt_label'], 0)
|
||||||
|
|
||||||
|
# Test with mode=query
|
||||||
|
cfg = {**self.DEFAULT_ARGS, 'split': 'query'}
|
||||||
|
dataset = dataset_class(**cfg)
|
||||||
|
self.assertEqual(len(dataset), 3)
|
||||||
|
data_info = dataset[0]
|
||||||
|
self.assertEqual(data_info['img_path'],
|
||||||
|
os.path.join(self.root, 'Img/img', '13_1_front.jpg'))
|
||||||
|
self.assertEqual(data_info['gt_label'], [0, 1])
|
||||||
|
|
||||||
|
# Test with mode=gallery
|
||||||
|
cfg = {**self.DEFAULT_ARGS, 'split': 'gallery'}
|
||||||
|
dataset = dataset_class(**cfg)
|
||||||
|
self.assertEqual(len(dataset), 3)
|
||||||
|
data_info = dataset[0]
|
||||||
|
self.assertEqual(
|
||||||
|
data_info['img_path'],
|
||||||
|
os.path.join(self.root, self.root, 'Img/img', '12_3_back.jpg'))
|
||||||
|
self.assertEqual(data_info['sample_idx'], 0)
|
||||||
|
|
||||||
|
def test_extra_repr(self):
|
||||||
|
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||||
|
cfg = {**self.DEFAULT_ARGS}
|
||||||
|
dataset = dataset_class(**cfg)
|
||||||
|
|
||||||
|
self.assertIn(f'Root of dataset: \t{dataset.data_root}', repr(dataset))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.tmpdir.cleanup()
|
||||||
|
|
Loading…
Reference in New Issue