support msmt_v2

pull/119/head
KaiyangZhou 2019-01-22 18:18:59 +00:00
parent 4dcc84f5c7
commit 825fce92eb
1 changed files with 32 additions and 6 deletions

View File

@ -18,6 +18,23 @@ from scipy.misc import imsave
from .bases import BaseImageDataset
# To adapt to different versions
# Log:
# 22.01.2019: v1 and v2 only differ in dir names
_TRAIN_DIR_KEY = 'train_dir'
_TEST_DIR_KEY = 'train_dir'
_VERSION = {
'MSMT17_V1': {
_TRAIN_DIR_KEY: 'train',
_TEST_DIR_KEY: 'test',
},
'MSMT17_V2': {
_TRAIN_DIR_KEY: 'mask_train_v2',
_TEST_DIR_KEY: 'mask_test_v2',
}
}
class MSMT17(BaseImageDataset):
"""
MSMT17
@ -37,12 +54,20 @@ class MSMT17(BaseImageDataset):
def __init__(self, root='data', verbose=True, **kwargs):
super(MSMT17, self).__init__(root)
self.dataset_dir = osp.join(self.root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'MSMT17_V1/train')
self.test_dir = osp.join(self.dataset_dir, 'MSMT17_V1/test')
self.list_train_path = osp.join(self.dataset_dir, 'MSMT17_V1/list_train.txt')
self.list_val_path = osp.join(self.dataset_dir, 'MSMT17_V1/list_val.txt')
self.list_query_path = osp.join(self.dataset_dir, 'MSMT17_V1/list_query.txt')
self.list_gallery_path = osp.join(self.dataset_dir, 'MSMT17_V1/list_gallery.txt')
has_main_dir = False
for main_dir in _VERSION:
if osp.exists(osp.join(self.dataset_dir, main_dir)):
train_dir = _VERSION[main_dir][TRAIN_DIR_KEY]
test_dir = _VERSION[main_dir][TEST_DIR_KEY]
has_main_dir = True
break
assert has_main_dir, "Dataset folder not found"
self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir)
self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir)
self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt')
self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt')
self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt')
self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt')
self._check_before_run()
train = self._process_dir(self.train_dir, self.list_train_path)
@ -50,6 +75,7 @@ class MSMT17(BaseImageDataset):
query = self._process_dir(self.test_dir, self.list_query_path)
gallery = self._process_dir(self.test_dir, self.list_gallery_path)
# To fairly compare with published methods, don't use val images for training
#train += val
#num_train_imgs += num_val_imgs