diff --git a/torchreid/datasets/cuhk01.py b/torchreid/datasets/cuhk01.py index 8231f28..6c0d156 100644 --- a/torchreid/datasets/cuhk01.py +++ b/torchreid/datasets/cuhk01.py @@ -42,10 +42,10 @@ class CUHK01(BaseImageDataset): self.campus_dir = osp.join(self.dataset_dir, 'campus') self.split_path = osp.join(self.dataset_dir, 'splits.json') - self._extract_file() - self._check_before_run() + self.extract_file() + self.check_before_run() - self._prepare_split() + self.prepare_split() splits = read_json(self.split_path) if split_id >= len(splits): raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) @@ -71,7 +71,7 @@ class CUHK01(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _extract_file(self): + def extract_file(self): if not osp.exists(self.campus_dir): print('Extracting files') zip_ref = zipfile.ZipFile(self.zip_path, 'r') @@ -79,14 +79,14 @@ class CUHK01(BaseImageDataset): zip_ref.close() print('Files extracted') - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) if not osp.exists(self.campus_dir): raise RuntimeError('"{}" is not available'.format(self.campus_dir)) - def _prepare_split(self): + def prepare_split(self): """ Image name format: 0001001.png, where first four digits represent identity and last four digits represent cameras. Camera 1&2 are considered the same @@ -146,4 +146,4 @@ class CUHK01(BaseImageDataset): write_json(splits, self.split_path) print('Split file saved to {}'.format(self.split_path)) - print('Splits created') + print('Splits created') \ No newline at end of file diff --git a/torchreid/datasets/cuhk03.py b/torchreid/datasets/cuhk03.py index c1d75bd..5aa48dd 100644 --- a/torchreid/datasets/cuhk03.py +++ b/torchreid/datasets/cuhk03.py @@ -58,8 +58,8 @@ class CUHK03(BaseImageDataset): self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat') self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat') - self._check_before_run() - self._preprocess() + self.check_before_run() + self.preprocess() if cuhk03_labeled: image_type = 'labeled' @@ -89,7 +89,7 @@ class CUHK03(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -102,7 +102,7 @@ class CUHK03(BaseImageDataset): if not osp.exists(self.split_new_lab_mat_path): raise RuntimeError('"{}" is not available'.format(self.split_new_lab_mat_path)) - def _preprocess(self): + def preprocess(self): """ This function is a bit complex and ugly, what it does is 1. Extract data from cuhk-03.mat and save as png images. diff --git a/torchreid/datasets/dukemtmcreid.py b/torchreid/datasets/dukemtmcreid.py index bbf16ce..44753fe 100644 --- a/torchreid/datasets/dukemtmcreid.py +++ b/torchreid/datasets/dukemtmcreid.py @@ -44,12 +44,12 @@ class DukeMTMCreID(BaseImageDataset): self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') - self._download_data() - self._check_before_run() + self.download_data() + self.check_before_run() - train = self._process_dir(self.train_dir, relabel=True) - query = self._process_dir(self.query_dir, relabel=False) - gallery = self._process_dir(self.gallery_dir, relabel=False) + train = self.process_dir(self.train_dir, relabel=True) + query = self.process_dir(self.query_dir, relabel=False) + gallery = self.process_dir(self.gallery_dir, relabel=False) if verbose: print('=> DukeMTMC-reID loaded') @@ -63,7 +63,7 @@ class DukeMTMCreID(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _download_data(self): + def download_data(self): if osp.exists(self.dataset_dir): print('This dataset has been downloaded.') return @@ -80,7 +80,7 @@ class DukeMTMCreID(BaseImageDataset): zip_ref.extractall(self.dataset_dir) zip_ref.close() - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -91,7 +91,7 @@ class DukeMTMCreID(BaseImageDataset): if not osp.exists(self.gallery_dir): raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) - def _process_dir(self, dir_path, relabel=False): + def process_dir(self, dir_path, relabel=False): img_paths = glob.glob(osp.join(dir_path, '*.jpg')) pattern = re.compile(r'([-\d]+)_c(\d)') diff --git a/torchreid/datasets/dukemtmcvidreid.py b/torchreid/datasets/dukemtmcvidreid.py index d15c979..f165670 100644 --- a/torchreid/datasets/dukemtmcvidreid.py +++ b/torchreid/datasets/dukemtmcvidreid.py @@ -47,13 +47,13 @@ class DukeMTMCVidReID(BaseVideoDataset): self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json') self.min_seq_len = min_seq_len - self._download_data() - self._check_before_run() + self.download_data() + self.check_before_run() print('Note: if root path is changed, the previously generated json files need to be re-generated (so delete them first)') - train = self._process_dir(self.train_dir, self.split_train_json_path, relabel=True) - query = self._process_dir(self.query_dir, self.split_query_json_path, relabel=False) - gallery = self._process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) + train = self.process_dir(self.train_dir, self.split_train_json_path, relabel=True) + query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False) + gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) if verbose: print('=> DukeMTMC-VideoReID loaded') @@ -67,7 +67,7 @@ class DukeMTMCVidReID(BaseVideoDataset): self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) - def _download_data(self): + def download_data(self): if osp.exists(self.dataset_dir): print('This dataset has been downloaded.') return @@ -84,7 +84,7 @@ class DukeMTMCVidReID(BaseVideoDataset): zip_ref.extractall(self.dataset_dir) zip_ref.close() - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -95,7 +95,7 @@ class DukeMTMCVidReID(BaseVideoDataset): if not osp.exists(self.gallery_dir): raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) - def _process_dir(self, dir_path, json_path, relabel): + def process_dir(self, dir_path, json_path, relabel): if osp.exists(json_path): print('=> {} generated before, awesome!'.format(json_path)) split = read_json(json_path) diff --git a/torchreid/datasets/grid.py b/torchreid/datasets/grid.py index b4441b4..3e7fe7c 100644 --- a/torchreid/datasets/grid.py +++ b/torchreid/datasets/grid.py @@ -44,10 +44,10 @@ class GRID(BaseImageDataset): self.split_mat_path = osp.join(self.dataset_dir, 'underground_reid', 'features_and_partitions.mat') self.split_path = osp.join(self.dataset_dir, 'splits.json') - self._download_data() - self._check_before_run() + self.download_data() + self.check_before_run() - self._prepare_split() + self.prepare_split() splits = read_json(self.split_path) if split_id >= len(splits): raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) @@ -73,7 +73,7 @@ class GRID(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -84,7 +84,7 @@ class GRID(BaseImageDataset): if not osp.exists(self.split_mat_path): raise RuntimeError('"{}" is not available'.format(self.split_mat_path)) - def _download_data(self): + def download_data(self): if osp.exists(self.dataset_dir): print('This dataset has been downloaded.') return @@ -101,7 +101,7 @@ class GRID(BaseImageDataset): zip_ref.extractall(self.dataset_dir) zip_ref.close() - def _prepare_split(self): + def prepare_split(self): if not osp.exists(self.split_path): print('Creating 10 random splits') split_mat = loadmat(self.split_mat_path) @@ -152,4 +152,4 @@ class GRID(BaseImageDataset): write_json(splits, self.split_path) print('Split file saved to {}'.format(self.split_path)) - print('Splits created') + print('Splits created') \ No newline at end of file diff --git a/torchreid/datasets/ilidsvid.py b/torchreid/datasets/ilidsvid.py index f7f7f98..bcbca5f 100644 --- a/torchreid/datasets/ilidsvid.py +++ b/torchreid/datasets/ilidsvid.py @@ -46,10 +46,10 @@ class iLIDSVID(BaseVideoDataset): self.cam_1_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam1') self.cam_2_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam2') - self._download_data() - self._check_before_run() + self.download_data() + self.check_before_run() - self._prepare_split() + self.prepare_split() splits = read_json(self.split_path) if split_id >= len(splits): raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) @@ -57,9 +57,9 @@ class iLIDSVID(BaseVideoDataset): train_dirs, test_dirs = split['train'], split['test'] print('# train identites: {}, # test identites {}'.format(len(train_dirs), len(test_dirs))) - train = self._process_data(train_dirs, cam1=True, cam2=True) - query = self._process_data(test_dirs, cam1=True, cam2=False) - gallery = self._process_data(test_dirs, cam1=False, cam2=True) + train = self.process_data(train_dirs, cam1=True, cam2=True) + query = self.process_data(test_dirs, cam1=True, cam2=False) + gallery = self.process_data(test_dirs, cam1=False, cam2=True) if verbose: print('=> iLIDS-VID loaded') @@ -73,7 +73,7 @@ class iLIDSVID(BaseVideoDataset): self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) - def _download_data(self): + def download_data(self): if osp.exists(self.dataset_dir): print('This dataset has been downloaded.') return @@ -89,7 +89,7 @@ class iLIDSVID(BaseVideoDataset): tar.extractall(path=self.dataset_dir) tar.close() - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -98,7 +98,7 @@ class iLIDSVID(BaseVideoDataset): if not osp.exists(self.split_dir): raise RuntimeError('"{}" is not available'.format(self.split_dir)) - def _prepare_split(self): + def prepare_split(self): if not osp.exists(self.split_path): print('Creating splits ...') mat_split_data = loadmat(self.split_mat_path)['ls_set'] @@ -142,7 +142,7 @@ class iLIDSVID(BaseVideoDataset): print('Splits created') - def _process_data(self, dirnames, cam1=True, cam2=True): + def process_data(self, dirnames, cam1=True, cam2=True): tracklets = [] dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} diff --git a/torchreid/datasets/market1501.py b/torchreid/datasets/market1501.py index 6e5c875..ed68835 100644 --- a/torchreid/datasets/market1501.py +++ b/torchreid/datasets/market1501.py @@ -42,13 +42,13 @@ class Market1501(BaseImageDataset): self.extra_gallery_dir = osp.join(self.dataset_dir, 'images') self.market1501_500k = market1501_500k - self._check_before_run() + self.check_before_run() - train = self._process_dir(self.train_dir, relabel=True) - query = self._process_dir(self.query_dir, relabel=False) - gallery = self._process_dir(self.gallery_dir, relabel=False) + train = self.process_dir(self.train_dir, relabel=True) + query = self.process_dir(self.query_dir, relabel=False) + gallery = self.process_dir(self.gallery_dir, relabel=False) if self.market1501_500k: - gallery += self._process_dir(self.extra_gallery_dir, relabel=False) + gallery += self.process_dir(self.extra_gallery_dir, relabel=False) if verbose: print('=> Market1501 loaded') @@ -62,7 +62,7 @@ class Market1501(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -75,7 +75,7 @@ class Market1501(BaseImageDataset): if self.market1501_500k and not osp.exists(self.extra_gallery_dir): raise RuntimeError('"{}" is not available'.format(self.extra_gallery_dir)) - def _process_dir(self, dir_path, relabel=False): + def process_dir(self, dir_path, relabel=False): img_paths = glob.glob(osp.join(dir_path, '*.jpg')) pattern = re.compile(r'([-\d]+)_c(\d)') diff --git a/torchreid/datasets/mars.py b/torchreid/datasets/mars.py index 52f765d..446a60e 100644 --- a/torchreid/datasets/mars.py +++ b/torchreid/datasets/mars.py @@ -43,11 +43,11 @@ class Mars(BaseVideoDataset): self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat') self.query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat') - self._check_before_run() + self.check_before_run() # prepare meta data - train_names = self._get_names(self.train_name_path) - test_names = self._get_names(self.test_name_path) + train_names = self.get_names(self.train_name_path) + test_names = self.get_names(self.test_name_path) track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) @@ -56,9 +56,9 @@ class Mars(BaseVideoDataset): gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] track_gallery = track_test[gallery_IDX,:] - train = self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) - query = self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) - gallery = self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) + train = self.process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) + query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) + gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) if verbose: print('=> MARS loaded') @@ -72,7 +72,7 @@ class Mars(BaseVideoDataset): self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -87,7 +87,7 @@ class Mars(BaseVideoDataset): if not osp.exists(self.query_IDX_path): raise RuntimeError('"{}" is not available'.format(self.query_IDX_path)) - def _get_names(self, fpath): + def get_names(self, fpath): names = [] with open(fpath, 'r') as f: for line in f: @@ -95,7 +95,7 @@ class Mars(BaseVideoDataset): names.append(new_line) return names - def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): + def process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): assert home_dir in ['bbox_train', 'bbox_test'] num_tracklets = meta_data.shape[0] pid_list = list(set(meta_data[:,2].tolist())) diff --git a/torchreid/datasets/msmt17.py b/torchreid/datasets/msmt17.py index 0f1bdc9..a6b6354 100644 --- a/torchreid/datasets/msmt17.py +++ b/torchreid/datasets/msmt17.py @@ -69,11 +69,11 @@ class MSMT17(BaseImageDataset): self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt') self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt') - self._check_before_run() - train = self._process_dir(self.train_dir, self.list_train_path) - #val = self._process_dir(self.train_dir, self.list_val_path) - query = self._process_dir(self.test_dir, self.list_query_path) - gallery = self._process_dir(self.test_dir, self.list_gallery_path) + self.check_before_run() + train = self.process_dir(self.train_dir, self.list_train_path) + #val = self.process_dir(self.train_dir, self.list_val_path) + query = self.process_dir(self.test_dir, self.list_query_path) + gallery = self.process_dir(self.test_dir, self.list_gallery_path) # To fairly compare with published methods, don't use val images for training #train += val @@ -91,7 +91,7 @@ class MSMT17(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -100,7 +100,7 @@ class MSMT17(BaseImageDataset): if not osp.exists(self.test_dir): raise RuntimeError('"{}" is not available'.format(self.test_dir)) - def _process_dir(self, dir_path, list_path): + def process_dir(self, dir_path, list_path): with open(list_path, 'r') as txt: lines = txt.readlines() dataset = [] diff --git a/torchreid/datasets/prid2011.py b/torchreid/datasets/prid2011.py index e413200..9acc718 100644 --- a/torchreid/datasets/prid2011.py +++ b/torchreid/datasets/prid2011.py @@ -42,7 +42,7 @@ class PRID2011(BaseVideoDataset): self.cam_a_path = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a') self.cam_b_path = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b') - self._check_before_run() + self.check_before_run() splits = read_json(self.split_path) if split_id >= len(splits): raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) @@ -50,9 +50,9 @@ class PRID2011(BaseVideoDataset): train_dirs, test_dirs = split['train'], split['test'] print('# train identites: {}, # test identites {}'.format(len(train_dirs), len(test_dirs))) - train = self._process_data(train_dirs, cam1=True, cam2=True) - query = self._process_data(test_dirs, cam1=True, cam2=False) - gallery = self._process_data(test_dirs, cam1=False, cam2=True) + train = self.process_dir(train_dirs, cam1=True, cam2=True) + query = self.process_dir(test_dirs, cam1=True, cam2=False) + gallery = self.process_dir(test_dirs, cam1=False, cam2=True) if verbose: print('=> PRID2011 loaded') @@ -66,12 +66,12 @@ class PRID2011(BaseVideoDataset): self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) - def _process_data(self, dirnames, cam1=True, cam2=True): + def process_dir(self, dirnames, cam1=True, cam2=True): tracklets = [] dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} diff --git a/torchreid/datasets/prid450s.py b/torchreid/datasets/prid450s.py index 52dcb10..6d0258a 100644 --- a/torchreid/datasets/prid450s.py +++ b/torchreid/datasets/prid450s.py @@ -43,10 +43,10 @@ class PRID450S(BaseImageDataset): self.cam_a_path = osp.join(self.dataset_dir, 'cam_a') self.cam_b_path = osp.join(self.dataset_dir, 'cam_b') - self._download_data() - self._check_before_run() + self.download_data() + self.check_before_run() - self._prepare_split() + self.prepare_split() splits = read_json(self.split_path) if split_id >= len(splits): raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) @@ -72,7 +72,7 @@ class PRID450S(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -81,7 +81,7 @@ class PRID450S(BaseImageDataset): if not osp.exists(self.cam_b_path): raise RuntimeError('"{}" is not available'.format(self.cam_b_path)) - def _download_data(self): + def download_data(self): if osp.exists(self.dataset_dir): print('This dataset has been downloaded.') return @@ -98,7 +98,7 @@ class PRID450S(BaseImageDataset): zip_ref.extractall(self.dataset_dir) zip_ref.close() - def _prepare_split(self): + def prepare_split(self): if not osp.exists(self.split_path): cam_a_imgs = sorted(glob.glob(osp.join(self.cam_a_path, 'img_*.png'))) cam_b_imgs = sorted(glob.glob(osp.join(self.cam_b_path, 'img_*.png'))) @@ -145,4 +145,4 @@ class PRID450S(BaseImageDataset): write_json(splits, self.split_path) print('Split file saved to {}'.format(self.split_path)) - print('Splits created') + print('Splits created') \ No newline at end of file diff --git a/torchreid/datasets/sensereid.py b/torchreid/datasets/sensereid.py index 834339c..0e5b8c8 100644 --- a/torchreid/datasets/sensereid.py +++ b/torchreid/datasets/sensereid.py @@ -44,10 +44,10 @@ class SenseReID(BaseImageDataset): self.query_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_probe') self.gallery_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_gallery') - self._check_before_run() + self.check_before_run() - query = self._process_dir(self.query_dir) - gallery = self._process_dir(self.gallery_dir) + query = self.process_dir(self.query_dir) + gallery = self.process_dir(self.gallery_dir) if verbose: print('=> SenseReID loaded (test only)') @@ -61,7 +61,7 @@ class SenseReID(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -70,7 +70,7 @@ class SenseReID(BaseImageDataset): if not osp.exists(self.gallery_dir): raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) - def _process_dir(self, dir_path): + def process_dir(self, dir_path): img_paths = glob.glob(osp.join(dir_path, '*.jpg')) dataset = [] diff --git a/torchreid/datasets/viper.py b/torchreid/datasets/viper.py index 05f1def..32e98f4 100755 --- a/torchreid/datasets/viper.py +++ b/torchreid/datasets/viper.py @@ -42,10 +42,10 @@ class VIPeR(BaseImageDataset): self.cam_b_path = osp.join(self.dataset_dir, 'VIPeR', 'cam_b') self.split_path = osp.join(self.dataset_dir, 'splits.json') - self._download_data() - self._check_before_run() + self.download_data() + self.check_before_run() - self._prepare_split() + self.prepare_split() splits = read_json(self.split_path) if split_id >= len(splits): raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) @@ -71,7 +71,7 @@ class VIPeR(BaseImageDataset): self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) - def _download_data(self): + def download_data(self): if osp.exists(self.dataset_dir): print('This dataset has been downloaded.') return @@ -88,7 +88,7 @@ class VIPeR(BaseImageDataset): zip_ref.extractall(self.dataset_dir) zip_ref.close() - def _check_before_run(self): + def check_before_run(self): """Check if all files are available before going deeper""" if not osp.exists(self.dataset_dir): raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) @@ -97,7 +97,7 @@ class VIPeR(BaseImageDataset): if not osp.exists(self.cam_b_path): raise RuntimeError('"{}" is not available'.format(self.cam_b_path)) - def _prepare_split(self): + def prepare_split(self): if not osp.exists(self.split_path): print('Creating 10 random splits of train ids and test ids') @@ -160,4 +160,4 @@ class VIPeR(BaseImageDataset): write_json(splits, self.split_path) print('Split file saved to {}'.format(self.split_path)) - print('Splits created') + print('Splits created') \ No newline at end of file