deep-person-reid/torchreid/data/datasets/video/prid2011.py

81 lines
2.8 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
from __future__ import division, print_function, absolute_import
2019-03-21 20:59:54 +08:00
import glob
2019-12-01 10:35:44 +08:00
import os.path as osp
2019-03-21 20:59:54 +08:00
from torchreid.utils import read_json
2019-12-01 11:31:32 +08:00
from ..dataset import VideoDataset
2019-03-21 20:59:54 +08:00
class PRID2011(VideoDataset):
2019-03-22 01:28:14 +08:00
"""PRID2011.
2019-03-21 20:59:54 +08:00
Reference:
2019-03-22 01:28:14 +08:00
Hirzer et al. Person Re-Identification by Descriptive and
Discriminative Classification. SCIA 2011.
2019-03-21 20:59:54 +08:00
2019-03-22 01:28:14 +08:00
URL: `<https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/>`_
2019-03-21 20:59:54 +08:00
Dataset statistics:
2019-03-22 01:28:14 +08:00
- identities: 200.
- tracklets: 400.
- cameras: 2.
2019-03-21 20:59:54 +08:00
"""
dataset_dir = 'prid2011'
dataset_url = None
def __init__(self, root='', split_id=0, **kwargs):
self.root = osp.abspath(osp.expanduser(root))
self.dataset_dir = osp.join(self.root, self.dataset_dir)
self.download_dataset(self.dataset_dir, self.dataset_url)
self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json')
2019-12-01 10:35:44 +08:00
self.cam_a_dir = osp.join(
self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a'
)
self.cam_b_dir = osp.join(
self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b'
)
required_files = [self.dataset_dir, self.cam_a_dir, self.cam_b_dir]
2019-03-21 20:59:54 +08:00
self.check_before_run(required_files)
splits = read_json(self.split_path)
2019-12-01 10:35:44 +08:00
if split_id >= len(splits):
raise ValueError(
'split_id exceeds range, received {}, but expected between 0 and {}'
.format(split_id,
len(splits) - 1)
)
2019-03-21 20:59:54 +08:00
split = splits[split_id]
train_dirs, test_dirs = split['train'], split['test']
train = self.process_dir(train_dirs, cam1=True, cam2=True)
query = self.process_dir(test_dirs, cam1=True, cam2=False)
gallery = self.process_dir(test_dirs, cam1=False, cam2=True)
super(PRID2011, self).__init__(train, query, gallery, **kwargs)
def process_dir(self, dirnames, cam1=True, cam2=True):
tracklets = []
2019-12-01 10:35:44 +08:00
dirname2pid = {dirname: i for i, dirname in enumerate(dirnames)}
2019-03-21 20:59:54 +08:00
for dirname in dirnames:
if cam1:
person_dir = osp.join(self.cam_a_dir, dirname)
img_names = glob.glob(osp.join(person_dir, '*.png'))
assert len(img_names) > 0
img_names = tuple(img_names)
pid = dirname2pid[dirname]
tracklets.append((img_names, pid, 0))
if cam2:
person_dir = osp.join(self.cam_b_dir, dirname)
img_names = glob.glob(osp.join(person_dir, '*.png'))
assert len(img_names) > 0
img_names = tuple(img_names)
pid = dirname2pid[dirname]
tracklets.append((img_names, pid, 1))
2019-12-01 10:35:44 +08:00
return tracklets