deep-person-reid/torchreid/datasets/prid2011.py

94 lines
3.4 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
2018-07-02 17:17:14 +08:00
import os
import glob
import re
import sys
import urllib
import tarfile
import zipfile
import os.path as osp
from scipy.io import loadmat
import numpy as np
import h5py
from scipy.misc import imsave
2018-08-15 16:48:17 +08:00
from torchreid.utils.iotools import mkdir_if_missing, write_json, read_json
2018-11-05 06:59:46 +08:00
from .bases import BaseVideoDataset
2018-07-02 17:17:14 +08:00
2018-11-05 06:59:46 +08:00
class PRID2011(BaseVideoDataset):
2018-07-02 17:17:14 +08:00
"""
PRID2011
Reference:
Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011.
URL: https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/
Dataset statistics:
# identities: 200
# tracklets: 400
# cameras: 2
"""
dataset_dir = 'prid2011'
2018-07-02 18:57:01 +08:00
def __init__(self, root='data', split_id=0, min_seq_len=0, verbose=True, **kwargs):
2018-07-02 17:17:14 +08:00
self.dataset_dir = osp.join(root, self.dataset_dir)
self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json')
self.cam_a_path = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a')
self.cam_b_path = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b')
self._check_before_run()
splits = read_json(self.split_path)
if split_id >= len(splits):
raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1))
split = splits[split_id]
train_dirs, test_dirs = split['train'], split['test']
print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs)))
2018-11-05 06:59:46 +08:00
train = self._process_data(train_dirs, cam1=True, cam2=True)
query = self._process_data(test_dirs, cam1=True, cam2=False)
gallery = self._process_data(test_dirs, cam1=False, cam2=True)
2018-07-02 17:17:14 +08:00
2018-07-02 18:57:01 +08:00
if verbose:
print("=> PRID2011 loaded")
2018-11-05 06:59:46 +08:00
self.print_dataset_statistics(train, query, gallery)
2018-07-02 17:17:14 +08:00
self.train = train
self.query = query
self.gallery = gallery
2018-11-05 06:59:46 +08:00
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
2018-07-02 17:17:14 +08:00
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
def _process_data(self, dirnames, cam1=True, cam2=True):
tracklets = []
dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)}
for dirname in dirnames:
if cam1:
person_dir = osp.join(self.cam_a_path, dirname)
img_names = glob.glob(osp.join(person_dir, '*.png'))
assert len(img_names) > 0
img_names = tuple(img_names)
pid = dirname2pid[dirname]
tracklets.append((img_names, pid, 0))
if cam2:
person_dir = osp.join(self.cam_b_path, dirname)
img_names = glob.glob(osp.join(person_dir, '*.png'))
assert len(img_names) > 0
img_names = tuple(img_names)
pid = dirname2pid[dirname]
tracklets.append((img_names, pid, 1))
2018-11-05 06:59:46 +08:00
return tracklets