update datasets

This commit is contained in:
KaiyangZhou 2018-04-24 10:55:59 +01:00
parent 19efb50255
commit 1f78afaef8

View File

@ -36,7 +36,7 @@ class Market1501(object):
query_dir = osp.join(root, 'query')
gallery_dir = osp.join(root, 'bounding_box_test')
def __init__(self):
def __init__(self, **kwargs):
self._check_before_run()
train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)
@ -136,7 +136,7 @@ class CUHK03(object):
split_new_det_mat_path = osp.join(root, 'cuhk03_new_protocol_config_detected.mat')
split_new_lab_mat_path = osp.join(root, 'cuhk03_new_protocol_config_labeled.mat')
def __init__(self, split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False):
def __init__(self, split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
self._check_before_run()
self._preprocess()
@ -384,7 +384,7 @@ class Mars(object):
track_test_info_path = osp.join(root, 'info/tracks_test_info.mat')
query_IDX_path = osp.join(root, 'info/query_IDX.mat')
def __init__(self, min_seq_len=0):
def __init__(self, min_seq_len=0, **kwargs):
self._check_before_run()
# prepare meta data
@ -523,7 +523,7 @@ class iLIDSVID(object):
cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1')
cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2')
def __init__(self, split_id=0):
def __init__(self, split_id=0, **kwargs):
self._download_data()
self._check_before_run()
@ -691,7 +691,7 @@ class PRID(object):
cam_a_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_a')
cam_b_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_b')
def __init__(self, split_id=0, min_seq_len=0):
def __init__(self, split_id=0, min_seq_len=0, **kwargs):
self._check_before_run()
splits = read_json(self.split_path)
if split_id >= len(splits):
@ -783,10 +783,10 @@ __factory = {
def get_names():
return __factory.keys()
def init_dataset(name, *args, **kwargs):
def init_dataset(name, **kwargs):
if name not in __factory.keys():
raise KeyError("Unknown dataset: {}".format(name))
return __factory[name](*args, **kwargs)
return __factory[name](**kwargs)
if __name__ == '__main__':
pass