mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
update datasets
This commit is contained in:
parent
19efb50255
commit
1f78afaef8
@ -36,7 +36,7 @@ class Market1501(object):
|
||||
query_dir = osp.join(root, 'query')
|
||||
gallery_dir = osp.join(root, 'bounding_box_test')
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
self._check_before_run()
|
||||
|
||||
train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)
|
||||
@ -136,7 +136,7 @@ class CUHK03(object):
|
||||
split_new_det_mat_path = osp.join(root, 'cuhk03_new_protocol_config_detected.mat')
|
||||
split_new_lab_mat_path = osp.join(root, 'cuhk03_new_protocol_config_labeled.mat')
|
||||
|
||||
def __init__(self, split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False):
|
||||
def __init__(self, split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
|
||||
self._check_before_run()
|
||||
self._preprocess()
|
||||
|
||||
@ -384,7 +384,7 @@ class Mars(object):
|
||||
track_test_info_path = osp.join(root, 'info/tracks_test_info.mat')
|
||||
query_IDX_path = osp.join(root, 'info/query_IDX.mat')
|
||||
|
||||
def __init__(self, min_seq_len=0):
|
||||
def __init__(self, min_seq_len=0, **kwargs):
|
||||
self._check_before_run()
|
||||
|
||||
# prepare meta data
|
||||
@ -523,7 +523,7 @@ class iLIDSVID(object):
|
||||
cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1')
|
||||
cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2')
|
||||
|
||||
def __init__(self, split_id=0):
|
||||
def __init__(self, split_id=0, **kwargs):
|
||||
self._download_data()
|
||||
self._check_before_run()
|
||||
|
||||
@ -691,7 +691,7 @@ class PRID(object):
|
||||
cam_a_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_a')
|
||||
cam_b_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_b')
|
||||
|
||||
def __init__(self, split_id=0, min_seq_len=0):
|
||||
def __init__(self, split_id=0, min_seq_len=0, **kwargs):
|
||||
self._check_before_run()
|
||||
splits = read_json(self.split_path)
|
||||
if split_id >= len(splits):
|
||||
@ -783,10 +783,10 @@ __factory = {
|
||||
def get_names():
|
||||
return __factory.keys()
|
||||
|
||||
def init_dataset(name, *args, **kwargs):
|
||||
def init_dataset(name, **kwargs):
|
||||
if name not in __factory.keys():
|
||||
raise KeyError("Unknown dataset: {}".format(name))
|
||||
return __factory[name](*args, **kwargs)
|
||||
return __factory[name](**kwargs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user