add args.root

pull/17/head
KaiyangZhou 2018-05-02 15:59:06 +01:00
parent e324ff921b
commit 44bc7b25ff
6 changed files with 105 additions and 91 deletions

View File

@ -14,12 +14,6 @@ from scipy.misc import imsave
from utils import mkdir_if_missing, write_json, read_json
"""Dataset classes
Each class has a 'root' variable pointing to the './data/specific-dataset' directory.
If you store dataset in custom paths, please change 'root' accordingly.
"""
"""Image ReID"""
class Market1501(object):
@ -35,12 +29,14 @@ class Market1501(object):
# identities: 1501 (+1 for background)
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
"""
root = './data/market1501'
train_dir = osp.join(root, 'bounding_box_train')
query_dir = osp.join(root, 'query')
gallery_dir = osp.join(root, 'bounding_box_test')
dataset_dir = 'market1501'
def __init__(self, root='data', **kwargs):
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
def __init__(self, **kwargs):
self._check_before_run()
train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)
@ -71,8 +67,8 @@ class Market1501(object):
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.root):
raise RuntimeError("'{}' is not available".format(self.root))
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.train_dir):
raise RuntimeError("'{}' is not available".format(self.train_dir))
if not osp.exists(self.query_dir):
@ -124,23 +120,25 @@ class CUHK03(object):
split_id (int): split index (default: 0)
cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False)
"""
root = './data/cuhk03'
data_dir = osp.join(root, 'cuhk03_release')
raw_mat_path = osp.join(data_dir, 'cuhk-03.mat')
imgs_detected_dir = osp.join(root, 'images_detected')
imgs_labeled_dir = osp.join(root, 'images_labeled')
split_classic_det_json_path = osp.join(root, 'splits_classic_detected.json')
split_classic_lab_json_path = osp.join(root, 'splits_classic_labeled.json')
split_new_det_json_path = osp.join(root, 'splits_new_detected.json')
split_new_lab_json_path = osp.join(root, 'splits_new_labeled.json')
split_new_det_mat_path = osp.join(root, 'cuhk03_new_protocol_config_detected.mat')
split_new_lab_mat_path = osp.join(root, 'cuhk03_new_protocol_config_labeled.mat')
dataset_dir = 'cuhk03'
def __init__(self, root='data', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
self.dataset_dir = osp.join(root, self.dataset_dir)
self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected')
self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled')
self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json')
self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json')
self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json')
self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json')
self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
def __init__(self, split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
self._check_before_run()
self._preprocess()
@ -192,16 +190,16 @@ class CUHK03(object):
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.root):
raise RuntimeError("'{}' is not available".format(self.root))
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.data_dir):
raise RuntimeError("'{}' is not available".format(self.root))
raise RuntimeError("'{}' is not available".format(self.data_dir))
if not osp.exists(self.raw_mat_path):
raise RuntimeError("'{}' is not available".format(self.root))
raise RuntimeError("'{}' is not available".format(self.raw_mat_path))
if not osp.exists(self.split_new_det_mat_path):
raise RuntimeError("'{}' is not available".format(self.root))
raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path))
if not osp.exists(self.split_new_lab_mat_path):
raise RuntimeError("'{}' is not available".format(self.root))
raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path))
def _preprocess(self):
"""
@ -377,12 +375,14 @@ class DukeMTMCreID(object):
# images:16522 (train) + 2228 (query) + 17661 (gallery)
# cameras: 8
"""
root = './data/dukemtmc-reid'
train_dir = osp.join(root, 'DukeMTMC-reID/bounding_box_train')
query_dir = osp.join(root, 'DukeMTMC-reID/query')
gallery_dir = osp.join(root, 'DukeMTMC-reID/bounding_box_test')
dataset_dir = 'dukemtmc-reid'
def __init__(self, root='data', **kwargs):
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
def __init__(self, **kwargs):
self._check_before_run()
train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)
@ -413,8 +413,8 @@ class DukeMTMCreID(object):
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.root):
raise RuntimeError("'{}' is not available".format(self.root))
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.train_dir):
raise RuntimeError("'{}' is not available".format(self.train_dir))
if not osp.exists(self.query_dir):
@ -458,15 +458,17 @@ class MSMT17(object):
# images: 32621 (train) + 11659 (query) + 82161 (gallery)
# cameras: 15
"""
root = './data/msmt17'
train_dir = osp.join(root, 'MSMT17_V1/train')
test_dir = osp.join(root, 'MSMT17_V1/test')
list_train_path = osp.join(root, 'MSMT17_V1/list_train.txt')
list_val_path = osp.join(root, 'MSMT17_V1/list_val.txt')
list_query_path = osp.join(root, 'MSMT17_V1/list_query.txt')
list_gallery_path = osp.join(root, 'MSMT17_V1/list_gallery.txt')
dataset_dir = 'msmt17'
def __init__(self, root='data', **kwargs):
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'MSMT17_V1/train')
self.test_dir = osp.join(self.dataset_dir, 'MSMT17_V1/test')
self.list_train_path = osp.join(self.dataset_dir, 'MSMT17_V1/list_train.txt')
self.list_val_path = osp.join(self.dataset_dir, 'MSMT17_V1/list_val.txt')
self.list_query_path = osp.join(self.dataset_dir, 'MSMT17_V1/list_query.txt')
self.list_gallery_path = osp.join(self.dataset_dir, 'MSMT17_V1/list_gallery.txt')
def __init__(self, **kwargs):
self._check_before_run()
train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, self.list_train_path)
val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path)
@ -501,8 +503,8 @@ class MSMT17(object):
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.root):
raise RuntimeError("'{}' is not available".format(self.root))
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.train_dir):
raise RuntimeError("'{}' is not available".format(self.train_dir))
if not osp.exists(self.test_dir):
@ -542,14 +544,16 @@ class Mars(object):
Args:
min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0).
"""
root = './data/mars'
train_name_path = osp.join(root, 'info/train_name.txt')
test_name_path = osp.join(root, 'info/test_name.txt')
track_train_info_path = osp.join(root, 'info/tracks_train_info.mat')
track_test_info_path = osp.join(root, 'info/tracks_test_info.mat')
query_IDX_path = osp.join(root, 'info/query_IDX.mat')
dataset_dir = 'mars'
def __init__(self, root='data', min_seq_len=0, **kwargs):
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_name_path = osp.join(self.dataset_dir, 'info/train_name.txt')
self.test_name_path = osp.join(self.dataset_dir, 'info/test_name.txt')
self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat')
self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat')
self.query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat')
def __init__(self, min_seq_len=0, **kwargs):
self._check_before_run()
# prepare meta data
@ -603,8 +607,8 @@ class Mars(object):
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.root):
raise RuntimeError("'{}' is not available".format(self.root))
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.train_name_path):
raise RuntimeError("'{}' is not available".format(self.train_name_path))
if not osp.exists(self.test_name_path):
@ -652,7 +656,7 @@ class Mars(object):
assert len(set(camnames)) == 1, "Error: images are captured under different cameras!"
# append image names with directory information
img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names]
img_paths = [osp.join(self.dataset_dir, home_dir, img_name[:4], img_name) for img_name in img_names]
if len(img_paths) >= min_seq_len:
img_paths = tuple(img_paths)
tracklets.append((img_paths, pid, camid))
@ -679,16 +683,18 @@ class iLIDSVID(object):
Args:
split_id (int): indicates which split to use. There are totally 10 splits.
"""
root = './data/ilids-vid'
dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar'
data_dir = osp.join(root, 'i-LIDS-VID')
split_dir = osp.join(root, 'train-test people splits')
split_mat_path = osp.join(split_dir, 'train_test_splits_ilidsvid.mat')
split_path = osp.join(root, 'splits.json')
cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1')
cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2')
dataset_dir = 'ilids-vid'
def __init__(self, root='data', split_id=0, **kwargs):
self.dataset_dir = osp.join(root, self.dataset_dir)
self.dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar'
self.data_dir = osp.join(self.dataset_dir, 'i-LIDS-VID')
self.split_dir = osp.join(self.dataset_dir, 'train-test people splits')
self.split_mat_path = osp.join(self.dataset_dir, 'train_test_splits_ilidsvid.mat')
self.split_path = osp.join(self.dataset_dir, 'splits.json')
self.cam_1_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam1')
self.cam_2_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam2')
def __init__(self, split_id=0, **kwargs):
self._download_data()
self._check_before_run()
@ -737,12 +743,12 @@ class iLIDSVID(object):
self.num_gallery_pids = num_gallery_pids
def _download_data(self):
if osp.exists(self.root):
if osp.exists(self.dataset_dir):
print("This dataset has been downloaded.")
return
mkdir_if_missing(self.root)
fpath = osp.join(self.root, osp.basename(self.dataset_url))
mkdir_if_missing(self.dataset_dir)
fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
print("Downloading iLIDS-VID dataset")
url_opener = urllib.URLopener()
@ -750,13 +756,13 @@ class iLIDSVID(object):
print("Extracting files")
tar = tarfile.open(fpath)
tar.extractall(path=self.root)
tar.extractall(path=self.dataset_dir)
tar.close()
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.root):
raise RuntimeError("'{}' is not available".format(self.root))
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.data_dir):
raise RuntimeError("'{}' is not available".format(self.data_dir))
if not osp.exists(self.split_dir):
@ -850,13 +856,15 @@ class PRID(object):
split_id (int): indicates which split to use. There are totally 10 splits.
min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0).
"""
root = './data/prid2011'
dataset_url = 'https://files.icg.tugraz.at/f/6ab7e8ce8f/?raw=1'
split_path = osp.join(root, 'splits_prid2011.json')
cam_a_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_a')
cam_b_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_b')
dataset_dir = 'prid2011'
def __init__(self, root='data', split_id=0, min_seq_len=0, **kwargs):
self.dataset_dir = osp.join(root, self.dataset_dir)
self.dataset_url = 'https://files.icg.tugraz.at/f/6ab7e8ce8f/?raw=1'
self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json')
self.cam_a_path = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a')
self.cam_b_path = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b')
def __init__(self, split_id=0, min_seq_len=0, **kwargs):
self._check_before_run()
splits = read_json(self.split_path)
if split_id >= len(splits):
@ -903,8 +911,8 @@ class PRID(object):
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.root):
raise RuntimeError("'{}' is not available".format(self.root))
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
def _process_data(self, dirnames, cam1=True, cam2=True):
tracklets = []
@ -941,6 +949,7 @@ __factory = {
'market1501': Market1501,
'cuhk03': CUHK03,
'dukemtmcreid': DukeMTMCreID,
'msmt17': MSMT17,
'mars': Mars,
'ilidsvid': iLIDSVID,
'prid': PRID,
@ -955,4 +964,4 @@ def init_dataset(name, **kwargs):
return __factory[name](**kwargs)
if __name__ == '__main__':
pass
d = PRID()

View File

@ -23,6 +23,7 @@ from eval_metrics import evaluate
parser = argparse.ArgumentParser(description='Train image model with center loss')
# Datasets
parser.add_argument('--root', type=str, default='data', help="root path to data directory")
parser.add_argument('-d', '--dataset', type=str, default='market1501',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,
@ -94,7 +95,7 @@ def main():
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_dataset(
name=args.dataset, split_id=args.split_id,
root=args.root, name=args.dataset, split_id=args.split_id,
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
)

View File

@ -24,6 +24,7 @@ from optimizers import init_optim
parser = argparse.ArgumentParser(description='Train image model with cross entropy loss')
# Datasets
parser.add_argument('--root', type=str, default='data', help="root path to data directory")
parser.add_argument('-d', '--dataset', type=str, default='market1501',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,
@ -93,7 +94,7 @@ def main():
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_dataset(
name=args.dataset, split_id=args.split_id,
root=args.root, name=args.dataset, split_id=args.split_id,
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
)

View File

@ -25,6 +25,7 @@ from optimizers import init_optim
parser = argparse.ArgumentParser(description='Train image model with cross entropy loss and hard triplet loss')
# Datasets
parser.add_argument('--root', type=str, default='data', help="root path to data directory")
parser.add_argument('-d', '--dataset', type=str, default='market1501',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,
@ -99,7 +100,7 @@ def main():
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_dataset(
name=args.dataset, split_id=args.split_id,
root=args.root, name=args.dataset, split_id=args.split_id,
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
)

View File

@ -24,6 +24,7 @@ from optimizers import init_optim
parser = argparse.ArgumentParser(description='Train video model with cross entropy loss')
# Datasets
parser.add_argument('--root', type=str, default='data', help="root path to data directory")
parser.add_argument('-d', '--dataset', type=str, default='mars',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,
@ -86,7 +87,7 @@ def main():
print("Currently using CPU (GPU is highly recommended)")
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_dataset(name=args.dataset)
dataset = data_manager.init_dataset(root=args.root, name=args.dataset)
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),

View File

@ -25,6 +25,7 @@ from optimizers import init_optim
parser = argparse.ArgumentParser(description='Train video model with cross entropy loss')
# Datasets
parser.add_argument('--root', type=str, default='data', help="root path to data directory")
parser.add_argument('-d', '--dataset', type=str, default='mars',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,
@ -92,7 +93,7 @@ def main():
print("Currently using CPU (GPU is highly recommended)")
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_dataset(name=args.dataset)
dataset = data_manager.init_dataset(root=args.root, name=args.dataset)
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),