deep-person-reid/torchreid/datasets/prid450s.py

137 lines
4.6 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
2018-07-02 17:17:14 +08:00
import os
import glob
import re
import sys
import urllib
import tarfile
import zipfile
import os.path as osp
from scipy.io import loadmat
import numpy as np
import h5py
from scipy.misc import imsave
2018-08-15 16:48:17 +08:00
from torchreid.utils.iotools import mkdir_if_missing, write_json, read_json
2018-11-05 06:59:46 +08:00
from .bases import BaseImageDataset
2018-07-02 17:17:14 +08:00
2018-11-05 06:59:46 +08:00
class PRID450S(BaseImageDataset):
2019-03-15 22:49:18 +08:00
"""PRID450S
2018-07-02 17:17:14 +08:00
Reference:
Roth et al. Mahalanobis Distance Learning for Person Re-Identification. PR 2014.
URL: https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/prid450s/
Dataset statistics:
# identities: 450
# images: 900
# cameras: 2
"""
dataset_dir = 'prid450s'
2018-08-12 05:22:48 +08:00
def __init__(self, root='data', split_id=0, min_seq_len=0, verbose=True, **kwargs):
super(PRID450S, self).__init__(root)
self.dataset_dir = osp.join(self.root, self.dataset_dir)
2018-07-02 17:17:14 +08:00
self.dataset_url = 'https://files.icg.tugraz.at/f/8c709245bb/?raw=1'
self.split_path = osp.join(self.dataset_dir, 'splits.json')
2019-03-15 22:49:18 +08:00
self.cam_a_dir = osp.join(self.dataset_dir, 'cam_a')
self.cam_b_dir = osp.join(self.dataset_dir, 'cam_b')
2018-07-02 17:17:14 +08:00
2019-02-27 17:57:48 +08:00
self.download_data()
2019-03-15 22:49:18 +08:00
required_files = [
self.dataset_dir,
self.cam_a_dir,
self.cam_b_dir
]
self.check_before_run(required_files)
2018-07-02 17:17:14 +08:00
2019-02-27 17:57:48 +08:00
self.prepare_split()
2018-07-02 17:17:14 +08:00
splits = read_json(self.split_path)
if split_id >= len(splits):
2019-01-31 06:41:47 +08:00
raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1))
2018-07-02 17:17:14 +08:00
split = splits[split_id]
train = split['train']
query = split['query']
gallery = split['gallery']
train = [tuple(item) for item in train]
query = [tuple(item) for item in query]
gallery = [tuple(item) for item in gallery]
2019-03-16 01:45:47 +08:00
self.init_attributes(train, query, gallery, **kwargs)
2019-03-16 00:11:25 +08:00
2018-07-02 18:57:01 +08:00
if verbose:
2019-03-16 01:45:47 +08:00
self.print_dataset_statistics(self.train, self.query, self.gallery)
2018-07-02 17:17:14 +08:00
2019-02-27 17:57:48 +08:00
def download_data(self):
2018-07-02 17:17:14 +08:00
if osp.exists(self.dataset_dir):
return
2019-01-31 06:41:47 +08:00
print('Creating directory {}'.format(self.dataset_dir))
2018-07-02 17:17:14 +08:00
mkdir_if_missing(self.dataset_dir)
fpath = osp.join(self.dataset_dir, 'prid_450s.zip')
2019-01-31 06:41:47 +08:00
print('Downloading PRID450S dataset')
2018-07-02 17:17:14 +08:00
urllib.urlretrieve(self.dataset_url, fpath)
2019-01-31 06:41:47 +08:00
print('Extracting files')
2018-07-02 17:17:14 +08:00
zip_ref = zipfile.ZipFile(fpath, 'r')
zip_ref.extractall(self.dataset_dir)
zip_ref.close()
2019-02-27 17:57:48 +08:00
def prepare_split(self):
2018-07-02 17:17:14 +08:00
if not osp.exists(self.split_path):
2019-03-15 22:49:18 +08:00
cam_a_imgs = sorted(glob.glob(osp.join(self.cam_a_dir, 'img_*.png')))
cam_b_imgs = sorted(glob.glob(osp.join(self.cam_b_dir, 'img_*.png')))
2018-07-02 17:17:14 +08:00
assert len(cam_a_imgs) == len(cam_b_imgs)
num_pids = len(cam_a_imgs)
num_train_pids = num_pids // 2
splits = []
for _ in range(10):
order = np.arange(num_pids)
np.random.shuffle(order)
train_idxs = np.sort(order[:num_train_pids])
idx2label = {idx: label for label, idx in enumerate(train_idxs)}
train, test = [], []
# processing camera a
for img_path in cam_a_imgs:
img_name = osp.basename(img_path)
img_idx = int(img_name.split('_')[1].split('.')[0])
if img_idx in train_idxs:
train.append((img_path, idx2label[img_idx], 0))
else:
test.append((img_path, img_idx, 0))
# processing camera b
for img_path in cam_b_imgs:
img_name = osp.basename(img_path)
img_idx = int(img_name.split('_')[1].split('.')[0])
if img_idx in train_idxs:
train.append((img_path, idx2label[img_idx], 1))
else:
test.append((img_path, img_idx, 1))
2019-03-15 22:49:18 +08:00
split = {
'train': train,
'query': test,
'gallery': test,
'num_train_pids': num_train_pids,
'num_query_pids': num_pids - num_train_pids,
'num_gallery_pids': num_pids - num_train_pids
}
2018-07-02 17:17:14 +08:00
splits.append(split)
2019-01-31 06:41:47 +08:00
print('Totally {} splits are created'.format(len(splits)))
2018-07-02 17:17:14 +08:00
write_json(splits, self.split_path)
2019-03-15 22:49:18 +08:00
print('Split file saved to {}'.format(self.split_path))