deep-person-reid/torchreid/datasets/mars.py

129 lines
5.5 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import re
import sys
import urllib
import tarfile
import zipfile
import os.path as osp
from scipy.io import loadmat
import numpy as np
import h5py
from scipy.misc import imsave
from .bases import BaseVideoDataset
class Mars(BaseVideoDataset):
"""
MARS
Reference:
Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016.
URL: http://www.liangzheng.com.cn/Project/project_mars.html
Dataset statistics:
# identities: 1261
# tracklets: 8298 (train) + 1980 (query) + 9330 (gallery)
# cameras: 6
"""
dataset_dir = 'mars'
def __init__(self, root='data', min_seq_len=0, verbose=True, **kwargs):
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_name_path = osp.join(self.dataset_dir, 'info/train_name.txt')
self.test_name_path = osp.join(self.dataset_dir, 'info/test_name.txt')
self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat')
self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat')
self.query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat')
self._check_before_run()
# prepare meta data
train_names = self._get_names(self.train_name_path)
test_names = self._get_names(self.test_name_path)
track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4)
track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4)
query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,)
query_IDX -= 1 # index from 0
track_query = track_test[query_IDX,:]
gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX]
track_gallery = track_test[gallery_IDX,:]
train = self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len)
query = self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
gallery = self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len)
if verbose:
print("=> MARS loaded")
self.print_dataset_statistics(train, query, gallery)
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train)
self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query)
self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery)
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.train_name_path):
raise RuntimeError("'{}' is not available".format(self.train_name_path))
if not osp.exists(self.test_name_path):
raise RuntimeError("'{}' is not available".format(self.test_name_path))
if not osp.exists(self.track_train_info_path):
raise RuntimeError("'{}' is not available".format(self.track_train_info_path))
if not osp.exists(self.track_test_info_path):
raise RuntimeError("'{}' is not available".format(self.track_test_info_path))
if not osp.exists(self.query_IDX_path):
raise RuntimeError("'{}' is not available".format(self.query_IDX_path))
def _get_names(self, fpath):
names = []
with open(fpath, 'r') as f:
for line in f:
new_line = line.rstrip()
names.append(new_line)
return names
def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0):
assert home_dir in ['bbox_train', 'bbox_test']
num_tracklets = meta_data.shape[0]
pid_list = list(set(meta_data[:,2].tolist()))
num_pids = len(pid_list)
if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)}
tracklets = []
for tracklet_idx in range(num_tracklets):
data = meta_data[tracklet_idx,...]
start_index, end_index, pid, camid = data
if pid == -1: continue # junk images are just ignored
assert 1 <= camid <= 6
if relabel: pid = pid2label[pid]
camid -= 1 # index starts from 0
img_names = names[start_index - 1:end_index]
# make sure image names correspond to the same person
pnames = [img_name[:4] for img_name in img_names]
assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images"
# make sure all images are captured under the same camera
camnames = [img_name[5] for img_name in img_names]
assert len(set(camnames)) == 1, "Error: images are captured under different cameras!"
# append image names with directory information
img_paths = [osp.join(self.dataset_dir, home_dir, img_name[:4], img_name) for img_name in img_names]
if len(img_paths) >= min_seq_len:
img_paths = tuple(img_paths)
tracklets.append((img_paths, pid, camid))
return tracklets