add init_attributes to BaseDataset
parent
e2b1001165
commit
3cbeab42c2
|
@ -46,7 +46,9 @@ class BaseDataManager(object):
|
|||
self.num_instances = num_instances
|
||||
|
||||
transform_train, transform_test = build_transforms(
|
||||
self.height, self.width, random_erase=self.random_erase, color_jitter=self.color_jitter,
|
||||
self.height, self.width,
|
||||
random_erase=self.random_erase,
|
||||
color_jitter=self.color_jitter,
|
||||
color_aug=self.color_aug
|
||||
)
|
||||
self.transform_train = transform_train
|
||||
|
@ -61,22 +63,15 @@ class BaseDataManager(object):
|
|||
return self._num_train_cams
|
||||
|
||||
def return_dataloaders(self):
|
||||
"""
|
||||
Return trainloader and testloader dictionary
|
||||
"""
|
||||
"""Return trainloader and testloader dictionary"""
|
||||
return self.trainloader, self.testloader_dict
|
||||
|
||||
def return_testdataset_by_name(self, name):
|
||||
"""
|
||||
Return query and gallery, each containing a list of (img_path, pid, camid).
|
||||
"""
|
||||
"""Return query and gallery, each containing a list of (img_path, pid, camid)"""
|
||||
return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery']
|
||||
|
||||
|
||||
class ImageDataManager(BaseDataManager):
|
||||
"""
|
||||
Image-ReID data manager
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_gpu,
|
||||
|
@ -92,15 +87,19 @@ class ImageDataManager(BaseDataManager):
|
|||
self.cuhk03_classic_split = cuhk03_classic_split
|
||||
self.market1501_500k = market1501_500k
|
||||
|
||||
print('=> Initializing TRAIN (source) datasets')
|
||||
print('=> Initializing train (source) datasets')
|
||||
train = []
|
||||
self._num_train_pids = 0
|
||||
self._num_train_cams = 0
|
||||
|
||||
for name in self.source_names:
|
||||
dataset = init_imgreid_dataset(
|
||||
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
|
||||
cuhk03_classic_split=self.cuhk03_classic_split, market1501_500k=self.market1501_500k
|
||||
root=self.root,
|
||||
name=name,
|
||||
split_id=self.split_id,
|
||||
cuhk03_labeled=self.cuhk03_labeled,
|
||||
cuhk03_classic_split=self.cuhk03_classic_split,
|
||||
market1501_500k=self.market1501_500k
|
||||
)
|
||||
|
||||
for img_path, pid, camid in dataset.train:
|
||||
|
@ -112,37 +111,52 @@ class ImageDataManager(BaseDataManager):
|
|||
self._num_train_cams += dataset.num_train_cams
|
||||
|
||||
self.train_sampler = build_train_sampler(
|
||||
train, self.train_sampler,
|
||||
train,
|
||||
self.train_sampler,
|
||||
train_batch_size=self.train_batch_size,
|
||||
num_instances=self.num_instances,
|
||||
)
|
||||
|
||||
self.trainloader = DataLoader(
|
||||
ImageDataset(train, transform=self.transform_train), sampler=self.train_sampler,
|
||||
batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers,
|
||||
pin_memory=self.use_gpu, drop_last=True
|
||||
ImageDataset(train, transform=self.transform_train),
|
||||
sampler=self.train_sampler,
|
||||
batch_size=self.train_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
print('=> Initializing TEST (target) datasets')
|
||||
print('=> Initializing test (target) datasets')
|
||||
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||
|
||||
for name in self.target_names:
|
||||
dataset = init_imgreid_dataset(
|
||||
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
|
||||
cuhk03_classic_split=self.cuhk03_classic_split, market1501_500k=self.market1501_500k
|
||||
root=self.root,
|
||||
name=name,
|
||||
split_id=self.split_id,
|
||||
cuhk03_labeled=self.cuhk03_labeled,
|
||||
cuhk03_classic_split=self.cuhk03_classic_split,
|
||||
market1501_500k=self.market1501_500k
|
||||
)
|
||||
|
||||
self.testloader_dict[name]['query'] = DataLoader(
|
||||
ImageDataset(dataset.query, transform=self.transform_test),
|
||||
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
||||
pin_memory=self.use_gpu, drop_last=False
|
||||
batch_size=self.test_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
self.testloader_dict[name]['gallery'] = DataLoader(
|
||||
ImageDataset(dataset.gallery, transform=self.transform_test),
|
||||
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
||||
pin_memory=self.use_gpu, drop_last=False
|
||||
batch_size=self.test_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
self.testdataset_dict[name]['query'] = dataset.query
|
||||
|
@ -161,9 +175,6 @@ class ImageDataManager(BaseDataManager):
|
|||
|
||||
|
||||
class VideoDataManager(BaseDataManager):
|
||||
"""
|
||||
Video-ReID data manager
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_gpu,
|
||||
|
@ -179,7 +190,7 @@ class VideoDataManager(BaseDataManager):
|
|||
self.sample_method = sample_method
|
||||
self.image_training = image_training
|
||||
|
||||
print('=> Initializing TRAIN (source) datasets')
|
||||
print('=> Initializing train (source) datasets')
|
||||
train = []
|
||||
self._num_train_pids = 0
|
||||
self._num_train_cams = 0
|
||||
|
@ -209,21 +220,33 @@ class VideoDataManager(BaseDataManager):
|
|||
if image_training:
|
||||
# each batch has image data of shape (batch, channel, height, width)
|
||||
self.trainloader = DataLoader(
|
||||
ImageDataset(train, transform=self.transform_train), sampler=self.train_sampler,
|
||||
batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers,
|
||||
pin_memory=self.use_gpu, drop_last=True
|
||||
ImageDataset(train, transform=self.transform_train),
|
||||
sampler=self.train_sampler,
|
||||
batch_size=self.train_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
else:
|
||||
# each batch has image data of shape (batch, seq_len, channel, height, width)
|
||||
# note: this requires new training scripts
|
||||
self.trainloader = DataLoader(
|
||||
VideoDataset(train, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_train),
|
||||
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
|
||||
pin_memory=self.use_gpu, drop_last=True
|
||||
VideoDataset(
|
||||
train,
|
||||
seq_len=self.seq_len,
|
||||
sample_method=self.sample_method,
|
||||
transform=self.transform_train
|
||||
),
|
||||
batch_size=self.train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
print('=> Initializing TEST (target) datasets')
|
||||
print('=> Initializing test (target) datasets')
|
||||
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
|
||||
|
||||
|
@ -231,15 +254,31 @@ class VideoDataManager(BaseDataManager):
|
|||
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
|
||||
|
||||
self.testloader_dict[name]['query'] = DataLoader(
|
||||
VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_test),
|
||||
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
||||
pin_memory=self.use_gpu, drop_last=False,
|
||||
VideoDataset(
|
||||
dataset.query,
|
||||
seq_len=self.seq_len,
|
||||
sample_method=self.sample_method,
|
||||
transform=self.transform_test
|
||||
),
|
||||
batch_size=self.test_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
self.testloader_dict[name]['gallery'] = DataLoader(
|
||||
VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_test),
|
||||
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
||||
pin_memory=self.use_gpu, drop_last=False,
|
||||
VideoDataset(
|
||||
dataset.gallery,
|
||||
seq_len=self.seq_len,
|
||||
sample_method=self.sample_method,
|
||||
transform=self.transform_test
|
||||
),
|
||||
batch_size=self.test_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.workers,
|
||||
pin_memory=self.use_gpu,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
self.testdataset_dict[name]['query'] = dataset.query
|
||||
|
|
|
@ -31,7 +31,7 @@ __imgreid_factory = {
|
|||
'prid450s': PRID450S,
|
||||
'ilids': iLIDS,
|
||||
'sensereid': SenseReID,
|
||||
'prid': PRID,
|
||||
'prid': PRID
|
||||
}
|
||||
|
||||
|
||||
|
@ -39,17 +39,19 @@ __vidreid_factory = {
|
|||
'mars': Mars,
|
||||
'ilidsvid': iLIDSVID,
|
||||
'prid2011': PRID2011,
|
||||
'dukemtmcvidreid': DukeMTMCVidReID,
|
||||
'dukemtmcvidreid': DukeMTMCVidReID
|
||||
}
|
||||
|
||||
|
||||
def init_imgreid_dataset(name, **kwargs):
|
||||
if name not in list(__imgreid_factory.keys()):
|
||||
raise KeyError('Invalid dataset, got "{}", but expected to be one of {}'.format(name, list(__imgreid_factory.keys())))
|
||||
avai_datasets = list(__imgreid_factory.keys())
|
||||
if name not in avai_datasets:
|
||||
raise RuntimeError('Invalid dataset name. Received "{}", but expected to be one of {}'.format(name, avai_datasets))
|
||||
return __imgreid_factory[name](**kwargs)
|
||||
|
||||
|
||||
def init_vidreid_dataset(name, **kwargs):
|
||||
if name not in list(__vidreid_factory.keys()):
|
||||
raise KeyError('Invalid dataset, got "{}", but expected to be one of {}'.format(name, list(__vidreid_factory.keys())))
|
||||
avai_datasets = list(__vidreid_factory.keys())
|
||||
if name not in avai_datasets:
|
||||
raise RuntimeError('Invalid dataset name. Received "{}", but expected to be one of {}'.format(name, avai_datasets))
|
||||
return __vidreid_factory[name](**kwargs)
|
|
@ -18,7 +18,59 @@ class BaseDataset(object):
|
|||
if not osp.exists(f):
|
||||
raise RuntimeError('"{}" is not found'.format(f))
|
||||
|
||||
def get_imagedata_info(self, data):
|
||||
def extract_data_info(self, data):
|
||||
"""Extract info from data list
|
||||
|
||||
Return: num_pids, num_imgs, num_cams, optional
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_pids(self, data):
|
||||
return self.extract_data_info(data)[0]
|
||||
|
||||
def get_num_cams(self, data):
|
||||
return self.extract_data_info(data)[2]
|
||||
|
||||
def init_attributes(self, train, query, gallery):
|
||||
self._train = train
|
||||
self._query = query
|
||||
self._gallery = gallery
|
||||
self._num_train_pids = self.get_num_pids(train)
|
||||
self._num_train_cams = self.get_num_cams(train)
|
||||
|
||||
def print_dataset_statistics(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
# train list containing (img_path, pid, camid)
|
||||
return self._train
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
# query list containing (img_path, pid, camid)
|
||||
return self._query
|
||||
|
||||
@property
|
||||
def gallery(self):
|
||||
# gallery list containing (img_path, pid, camid)
|
||||
return self._gallery
|
||||
|
||||
@property
|
||||
def num_train_pids(self):
|
||||
# number of train identities
|
||||
return self._num_train_pids
|
||||
|
||||
@property
|
||||
def num_train_cams(self):
|
||||
# number of train camera views
|
||||
return self._num_train_cams
|
||||
|
||||
|
||||
class BaseImageDataset(BaseDataset):
|
||||
"""Base class of image-reid dataset"""
|
||||
|
||||
def extract_data_info(self, data):
|
||||
pids, cams = [], []
|
||||
for _, pid, camid in data:
|
||||
pids += [pid]
|
||||
|
@ -30,32 +82,10 @@ class BaseDataset(object):
|
|||
num_imgs = len(data)
|
||||
return num_pids, num_imgs, num_cams
|
||||
|
||||
def get_videodata_info(self, data, return_tracklet_stats=False):
|
||||
pids, cams, tracklet_stats = [], [], []
|
||||
for img_paths, pid, camid in data:
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
tracklet_stats += [len(img_paths)]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_tracklets = len(data)
|
||||
if return_tracklet_stats:
|
||||
return num_pids, num_tracklets, num_cams, tracklet_stats
|
||||
return num_pids, num_tracklets, num_cams
|
||||
|
||||
def print_dataset_statistics(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseImageDataset(BaseDataset):
|
||||
"""Base class of image-reid dataset"""
|
||||
|
||||
def print_dataset_statistics(self, train, query, gallery):
|
||||
num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
|
||||
num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
|
||||
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
|
||||
num_train_pids, num_train_imgs, num_train_cams = self.extract_data_info(train)
|
||||
num_query_pids, num_query_imgs, num_query_cams = self.extract_data_info(query)
|
||||
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.extract_data_info(gallery)
|
||||
|
||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
print(' ----------------------------------------')
|
||||
|
@ -70,15 +100,30 @@ class BaseImageDataset(BaseDataset):
|
|||
class BaseVideoDataset(BaseDataset):
|
||||
"""Base class of video-reid dataset"""
|
||||
|
||||
def extract_data_info(self, data, return_tracklet_stats=False):
|
||||
pids, cams, tracklet_stats = [], [], []
|
||||
for img_paths, pid, camid in data:
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
tracklet_stats += [len(img_paths)]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_tracklets = len(data)
|
||||
if return_tracklet_stats:
|
||||
return num_pids, num_tracklets, num_cams, tracklet_stats
|
||||
return num_pids, num_tracklets, num_cams
|
||||
|
||||
def print_dataset_statistics(self, train, query, gallery):
|
||||
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
|
||||
self.get_videodata_info(train, return_tracklet_stats=True)
|
||||
self.extract_data_info(train, return_tracklet_stats=True)
|
||||
|
||||
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
|
||||
self.get_videodata_info(query, return_tracklet_stats=True)
|
||||
self.extract_data_info(query, return_tracklet_stats=True)
|
||||
|
||||
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
|
||||
self.get_videodata_info(gallery, return_tracklet_stats=True)
|
||||
self.extract_data_info(gallery, return_tracklet_stats=True)
|
||||
|
||||
tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
|
||||
min_num = np.min(tracklet_stats)
|
||||
|
|
|
@ -63,17 +63,11 @@ class CUHK01(BaseImageDataset):
|
|||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def extract_file(self):
|
||||
if not osp.exists(self.campus_dir):
|
||||
print('Extracting files')
|
||||
|
|
|
@ -81,17 +81,11 @@ class CUHK03(BaseImageDataset):
|
|||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def preprocess_split(self):
|
||||
"""
|
||||
This function is a bit complex and ugly, what it does is
|
||||
|
|
|
@ -57,17 +57,11 @@ class DukeMTMCreID(BaseImageDataset):
|
|||
query = self.process_dir(self.query_dir, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
return
|
||||
|
|
|
@ -60,17 +60,11 @@ class DukeMTMCVidReID(BaseVideoDataset):
|
|||
query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
|
||||
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
|
||||
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
return
|
||||
|
|
|
@ -67,17 +67,11 @@ class GRID(BaseImageDataset):
|
|||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
return
|
||||
|
|
|
@ -58,17 +58,11 @@ class iLIDS(BaseImageDataset):
|
|||
|
||||
train, query, gallery = self.process_split(split)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
return
|
||||
|
|
|
@ -65,17 +65,11 @@ class iLIDSVID(BaseVideoDataset):
|
|||
query = self.process_data(test_dirs, cam1=True, cam2=False)
|
||||
gallery = self.process_data(test_dirs, cam1=False, cam2=True)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
|
||||
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
|
||||
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
return
|
||||
|
|
|
@ -57,17 +57,11 @@ class Market1501(BaseImageDataset):
|
|||
if self.market1501_500k:
|
||||
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
|
|
@ -66,17 +66,11 @@ class Mars(BaseVideoDataset):
|
|||
query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
|
||||
gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
|
||||
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
|
||||
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
|
||||
|
||||
def get_names(self, fpath):
|
||||
names = []
|
||||
with open(fpath, 'r') as f:
|
||||
|
|
|
@ -84,17 +84,10 @@ class MSMT17(BaseImageDataset):
|
|||
#train += val
|
||||
#num_train_imgs += num_val_imgs
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def process_dir(self, dir_path, list_path):
|
||||
with open(list_path, 'r') as txt:
|
||||
lines = txt.readlines()
|
||||
|
|
|
@ -60,17 +60,11 @@ class PRID(BaseImageDataset):
|
|||
|
||||
train, query, gallery = self.process_split(split)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def prepare_split(self):
|
||||
if not osp.exists(self.split_path):
|
||||
print('Creating splits ...')
|
||||
|
|
|
@ -58,17 +58,11 @@ class PRID2011(BaseVideoDataset):
|
|||
query = self.process_dir(test_dirs, cam1=True, cam2=False)
|
||||
gallery = self.process_dir(test_dirs, cam1=False, cam2=True)
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
|
||||
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
|
||||
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
|
||||
|
||||
def process_dir(self, dirnames, cam1=True, cam2=True):
|
||||
tracklets = []
|
||||
dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
|
||||
|
|
|
@ -65,17 +65,11 @@ class PRID450S(BaseImageDataset):
|
|||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
return
|
||||
|
|
|
@ -52,17 +52,12 @@ class SenseReID(BaseImageDataset):
|
|||
|
||||
query = self.process_dir(self.query_dir)
|
||||
gallery = self.process_dir(self.gallery_dir)
|
||||
train = copy.deepcopy(query) # dummy variable
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(query, query, gallery)
|
||||
|
||||
self.train = copy.deepcopy(query) # only used to initialize trainloader
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
def process_dir(self, dir_path):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
|
|
|
@ -64,17 +64,11 @@ class VIPeR(BaseImageDataset):
|
|||
query = [tuple(item) for item in query]
|
||||
gallery = [tuple(item) for item in gallery]
|
||||
|
||||
self.init_attributes(train, query, gallery)
|
||||
|
||||
if verbose:
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue