diff --git a/data_manager.py b/data_manager.py index cc728d6..02e8914 100644 --- a/data_manager.py +++ b/data_manager.py @@ -1,9 +1,11 @@ -from __future__ import absolute_import +from __future__ import print_function, absolute_import import os import glob import re import sys import os.path as osp +from scipy.io import loadmat +import numpy as np """Dataset classes""" @@ -13,10 +15,10 @@ class Market1501(object): Reference: Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. - ========================== + Dataset statistics: # identities: 1501 (+1 for background) - # images: 12936 (train) + 3368 (query) + 15913 (gallery) = + # images: 12936 (train) + 3368 (query) + 15913 (gallery) """ root = './data/market1501' train_dir = osp.join(root, 'bounding_box_train') @@ -56,7 +58,6 @@ class Market1501(object): self.num_gallery_pids = num_gallery_pids def _process_dir(self, dir_path, relabel=False): - print("Processing directory '{}'".format(dir_path)) img_paths = glob.glob(osp.join(dir_path, '*.jpg')) pattern = re.compile(r'([-\d]+)_c(\d)') @@ -86,10 +87,128 @@ class Market1501(object): print("Error: '{}' is not available.".format(dir_path)) sys.exit() +class Mars(object): + """ + MARS + + Reference: + Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. + + Dataset statistics: + # identities: 1261 + # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery) + + Args: + min_seq_len (int): tracklet with length shorter than this value will be discarded. + """ + root = './data/mars' + train_name_path = osp.join(root, 'info/train_name.txt') + test_name_path = osp.join(root, 'info/test_name.txt') + track_train_info_path = osp.join(root, 'info/tracks_train_info.mat') + track_test_info_path = osp.join(root, 'info/tracks_test_info.mat') + query_IDX_path = osp.join(root, 'info/query_IDX.mat') + + def __init__(self, min_seq_len=0): + # prepare meta data + train_names = self._get_names(self.train_name_path) + test_names = self._get_names(self.test_name_path) + track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) + track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) + query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) + query_IDX -= 1 # index from 0 + track_query = track_test[query_IDX,:] + gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] + track_gallery = track_test[gallery_IDX,:] + + train, num_train_tracklets, num_train_pids, num_train_imgs = \ + self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) + + query, num_query_tracklets, num_query_pids, num_query_imgs = \ + self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) + + gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ + self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) + + num_imgs_per_tracklet = num_train_imgs + num_query_imgs + num_gallery_imgs + min_num = np.min(num_imgs_per_tracklet) + max_num = np.max(num_imgs_per_tracklet) + avg_num = np.mean(num_imgs_per_tracklet) + + num_total_pids = num_train_pids + num_query_pids + num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets + + print("=> MARS loaded") + print("Dataset statistics:") + print(" ------------------------------") + print(" subset | # ids | # tracklets") + print(" ------------------------------") + print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) + print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) + print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) + print(" ------------------------------") + print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) + print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) + print(" ------------------------------") + + self.train = train + self.query = query + self.gallery = gallery + + self.num_train_pids = num_train_pids + self.num_query_pids = num_query_pids + self.num_gallery_pids = num_gallery_pids + + def _get_names(self, fpath): + names = [] + with open(fpath, 'r') as f: + for line in f: + new_line = line.rstrip() + names.append(new_line) + return names + + def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): + assert home_dir in ['bbox_train', 'bbox_test'] + num_tracklets = meta_data.shape[0] + pid_list = list(set(meta_data[:,2].tolist())) + num_pids = len(pid_list) + + if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} + tracklets = [] + num_imgs_per_tracklet = [] + + for tracklet_idx in range(num_tracklets): + data = meta_data[tracklet_idx,...] + start_index, end_index, pid, camid = data + if pid == -1: continue # junk images are just ignored + assert 1 <= camid <= 6 + if relabel: pid = pid2label[pid] + camid -= 1 # index starts from 0 + img_names = names[start_index-1:end_index] + + # make sure image names correspond to the same person + pnames = [img_name[:4] for img_name in img_names] + assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" + + # make sure all images are captured under the same camera + camnames = [img_name[5] for img_name in img_names] + assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" + + # append image names with directory information + img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] + if len(img_paths) >= min_seq_len: + img_paths = tuple(img_paths) + tracklets.append((img_paths, pid, camid)) + num_imgs_per_tracklet.append(len(img_paths)) + + num_tracklets = len(tracklets) + + return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet + """Create dataset""" __factory = { 'market1501': Market1501, + 'mars': Mars, } def get_names(): @@ -101,4 +220,12 @@ def init_dataset(name, *args, **kwargs): return __factory[name](*args, **kwargs) if __name__ == '__main__': - dataset = Market1501() \ No newline at end of file + #dataset = Market1501() + dataset = Mars() + + + + + + +