from __future__ import absolute_import from __future__ import print_function import numpy as np class BaseDataset(object): """ Base class of reid dataset """ def get_imagedata_info(self, data): pids, cams = [], [] for _, pid, camid in data: pids += [pid] cams += [camid] pids = set(pids) cams = set(cams) num_pids = len(pids) num_cams = len(cams) num_imgs = len(data) return num_pids, num_imgs, num_cams def get_videodata_info(self, data, return_tracklet_info=False): pids, cams, tracklet_info = [], [], [] for img_paths, pid, camid in data: pids += [pid] cams += [camid] tracklet_info += [len(img_paths)] pids = set(pids) cams = set(cams) num_pids = len(pids) num_cams = len(cams) num_tracklets = len(data) if return_tracklet_info: return num_pids, num_tracklets, num_cams, tracklet_info return num_pids, num_tracklets, num_cams def print_dataset_statistics(self): raise NotImplementedError class BaseImageDataset(BaseDataset): """ Base class of image reid dataset """ def print_dataset_statistics(self, train, query, gallery): num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) print("Dataset statistics:") print(" ----------------------------------------") print(" subset | # ids | # images | # cameras") print(" ----------------------------------------") print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) print(" ----------------------------------------") class BaseVideoDataset(BaseDataset): """ Base class of video reid dataset """ def print_dataset_statistics(self, train, query, gallery): num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_info = \ self.get_videodata_info(train, return_tracklet_info=True) num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_info = \ self.get_videodata_info(query, return_tracklet_info=True) num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_info = \ self.get_videodata_info(gallery, return_tracklet_info=True) tracklet_info = train_tracklet_info + query_tracklet_info + gallery_tracklet_info min_num = np.min(tracklet_info) max_num = np.max(tracklet_info) avg_num = np.mean(tracklet_info) print("Dataset statistics:") print(" -------------------------------------------") print(" subset | # ids | # tracklets | # cameras") print(" -------------------------------------------") print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) print(" -------------------------------------------") print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) print(" -------------------------------------------")