This commit is contained in:
KaiyangZhou 2018-04-23 11:02:10 +01:00
parent d244c7044a
commit a6e274d35d

View File

@ -133,10 +133,10 @@ class CUHK03(object):
self._preprocess() self._preprocess()
if cuhk03_labeled: if cuhk03_labeled:
print("Use 'labeled' images") image_type = 'labeled'
split_path = self.split_labeled_path split_path = self.split_labeled_path
else: else:
print("Use 'detected' images") image_type = 'detected'
split_path = self.split_detected_path split_path = self.split_detected_path
splits = read_json(split_path) splits = read_json(split_path)
@ -158,7 +158,7 @@ class CUHK03(object):
num_gallery_imgs = split['num_gallery_imgs'] num_gallery_imgs = split['num_gallery_imgs']
num_total_imgs = num_train_imgs + num_query_imgs num_total_imgs = num_train_imgs + num_query_imgs
print("=> CUHK03 loaded") print("=> CUHK03 ({}) loaded".format(image_type))
print("Dataset statistics:") print("Dataset statistics:")
print(" ------------------------------") print(" ------------------------------")
print(" subset | # ids | # images") print(" subset | # ids | # images")
@ -244,10 +244,10 @@ class CUHK03(object):
for pid in range(num_pids): for pid in range(num_pids):
img_paths_v0 = _process_images(camp[pid,:5], campid, pid, 0, imgs_dir) img_paths_v0 = _process_images(camp[pid,:5], campid, pid, 0, imgs_dir)
img_paths_v1 = _process_images(camp[pid,5:], campid, pid, 1, imgs_dir) img_paths_v1 = _process_images(camp[pid,5:], campid, pid, 1, imgs_dir)
img_paths_both = img_paths_v0 + img_paths_v1 img_paths_both_views = img_paths_v0 + img_paths_v1
assert len(img_paths_both) > 0, "campid{}-pid{} have no images".format(campid, pid) assert len(img_paths_both_views) > 0, "campid{}-pid{} has no images".format(campid, pid)
meta_data.append((campid, pid, img_paths_both)) meta_data.append((campid, pid, img_paths_both_views))
print("done camera pair {}".format(campid+1)) print("done camera pair {} with {} identities".format(campid+1, num_pids))
return meta_data return meta_data
meta_detected = _extract_img('detected') meta_detected = _extract_img('detected')
@ -261,13 +261,13 @@ class CUHK03(object):
if [campid+1, pid+1] in test_split: if [campid+1, pid+1] in test_split:
for img_path in img_paths: for img_path in img_paths:
camid = int(img_path.split('_')[2]) camid = int(osp.basename(img_path).split('_')[2])
test.append((img_path, num_test_pids, camid)) test.append((img_path, num_test_pids, camid))
num_test_pids += 1 num_test_pids += 1
num_test_imgs += len(img_paths) num_test_imgs += len(img_paths)
else: else:
for img_path in img_paths: for img_path in img_paths:
camid = int(img_path.split('_')[2]) camid = int(osp.basename(img_path).split('_')[2])
train.append((img_path, num_train_pids, camid)) train.append((img_path, num_train_pids, camid))
num_train_pids += 1 num_train_pids += 1
num_train_imgs += len(img_paths) num_train_imgs += len(img_paths)
@ -738,4 +738,4 @@ def init_dataset(name, *args, **kwargs):
return __factory[name](*args, **kwargs) return __factory[name](*args, **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
pass dataset = CUHK03()