add init_attributes to BaseDataset
parent
e2b1001165
commit
3cbeab42c2
|
@ -46,7 +46,9 @@ class BaseDataManager(object):
|
||||||
self.num_instances = num_instances
|
self.num_instances = num_instances
|
||||||
|
|
||||||
transform_train, transform_test = build_transforms(
|
transform_train, transform_test = build_transforms(
|
||||||
self.height, self.width, random_erase=self.random_erase, color_jitter=self.color_jitter,
|
self.height, self.width,
|
||||||
|
random_erase=self.random_erase,
|
||||||
|
color_jitter=self.color_jitter,
|
||||||
color_aug=self.color_aug
|
color_aug=self.color_aug
|
||||||
)
|
)
|
||||||
self.transform_train = transform_train
|
self.transform_train = transform_train
|
||||||
|
@ -61,22 +63,15 @@ class BaseDataManager(object):
|
||||||
return self._num_train_cams
|
return self._num_train_cams
|
||||||
|
|
||||||
def return_dataloaders(self):
|
def return_dataloaders(self):
|
||||||
"""
|
"""Return trainloader and testloader dictionary"""
|
||||||
Return trainloader and testloader dictionary
|
|
||||||
"""
|
|
||||||
return self.trainloader, self.testloader_dict
|
return self.trainloader, self.testloader_dict
|
||||||
|
|
||||||
def return_testdataset_by_name(self, name):
|
def return_testdataset_by_name(self, name):
|
||||||
"""
|
"""Return query and gallery, each containing a list of (img_path, pid, camid)"""
|
||||||
Return query and gallery, each containing a list of (img_path, pid, camid).
|
|
||||||
"""
|
|
||||||
return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery']
|
return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery']
|
||||||
|
|
||||||
|
|
||||||
class ImageDataManager(BaseDataManager):
|
class ImageDataManager(BaseDataManager):
|
||||||
"""
|
|
||||||
Image-ReID data manager
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
use_gpu,
|
use_gpu,
|
||||||
|
@ -92,15 +87,19 @@ class ImageDataManager(BaseDataManager):
|
||||||
self.cuhk03_classic_split = cuhk03_classic_split
|
self.cuhk03_classic_split = cuhk03_classic_split
|
||||||
self.market1501_500k = market1501_500k
|
self.market1501_500k = market1501_500k
|
||||||
|
|
||||||
print('=> Initializing TRAIN (source) datasets')
|
print('=> Initializing train (source) datasets')
|
||||||
train = []
|
train = []
|
||||||
self._num_train_pids = 0
|
self._num_train_pids = 0
|
||||||
self._num_train_cams = 0
|
self._num_train_cams = 0
|
||||||
|
|
||||||
for name in self.source_names:
|
for name in self.source_names:
|
||||||
dataset = init_imgreid_dataset(
|
dataset = init_imgreid_dataset(
|
||||||
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
|
root=self.root,
|
||||||
cuhk03_classic_split=self.cuhk03_classic_split, market1501_500k=self.market1501_500k
|
name=name,
|
||||||
|
split_id=self.split_id,
|
||||||
|
cuhk03_labeled=self.cuhk03_labeled,
|
||||||
|
cuhk03_classic_split=self.cuhk03_classic_split,
|
||||||
|
market1501_500k=self.market1501_500k
|
||||||
)
|
)
|
||||||
|
|
||||||
for img_path, pid, camid in dataset.train:
|
for img_path, pid, camid in dataset.train:
|
||||||
|
@ -112,37 +111,52 @@ class ImageDataManager(BaseDataManager):
|
||||||
self._num_train_cams += dataset.num_train_cams
|
self._num_train_cams += dataset.num_train_cams
|
||||||
|
|
||||||
self.train_sampler = build_train_sampler(
|
self.train_sampler = build_train_sampler(
|
||||||
train, self.train_sampler,
|
train,
|
||||||
|
self.train_sampler,
|
||||||
train_batch_size=self.train_batch_size,
|
train_batch_size=self.train_batch_size,
|
||||||
num_instances=self.num_instances,
|
num_instances=self.num_instances,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.trainloader = DataLoader(
|
self.trainloader = DataLoader(
|
||||||
ImageDataset(train, transform=self.transform_train), sampler=self.train_sampler,
|
ImageDataset(train, transform=self.transform_train),
|
||||||
batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers,
|
sampler=self.train_sampler,
|
||||||
pin_memory=self.use_gpu, drop_last=True
|
batch_size=self.train_batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.workers,
|
||||||
|
pin_memory=self.use_gpu,
|
||||||
|
drop_last=True
|
||||||
)
|
)
|
||||||
|
|
||||||
print('=> Initializing TEST (target) datasets')
|
print('=> Initializing test (target) datasets')
|
||||||
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||||
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||||
|
|
||||||
for name in self.target_names:
|
for name in self.target_names:
|
||||||
dataset = init_imgreid_dataset(
|
dataset = init_imgreid_dataset(
|
||||||
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
|
root=self.root,
|
||||||
cuhk03_classic_split=self.cuhk03_classic_split, market1501_500k=self.market1501_500k
|
name=name,
|
||||||
|
split_id=self.split_id,
|
||||||
|
cuhk03_labeled=self.cuhk03_labeled,
|
||||||
|
cuhk03_classic_split=self.cuhk03_classic_split,
|
||||||
|
market1501_500k=self.market1501_500k
|
||||||
)
|
)
|
||||||
|
|
||||||
self.testloader_dict[name]['query'] = DataLoader(
|
self.testloader_dict[name]['query'] = DataLoader(
|
||||||
ImageDataset(dataset.query, transform=self.transform_test),
|
ImageDataset(dataset.query, transform=self.transform_test),
|
||||||
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
batch_size=self.test_batch_size,
|
||||||
pin_memory=self.use_gpu, drop_last=False
|
shuffle=False,
|
||||||
|
num_workers=self.workers,
|
||||||
|
pin_memory=self.use_gpu,
|
||||||
|
drop_last=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.testloader_dict[name]['gallery'] = DataLoader(
|
self.testloader_dict[name]['gallery'] = DataLoader(
|
||||||
ImageDataset(dataset.gallery, transform=self.transform_test),
|
ImageDataset(dataset.gallery, transform=self.transform_test),
|
||||||
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
batch_size=self.test_batch_size,
|
||||||
pin_memory=self.use_gpu, drop_last=False
|
shuffle=False,
|
||||||
|
num_workers=self.workers,
|
||||||
|
pin_memory=self.use_gpu,
|
||||||
|
drop_last=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.testdataset_dict[name]['query'] = dataset.query
|
self.testdataset_dict[name]['query'] = dataset.query
|
||||||
|
@ -161,9 +175,6 @@ class ImageDataManager(BaseDataManager):
|
||||||
|
|
||||||
|
|
||||||
class VideoDataManager(BaseDataManager):
|
class VideoDataManager(BaseDataManager):
|
||||||
"""
|
|
||||||
Video-ReID data manager
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
use_gpu,
|
use_gpu,
|
||||||
|
@ -179,7 +190,7 @@ class VideoDataManager(BaseDataManager):
|
||||||
self.sample_method = sample_method
|
self.sample_method = sample_method
|
||||||
self.image_training = image_training
|
self.image_training = image_training
|
||||||
|
|
||||||
print('=> Initializing TRAIN (source) datasets')
|
print('=> Initializing train (source) datasets')
|
||||||
train = []
|
train = []
|
||||||
self._num_train_pids = 0
|
self._num_train_pids = 0
|
||||||
self._num_train_cams = 0
|
self._num_train_cams = 0
|
||||||
|
@ -209,21 +220,33 @@ class VideoDataManager(BaseDataManager):
|
||||||
if image_training:
|
if image_training:
|
||||||
# each batch has image data of shape (batch, channel, height, width)
|
# each batch has image data of shape (batch, channel, height, width)
|
||||||
self.trainloader = DataLoader(
|
self.trainloader = DataLoader(
|
||||||
ImageDataset(train, transform=self.transform_train), sampler=self.train_sampler,
|
ImageDataset(train, transform=self.transform_train),
|
||||||
batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers,
|
sampler=self.train_sampler,
|
||||||
pin_memory=self.use_gpu, drop_last=True
|
batch_size=self.train_batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.workers,
|
||||||
|
pin_memory=self.use_gpu,
|
||||||
|
drop_last=True
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# each batch has image data of shape (batch, seq_len, channel, height, width)
|
# each batch has image data of shape (batch, seq_len, channel, height, width)
|
||||||
# note: this requires new training scripts
|
# note: this requires new training scripts
|
||||||
self.trainloader = DataLoader(
|
self.trainloader = DataLoader(
|
||||||
VideoDataset(train, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_train),
|
VideoDataset(
|
||||||
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
|
train,
|
||||||
pin_memory=self.use_gpu, drop_last=True
|
seq_len=self.seq_len,
|
||||||
|
sample_method=self.sample_method,
|
||||||
|
transform=self.transform_train
|
||||||
|
),
|
||||||
|
batch_size=self.train_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=self.workers,
|
||||||
|
pin_memory=self.use_gpu,
|
||||||
|
drop_last=True
|
||||||
)
|
)
|
||||||
|
|
||||||
print('=> Initializing TEST (target) datasets')
|
print('=> Initializing test (target) datasets')
|
||||||
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||||
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||||
|
|
||||||
|
@ -231,15 +254,31 @@ class VideoDataManager(BaseDataManager):
|
||||||
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
|
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
|
||||||
|
|
||||||
self.testloader_dict[name]['query'] = DataLoader(
|
self.testloader_dict[name]['query'] = DataLoader(
|
||||||
VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_test),
|
VideoDataset(
|
||||||
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
dataset.query,
|
||||||
pin_memory=self.use_gpu, drop_last=False,
|
seq_len=self.seq_len,
|
||||||
|
sample_method=self.sample_method,
|
||||||
|
transform=self.transform_test
|
||||||
|
),
|
||||||
|
batch_size=self.test_batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.workers,
|
||||||
|
pin_memory=self.use_gpu,
|
||||||
|
drop_last=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.testloader_dict[name]['gallery'] = DataLoader(
|
self.testloader_dict[name]['gallery'] = DataLoader(
|
||||||
VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_test),
|
VideoDataset(
|
||||||
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
dataset.gallery,
|
||||||
pin_memory=self.use_gpu, drop_last=False,
|
seq_len=self.seq_len,
|
||||||
|
sample_method=self.sample_method,
|
||||||
|
transform=self.transform_test
|
||||||
|
),
|
||||||
|
batch_size=self.test_batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.workers,
|
||||||
|
pin_memory=self.use_gpu,
|
||||||
|
drop_last=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.testdataset_dict[name]['query'] = dataset.query
|
self.testdataset_dict[name]['query'] = dataset.query
|
||||||
|
|
|
@ -31,7 +31,7 @@ __imgreid_factory = {
|
||||||
'prid450s': PRID450S,
|
'prid450s': PRID450S,
|
||||||
'ilids': iLIDS,
|
'ilids': iLIDS,
|
||||||
'sensereid': SenseReID,
|
'sensereid': SenseReID,
|
||||||
'prid': PRID,
|
'prid': PRID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,17 +39,19 @@ __vidreid_factory = {
|
||||||
'mars': Mars,
|
'mars': Mars,
|
||||||
'ilidsvid': iLIDSVID,
|
'ilidsvid': iLIDSVID,
|
||||||
'prid2011': PRID2011,
|
'prid2011': PRID2011,
|
||||||
'dukemtmcvidreid': DukeMTMCVidReID,
|
'dukemtmcvidreid': DukeMTMCVidReID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def init_imgreid_dataset(name, **kwargs):
|
def init_imgreid_dataset(name, **kwargs):
|
||||||
if name not in list(__imgreid_factory.keys()):
|
avai_datasets = list(__imgreid_factory.keys())
|
||||||
raise KeyError('Invalid dataset, got "{}", but expected to be one of {}'.format(name, list(__imgreid_factory.keys())))
|
if name not in avai_datasets:
|
||||||
|
raise RuntimeError('Invalid dataset name. Received "{}", but expected to be one of {}'.format(name, avai_datasets))
|
||||||
return __imgreid_factory[name](**kwargs)
|
return __imgreid_factory[name](**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def init_vidreid_dataset(name, **kwargs):
|
def init_vidreid_dataset(name, **kwargs):
|
||||||
if name not in list(__vidreid_factory.keys()):
|
avai_datasets = list(__vidreid_factory.keys())
|
||||||
raise KeyError('Invalid dataset, got "{}", but expected to be one of {}'.format(name, list(__vidreid_factory.keys())))
|
if name not in avai_datasets:
|
||||||
|
raise RuntimeError('Invalid dataset name. Received "{}", but expected to be one of {}'.format(name, avai_datasets))
|
||||||
return __vidreid_factory[name](**kwargs)
|
return __vidreid_factory[name](**kwargs)
|
|
@ -18,7 +18,59 @@ class BaseDataset(object):
|
||||||
if not osp.exists(f):
|
if not osp.exists(f):
|
||||||
raise RuntimeError('"{}" is not found'.format(f))
|
raise RuntimeError('"{}" is not found'.format(f))
|
||||||
|
|
||||||
def get_imagedata_info(self, data):
|
def extract_data_info(self, data):
|
||||||
|
"""Extract info from data list
|
||||||
|
|
||||||
|
Return: num_pids, num_imgs, num_cams, optional
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_num_pids(self, data):
|
||||||
|
return self.extract_data_info(data)[0]
|
||||||
|
|
||||||
|
def get_num_cams(self, data):
|
||||||
|
return self.extract_data_info(data)[2]
|
||||||
|
|
||||||
|
def init_attributes(self, train, query, gallery):
|
||||||
|
self._train = train
|
||||||
|
self._query = query
|
||||||
|
self._gallery = gallery
|
||||||
|
self._num_train_pids = self.get_num_pids(train)
|
||||||
|
self._num_train_cams = self.get_num_cams(train)
|
||||||
|
|
||||||
|
def print_dataset_statistics(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train(self):
|
||||||
|
# train list containing (img_path, pid, camid)
|
||||||
|
return self._train
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query(self):
|
||||||
|
# query list containing (img_path, pid, camid)
|
||||||
|
return self._query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gallery(self):
|
||||||
|
# gallery list containing (img_path, pid, camid)
|
||||||
|
return self._gallery
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_train_pids(self):
|
||||||
|
# number of train identities
|
||||||
|
return self._num_train_pids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_train_cams(self):
|
||||||
|
# number of train camera views
|
||||||
|
return self._num_train_cams
|
||||||
|
|
||||||
|
|
||||||
|
class BaseImageDataset(BaseDataset):
|
||||||
|
"""Base class of image-reid dataset"""
|
||||||
|
|
||||||
|
def extract_data_info(self, data):
|
||||||
pids, cams = [], []
|
pids, cams = [], []
|
||||||
for _, pid, camid in data:
|
for _, pid, camid in data:
|
||||||
pids += [pid]
|
pids += [pid]
|
||||||
|
@ -30,32 +82,10 @@ class BaseDataset(object):
|
||||||
num_imgs = len(data)
|
num_imgs = len(data)
|
||||||
return num_pids, num_imgs, num_cams
|
return num_pids, num_imgs, num_cams
|
||||||
|
|
||||||
def get_videodata_info(self, data, return_tracklet_stats=False):
|
|
||||||
pids, cams, tracklet_stats = [], [], []
|
|
||||||
for img_paths, pid, camid in data:
|
|
||||||
pids += [pid]
|
|
||||||
cams += [camid]
|
|
||||||
tracklet_stats += [len(img_paths)]
|
|
||||||
pids = set(pids)
|
|
||||||
cams = set(cams)
|
|
||||||
num_pids = len(pids)
|
|
||||||
num_cams = len(cams)
|
|
||||||
num_tracklets = len(data)
|
|
||||||
if return_tracklet_stats:
|
|
||||||
return num_pids, num_tracklets, num_cams, tracklet_stats
|
|
||||||
return num_pids, num_tracklets, num_cams
|
|
||||||
|
|
||||||
def print_dataset_statistics(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class BaseImageDataset(BaseDataset):
|
|
||||||
"""Base class of image-reid dataset"""
|
|
||||||
|
|
||||||
def print_dataset_statistics(self, train, query, gallery):
|
def print_dataset_statistics(self, train, query, gallery):
|
||||||
num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
|
num_train_pids, num_train_imgs, num_train_cams = self.extract_data_info(train)
|
||||||
num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
|
num_query_pids, num_query_imgs, num_query_cams = self.extract_data_info(query)
|
||||||
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
|
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.extract_data_info(gallery)
|
||||||
|
|
||||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||||
print(' ----------------------------------------')
|
print(' ----------------------------------------')
|
||||||
|
@ -70,15 +100,30 @@ class BaseImageDataset(BaseDataset):
|
||||||
class BaseVideoDataset(BaseDataset):
|
class BaseVideoDataset(BaseDataset):
|
||||||
"""Base class of video-reid dataset"""
|
"""Base class of video-reid dataset"""
|
||||||
|
|
||||||
|
def extract_data_info(self, data, return_tracklet_stats=False):
|
||||||
|
pids, cams, tracklet_stats = [], [], []
|
||||||
|
for img_paths, pid, camid in data:
|
||||||
|
pids += [pid]
|
||||||
|
cams += [camid]
|
||||||
|
tracklet_stats += [len(img_paths)]
|
||||||
|
pids = set(pids)
|
||||||
|
cams = set(cams)
|
||||||
|
num_pids = len(pids)
|
||||||
|
num_cams = len(cams)
|
||||||
|
num_tracklets = len(data)
|
||||||
|
if return_tracklet_stats:
|
||||||
|
return num_pids, num_tracklets, num_cams, tracklet_stats
|
||||||
|
return num_pids, num_tracklets, num_cams
|
||||||
|
|
||||||
def print_dataset_statistics(self, train, query, gallery):
|
def print_dataset_statistics(self, train, query, gallery):
|
||||||
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
|
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
|
||||||
self.get_videodata_info(train, return_tracklet_stats=True)
|
self.extract_data_info(train, return_tracklet_stats=True)
|
||||||
|
|
||||||
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
|
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
|
||||||
self.get_videodata_info(query, return_tracklet_stats=True)
|
self.extract_data_info(query, return_tracklet_stats=True)
|
||||||
|
|
||||||
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
|
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
|
||||||
self.get_videodata_info(gallery, return_tracklet_stats=True)
|
self.extract_data_info(gallery, return_tracklet_stats=True)
|
||||||
|
|
||||||
tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
|
tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
|
||||||
min_num = np.min(tracklet_stats)
|
min_num = np.min(tracklet_stats)
|
||||||
|
|
|
@ -63,17 +63,11 @@ class CUHK01(BaseImageDataset):
|
||||||
query = [tuple(item) for item in query]
|
query = [tuple(item) for item in query]
|
||||||
gallery = [tuple(item) for item in gallery]
|
gallery = [tuple(item) for item in gallery]
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def extract_file(self):
|
def extract_file(self):
|
||||||
if not osp.exists(self.campus_dir):
|
if not osp.exists(self.campus_dir):
|
||||||
print('Extracting files')
|
print('Extracting files')
|
||||||
|
|
|
@ -81,17 +81,11 @@ class CUHK03(BaseImageDataset):
|
||||||
query = split['query']
|
query = split['query']
|
||||||
gallery = split['gallery']
|
gallery = split['gallery']
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def preprocess_split(self):
|
def preprocess_split(self):
|
||||||
"""
|
"""
|
||||||
This function is a bit complex and ugly, what it does is
|
This function is a bit complex and ugly, what it does is
|
||||||
|
|
|
@ -57,17 +57,11 @@ class DukeMTMCreID(BaseImageDataset):
|
||||||
query = self.process_dir(self.query_dir, relabel=False)
|
query = self.process_dir(self.query_dir, relabel=False)
|
||||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def download_data(self):
|
def download_data(self):
|
||||||
if osp.exists(self.dataset_dir):
|
if osp.exists(self.dataset_dir):
|
||||||
return
|
return
|
||||||
|
|
|
@ -60,17 +60,11 @@ class DukeMTMCVidReID(BaseVideoDataset):
|
||||||
query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False)
|
query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False)
|
||||||
gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False)
|
gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False)
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
|
|
||||||
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
|
|
||||||
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
|
|
||||||
|
|
||||||
def download_data(self):
|
def download_data(self):
|
||||||
if osp.exists(self.dataset_dir):
|
if osp.exists(self.dataset_dir):
|
||||||
return
|
return
|
||||||
|
|
|
@ -67,17 +67,11 @@ class GRID(BaseImageDataset):
|
||||||
query = [tuple(item) for item in query]
|
query = [tuple(item) for item in query]
|
||||||
gallery = [tuple(item) for item in gallery]
|
gallery = [tuple(item) for item in gallery]
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def download_data(self):
|
def download_data(self):
|
||||||
if osp.exists(self.dataset_dir):
|
if osp.exists(self.dataset_dir):
|
||||||
return
|
return
|
||||||
|
|
|
@ -58,17 +58,11 @@ class iLIDS(BaseImageDataset):
|
||||||
|
|
||||||
train, query, gallery = self.process_split(split)
|
train, query, gallery = self.process_split(split)
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def download_data(self):
|
def download_data(self):
|
||||||
if osp.exists(self.dataset_dir):
|
if osp.exists(self.dataset_dir):
|
||||||
return
|
return
|
||||||
|
|
|
@ -65,17 +65,11 @@ class iLIDSVID(BaseVideoDataset):
|
||||||
query = self.process_data(test_dirs, cam1=True, cam2=False)
|
query = self.process_data(test_dirs, cam1=True, cam2=False)
|
||||||
gallery = self.process_data(test_dirs, cam1=False, cam2=True)
|
gallery = self.process_data(test_dirs, cam1=False, cam2=True)
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
|
|
||||||
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
|
|
||||||
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
|
|
||||||
|
|
||||||
def download_data(self):
|
def download_data(self):
|
||||||
if osp.exists(self.dataset_dir):
|
if osp.exists(self.dataset_dir):
|
||||||
return
|
return
|
||||||
|
|
|
@ -57,17 +57,11 @@ class Market1501(BaseImageDataset):
|
||||||
if self.market1501_500k:
|
if self.market1501_500k:
|
||||||
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
|
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def process_dir(self, dir_path, relabel=False):
|
def process_dir(self, dir_path, relabel=False):
|
||||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||||
|
|
|
@ -66,17 +66,11 @@ class Mars(BaseVideoDataset):
|
||||||
query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
|
query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
|
||||||
gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
|
gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
|
|
||||||
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
|
|
||||||
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
|
|
||||||
|
|
||||||
def get_names(self, fpath):
|
def get_names(self, fpath):
|
||||||
names = []
|
names = []
|
||||||
with open(fpath, 'r') as f:
|
with open(fpath, 'r') as f:
|
||||||
|
|
|
@ -84,17 +84,10 @@ class MSMT17(BaseImageDataset):
|
||||||
#train += val
|
#train += val
|
||||||
#num_train_imgs += num_val_imgs
|
#num_train_imgs += num_val_imgs
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def process_dir(self, dir_path, list_path):
|
def process_dir(self, dir_path, list_path):
|
||||||
with open(list_path, 'r') as txt:
|
with open(list_path, 'r') as txt:
|
||||||
lines = txt.readlines()
|
lines = txt.readlines()
|
||||||
|
|
|
@ -60,17 +60,11 @@ class PRID(BaseImageDataset):
|
||||||
|
|
||||||
train, query, gallery = self.process_split(split)
|
train, query, gallery = self.process_split(split)
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def prepare_split(self):
|
def prepare_split(self):
|
||||||
if not osp.exists(self.split_path):
|
if not osp.exists(self.split_path):
|
||||||
print('Creating splits ...')
|
print('Creating splits ...')
|
||||||
|
|
|
@ -58,17 +58,11 @@ class PRID2011(BaseVideoDataset):
|
||||||
query = self.process_dir(test_dirs, cam1=True, cam2=False)
|
query = self.process_dir(test_dirs, cam1=True, cam2=False)
|
||||||
gallery = self.process_dir(test_dirs, cam1=False, cam2=True)
|
gallery = self.process_dir(test_dirs, cam1=False, cam2=True)
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
|
|
||||||
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
|
|
||||||
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
|
|
||||||
|
|
||||||
def process_dir(self, dirnames, cam1=True, cam2=True):
|
def process_dir(self, dirnames, cam1=True, cam2=True):
|
||||||
tracklets = []
|
tracklets = []
|
||||||
dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
|
dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
|
||||||
|
|
|
@ -65,17 +65,11 @@ class PRID450S(BaseImageDataset):
|
||||||
query = [tuple(item) for item in query]
|
query = [tuple(item) for item in query]
|
||||||
gallery = [tuple(item) for item in gallery]
|
gallery = [tuple(item) for item in gallery]
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def download_data(self):
|
def download_data(self):
|
||||||
if osp.exists(self.dataset_dir):
|
if osp.exists(self.dataset_dir):
|
||||||
return
|
return
|
||||||
|
|
|
@ -52,17 +52,12 @@ class SenseReID(BaseImageDataset):
|
||||||
|
|
||||||
query = self.process_dir(self.query_dir)
|
query = self.process_dir(self.query_dir)
|
||||||
gallery = self.process_dir(self.gallery_dir)
|
gallery = self.process_dir(self.gallery_dir)
|
||||||
|
train = copy.deepcopy(query) # dummy variable
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(query, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = copy.deepcopy(query) # only used to initialize trainloader
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def process_dir(self, dir_path):
|
def process_dir(self, dir_path):
|
||||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||||
|
|
|
@ -64,17 +64,11 @@ class VIPeR(BaseImageDataset):
|
||||||
query = [tuple(item) for item in query]
|
query = [tuple(item) for item in query]
|
||||||
gallery = [tuple(item) for item in gallery]
|
gallery = [tuple(item) for item in gallery]
|
||||||
|
|
||||||
|
self.init_attributes(train, query, gallery)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
self.print_dataset_statistics(train, query, gallery)
|
self.print_dataset_statistics(train, query, gallery)
|
||||||
|
|
||||||
self.train = train
|
|
||||||
self.query = query
|
|
||||||
self.gallery = gallery
|
|
||||||
|
|
||||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
|
||||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
|
||||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
|
||||||
|
|
||||||
def download_data(self):
|
def download_data(self):
|
||||||
if osp.exists(self.dataset_dir):
|
if osp.exists(self.dataset_dir):
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue