diff --git a/torchreid/data_manager.py b/torchreid/data_manager.py index dc2e127..aab9df4 100644 --- a/torchreid/data_manager.py +++ b/torchreid/data_manager.py @@ -46,7 +46,9 @@ class BaseDataManager(object): self.num_instances = num_instances transform_train, transform_test = build_transforms( - self.height, self.width, random_erase=self.random_erase, color_jitter=self.color_jitter, + self.height, self.width, + random_erase=self.random_erase, + color_jitter=self.color_jitter, color_aug=self.color_aug ) self.transform_train = transform_train @@ -61,22 +63,15 @@ class BaseDataManager(object): return self._num_train_cams def return_dataloaders(self): - """ - Return trainloader and testloader dictionary - """ + """Return trainloader and testloader dictionary""" return self.trainloader, self.testloader_dict def return_testdataset_by_name(self, name): - """ - Return query and gallery, each containing a list of (img_path, pid, camid). - """ + """Return query and gallery, each containing a list of (img_path, pid, camid)""" return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery'] class ImageDataManager(BaseDataManager): - """ - Image-ReID data manager - """ def __init__(self, use_gpu, @@ -92,15 +87,19 @@ class ImageDataManager(BaseDataManager): self.cuhk03_classic_split = cuhk03_classic_split self.market1501_500k = market1501_500k - print('=> Initializing TRAIN (source) datasets') + print('=> Initializing train (source) datasets') train = [] self._num_train_pids = 0 self._num_train_cams = 0 for name in self.source_names: dataset = init_imgreid_dataset( - root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled, - cuhk03_classic_split=self.cuhk03_classic_split, market1501_500k=self.market1501_500k + root=self.root, + name=name, + split_id=self.split_id, + cuhk03_labeled=self.cuhk03_labeled, + cuhk03_classic_split=self.cuhk03_classic_split, + market1501_500k=self.market1501_500k ) for img_path, pid, camid in dataset.train: @@ -112,37 +111,52 @@ class ImageDataManager(BaseDataManager): self._num_train_cams += dataset.num_train_cams self.train_sampler = build_train_sampler( - train, self.train_sampler, + train, + self.train_sampler, train_batch_size=self.train_batch_size, num_instances=self.num_instances, ) self.trainloader = DataLoader( - ImageDataset(train, transform=self.transform_train), sampler=self.train_sampler, - batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers, - pin_memory=self.use_gpu, drop_last=True + ImageDataset(train, transform=self.transform_train), + sampler=self.train_sampler, + batch_size=self.train_batch_size, + shuffle=False, + num_workers=self.workers, + pin_memory=self.use_gpu, + drop_last=True ) - print('=> Initializing TEST (target) datasets') + print('=> Initializing test (target) datasets') self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names} self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names} for name in self.target_names: dataset = init_imgreid_dataset( - root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled, - cuhk03_classic_split=self.cuhk03_classic_split, market1501_500k=self.market1501_500k + root=self.root, + name=name, + split_id=self.split_id, + cuhk03_labeled=self.cuhk03_labeled, + cuhk03_classic_split=self.cuhk03_classic_split, + market1501_500k=self.market1501_500k ) self.testloader_dict[name]['query'] = DataLoader( ImageDataset(dataset.query, transform=self.transform_test), - batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, - pin_memory=self.use_gpu, drop_last=False + batch_size=self.test_batch_size, + shuffle=False, + num_workers=self.workers, + pin_memory=self.use_gpu, + drop_last=False ) self.testloader_dict[name]['gallery'] = DataLoader( ImageDataset(dataset.gallery, transform=self.transform_test), - batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, - pin_memory=self.use_gpu, drop_last=False + batch_size=self.test_batch_size, + shuffle=False, + num_workers=self.workers, + pin_memory=self.use_gpu, + drop_last=False ) self.testdataset_dict[name]['query'] = dataset.query @@ -161,9 +175,6 @@ class ImageDataManager(BaseDataManager): class VideoDataManager(BaseDataManager): - """ - Video-ReID data manager - """ def __init__(self, use_gpu, @@ -179,7 +190,7 @@ class VideoDataManager(BaseDataManager): self.sample_method = sample_method self.image_training = image_training - print('=> Initializing TRAIN (source) datasets') + print('=> Initializing train (source) datasets') train = [] self._num_train_pids = 0 self._num_train_cams = 0 @@ -209,21 +220,33 @@ class VideoDataManager(BaseDataManager): if image_training: # each batch has image data of shape (batch, channel, height, width) self.trainloader = DataLoader( - ImageDataset(train, transform=self.transform_train), sampler=self.train_sampler, - batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers, - pin_memory=self.use_gpu, drop_last=True + ImageDataset(train, transform=self.transform_train), + sampler=self.train_sampler, + batch_size=self.train_batch_size, + shuffle=False, + num_workers=self.workers, + pin_memory=self.use_gpu, + drop_last=True ) else: # each batch has image data of shape (batch, seq_len, channel, height, width) # note: this requires new training scripts self.trainloader = DataLoader( - VideoDataset(train, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_train), - batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers, - pin_memory=self.use_gpu, drop_last=True + VideoDataset( + train, + seq_len=self.seq_len, + sample_method=self.sample_method, + transform=self.transform_train + ), + batch_size=self.train_batch_size, + shuffle=True, + num_workers=self.workers, + pin_memory=self.use_gpu, + drop_last=True ) - print('=> Initializing TEST (target) datasets') + print('=> Initializing test (target) datasets') self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names} self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names} @@ -231,15 +254,31 @@ class VideoDataManager(BaseDataManager): dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id) self.testloader_dict[name]['query'] = DataLoader( - VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_test), - batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, - pin_memory=self.use_gpu, drop_last=False, + VideoDataset( + dataset.query, + seq_len=self.seq_len, + sample_method=self.sample_method, + transform=self.transform_test + ), + batch_size=self.test_batch_size, + shuffle=False, + num_workers=self.workers, + pin_memory=self.use_gpu, + drop_last=False ) self.testloader_dict[name]['gallery'] = DataLoader( - VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_test), - batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, - pin_memory=self.use_gpu, drop_last=False, + VideoDataset( + dataset.gallery, + seq_len=self.seq_len, + sample_method=self.sample_method, + transform=self.transform_test + ), + batch_size=self.test_batch_size, + shuffle=False, + num_workers=self.workers, + pin_memory=self.use_gpu, + drop_last=False ) self.testdataset_dict[name]['query'] = dataset.query diff --git a/torchreid/datasets/__init__.py b/torchreid/datasets/__init__.py index 6d00028..b4f6fa5 100644 --- a/torchreid/datasets/__init__.py +++ b/torchreid/datasets/__init__.py @@ -31,7 +31,7 @@ __imgreid_factory = { 'prid450s': PRID450S, 'ilids': iLIDS, 'sensereid': SenseReID, - 'prid': PRID, + 'prid': PRID } @@ -39,17 +39,19 @@ __vidreid_factory = { 'mars': Mars, 'ilidsvid': iLIDSVID, 'prid2011': PRID2011, - 'dukemtmcvidreid': DukeMTMCVidReID, + 'dukemtmcvidreid': DukeMTMCVidReID } def init_imgreid_dataset(name, **kwargs): - if name not in list(__imgreid_factory.keys()): - raise KeyError('Invalid dataset, got "{}", but expected to be one of {}'.format(name, list(__imgreid_factory.keys()))) + avai_datasets = list(__imgreid_factory.keys()) + if name not in avai_datasets: + raise RuntimeError('Invalid dataset name. Received "{}", but expected to be one of {}'.format(name, avai_datasets)) return __imgreid_factory[name](**kwargs) def init_vidreid_dataset(name, **kwargs): - if name not in list(__vidreid_factory.keys()): - raise KeyError('Invalid dataset, got "{}", but expected to be one of {}'.format(name, list(__vidreid_factory.keys()))) + avai_datasets = list(__vidreid_factory.keys()) + if name not in avai_datasets: + raise RuntimeError('Invalid dataset name. Received "{}", but expected to be one of {}'.format(name, avai_datasets)) return __vidreid_factory[name](**kwargs) \ No newline at end of file diff --git a/torchreid/datasets/bases.py b/torchreid/datasets/bases.py index 9dfb63c..c0b6ddb 100644 --- a/torchreid/datasets/bases.py +++ b/torchreid/datasets/bases.py @@ -18,7 +18,59 @@ class BaseDataset(object): if not osp.exists(f): raise RuntimeError('"{}" is not found'.format(f)) - def get_imagedata_info(self, data): + def extract_data_info(self, data): + """Extract info from data list + + Return: num_pids, num_imgs, num_cams, optional + """ + raise NotImplementedError + + def get_num_pids(self, data): + return self.extract_data_info(data)[0] + + def get_num_cams(self, data): + return self.extract_data_info(data)[2] + + def init_attributes(self, train, query, gallery): + self._train = train + self._query = query + self._gallery = gallery + self._num_train_pids = self.get_num_pids(train) + self._num_train_cams = self.get_num_cams(train) + + def print_dataset_statistics(self): + raise NotImplementedError + + @property + def train(self): + # train list containing (img_path, pid, camid) + return self._train + + @property + def query(self): + # query list containing (img_path, pid, camid) + return self._query + + @property + def gallery(self): + # gallery list containing (img_path, pid, camid) + return self._gallery + + @property + def num_train_pids(self): + # number of train identities + return self._num_train_pids + + @property + def num_train_cams(self): + # number of train camera views + return self._num_train_cams + + +class BaseImageDataset(BaseDataset): + """Base class of image-reid dataset""" + + def extract_data_info(self, data): pids, cams = [], [] for _, pid, camid in data: pids += [pid] @@ -30,32 +82,10 @@ class BaseDataset(object): num_imgs = len(data) return num_pids, num_imgs, num_cams - def get_videodata_info(self, data, return_tracklet_stats=False): - pids, cams, tracklet_stats = [], [], [] - for img_paths, pid, camid in data: - pids += [pid] - cams += [camid] - tracklet_stats += [len(img_paths)] - pids = set(pids) - cams = set(cams) - num_pids = len(pids) - num_cams = len(cams) - num_tracklets = len(data) - if return_tracklet_stats: - return num_pids, num_tracklets, num_cams, tracklet_stats - return num_pids, num_tracklets, num_cams - - def print_dataset_statistics(self): - raise NotImplementedError - - -class BaseImageDataset(BaseDataset): - """Base class of image-reid dataset""" - def print_dataset_statistics(self, train, query, gallery): - num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) - num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) - num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) + num_train_pids, num_train_imgs, num_train_cams = self.extract_data_info(train) + num_query_pids, num_query_imgs, num_query_cams = self.extract_data_info(query) + num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.extract_data_info(gallery) print('=> Loaded {}'.format(self.__class__.__name__)) print(' ----------------------------------------') @@ -70,15 +100,30 @@ class BaseImageDataset(BaseDataset): class BaseVideoDataset(BaseDataset): """Base class of video-reid dataset""" + def extract_data_info(self, data, return_tracklet_stats=False): + pids, cams, tracklet_stats = [], [], [] + for img_paths, pid, camid in data: + pids += [pid] + cams += [camid] + tracklet_stats += [len(img_paths)] + pids = set(pids) + cams = set(cams) + num_pids = len(pids) + num_cams = len(cams) + num_tracklets = len(data) + if return_tracklet_stats: + return num_pids, num_tracklets, num_cams, tracklet_stats + return num_pids, num_tracklets, num_cams + def print_dataset_statistics(self, train, query, gallery): num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ - self.get_videodata_info(train, return_tracklet_stats=True) + self.extract_data_info(train, return_tracklet_stats=True) num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ - self.get_videodata_info(query, return_tracklet_stats=True) + self.extract_data_info(query, return_tracklet_stats=True) num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ - self.get_videodata_info(gallery, return_tracklet_stats=True) + self.extract_data_info(gallery, return_tracklet_stats=True) tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats min_num = np.min(tracklet_stats) diff --git a/torchreid/datasets/cuhk01.py b/torchreid/datasets/cuhk01.py index 63f5635..c530e94 100644 --- a/torchreid/datasets/cuhk01.py +++ b/torchreid/datasets/cuhk01.py @@ -63,17 +63,11 @@ class CUHK01(BaseImageDataset): query = [tuple(item) for item in query] gallery = [tuple(item) for item in gallery] + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def extract_file(self): if not osp.exists(self.campus_dir): print('Extracting files') diff --git a/torchreid/datasets/cuhk03.py b/torchreid/datasets/cuhk03.py index 71a0e3f..4c8c93f 100644 --- a/torchreid/datasets/cuhk03.py +++ b/torchreid/datasets/cuhk03.py @@ -81,17 +81,11 @@ class CUHK03(BaseImageDataset): query = split['query'] gallery = split['gallery'] + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def preprocess_split(self): """ This function is a bit complex and ugly, what it does is diff --git a/torchreid/datasets/dukemtmcreid.py b/torchreid/datasets/dukemtmcreid.py index 5ed3aec..0f613c1 100644 --- a/torchreid/datasets/dukemtmcreid.py +++ b/torchreid/datasets/dukemtmcreid.py @@ -57,17 +57,11 @@ class DukeMTMCreID(BaseImageDataset): query = self.process_dir(self.query_dir, relabel=False) gallery = self.process_dir(self.gallery_dir, relabel=False) + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def download_data(self): if osp.exists(self.dataset_dir): return diff --git a/torchreid/datasets/dukemtmcvidreid.py b/torchreid/datasets/dukemtmcvidreid.py index 6ca43de..857e2d2 100644 --- a/torchreid/datasets/dukemtmcvidreid.py +++ b/torchreid/datasets/dukemtmcvidreid.py @@ -60,17 +60,11 @@ class DukeMTMCVidReID(BaseVideoDataset): query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False) gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train) - self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) - self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) - def download_data(self): if osp.exists(self.dataset_dir): return diff --git a/torchreid/datasets/grid.py b/torchreid/datasets/grid.py index 68cd458..0bea9d7 100644 --- a/torchreid/datasets/grid.py +++ b/torchreid/datasets/grid.py @@ -67,17 +67,11 @@ class GRID(BaseImageDataset): query = [tuple(item) for item in query] gallery = [tuple(item) for item in gallery] + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def download_data(self): if osp.exists(self.dataset_dir): return diff --git a/torchreid/datasets/ilids.py b/torchreid/datasets/ilids.py index 21f67f6..73eed12 100644 --- a/torchreid/datasets/ilids.py +++ b/torchreid/datasets/ilids.py @@ -58,17 +58,11 @@ class iLIDS(BaseImageDataset): train, query, gallery = self.process_split(split) + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def download_data(self): if osp.exists(self.dataset_dir): return diff --git a/torchreid/datasets/ilidsvid.py b/torchreid/datasets/ilidsvid.py index bcd6197..936751c 100644 --- a/torchreid/datasets/ilidsvid.py +++ b/torchreid/datasets/ilidsvid.py @@ -65,17 +65,11 @@ class iLIDSVID(BaseVideoDataset): query = self.process_data(test_dirs, cam1=True, cam2=False) gallery = self.process_data(test_dirs, cam1=False, cam2=True) + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train) - self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) - self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) - def download_data(self): if osp.exists(self.dataset_dir): return diff --git a/torchreid/datasets/market1501.py b/torchreid/datasets/market1501.py index 25bfb72..669b2a6 100644 --- a/torchreid/datasets/market1501.py +++ b/torchreid/datasets/market1501.py @@ -57,17 +57,11 @@ class Market1501(BaseImageDataset): if self.market1501_500k: gallery += self.process_dir(self.extra_gallery_dir, relabel=False) + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def process_dir(self, dir_path, relabel=False): img_paths = glob.glob(osp.join(dir_path, '*.jpg')) pattern = re.compile(r'([-\d]+)_c(\d)') diff --git a/torchreid/datasets/mars.py b/torchreid/datasets/mars.py index 1af5d25..12e00d6 100644 --- a/torchreid/datasets/mars.py +++ b/torchreid/datasets/mars.py @@ -66,17 +66,11 @@ class Mars(BaseVideoDataset): query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train) - self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) - self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) - def get_names(self, fpath): names = [] with open(fpath, 'r') as f: diff --git a/torchreid/datasets/msmt17.py b/torchreid/datasets/msmt17.py index 6007376..ce629aa 100644 --- a/torchreid/datasets/msmt17.py +++ b/torchreid/datasets/msmt17.py @@ -84,17 +84,10 @@ class MSMT17(BaseImageDataset): #train += val #num_train_imgs += num_val_imgs + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def process_dir(self, dir_path, list_path): with open(list_path, 'r') as txt: lines = txt.readlines() diff --git a/torchreid/datasets/prid.py b/torchreid/datasets/prid.py index 151873c..13538c9 100644 --- a/torchreid/datasets/prid.py +++ b/torchreid/datasets/prid.py @@ -60,17 +60,11 @@ class PRID(BaseImageDataset): train, query, gallery = self.process_split(split) + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def prepare_split(self): if not osp.exists(self.split_path): print('Creating splits ...') diff --git a/torchreid/datasets/prid2011.py b/torchreid/datasets/prid2011.py index d5b3050..6b95317 100644 --- a/torchreid/datasets/prid2011.py +++ b/torchreid/datasets/prid2011.py @@ -58,17 +58,11 @@ class PRID2011(BaseVideoDataset): query = self.process_dir(test_dirs, cam1=True, cam2=False) gallery = self.process_dir(test_dirs, cam1=False, cam2=True) + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train) - self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) - self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) - def process_dir(self, dirnames, cam1=True, cam2=True): tracklets = [] dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} diff --git a/torchreid/datasets/prid450s.py b/torchreid/datasets/prid450s.py index 437a3aa..60e1f6b 100644 --- a/torchreid/datasets/prid450s.py +++ b/torchreid/datasets/prid450s.py @@ -65,17 +65,11 @@ class PRID450S(BaseImageDataset): query = [tuple(item) for item in query] gallery = [tuple(item) for item in gallery] + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def download_data(self): if osp.exists(self.dataset_dir): return diff --git a/torchreid/datasets/sensereid.py b/torchreid/datasets/sensereid.py index 337d3c0..7d38487 100644 --- a/torchreid/datasets/sensereid.py +++ b/torchreid/datasets/sensereid.py @@ -52,17 +52,12 @@ class SenseReID(BaseImageDataset): query = self.process_dir(self.query_dir) gallery = self.process_dir(self.gallery_dir) + train = copy.deepcopy(query) # dummy variable + + self.init_attributes(train, query, gallery) if verbose: - self.print_dataset_statistics(query, query, gallery) - - self.train = copy.deepcopy(query) # only used to initialize trainloader - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) + self.print_dataset_statistics(train, query, gallery) def process_dir(self, dir_path): img_paths = glob.glob(osp.join(dir_path, '*.jpg')) diff --git a/torchreid/datasets/viper.py b/torchreid/datasets/viper.py index 05a777a..fbc0598 100755 --- a/torchreid/datasets/viper.py +++ b/torchreid/datasets/viper.py @@ -64,17 +64,11 @@ class VIPeR(BaseImageDataset): query = [tuple(item) for item in query] gallery = [tuple(item) for item in gallery] + self.init_attributes(train, query, gallery) + if verbose: self.print_dataset_statistics(train, query, gallery) - self.train = train - self.query = query - self.gallery = gallery - - self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) - self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) - self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def download_data(self): if osp.exists(self.dataset_dir): return