92 lines
3.7 KiB
Python
92 lines
3.7 KiB
Python
|
from __future__ import absolute_import
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
class BaseDataset(object):
|
||
|
"""
|
||
|
Base class of reid dataset
|
||
|
"""
|
||
|
|
||
|
def get_imagedata_info(self, data):
|
||
|
pids, cams = [], []
|
||
|
for _, pid, camid in data:
|
||
|
pids += [pid]
|
||
|
cams += [camid]
|
||
|
pids = set(pids)
|
||
|
cams = set(cams)
|
||
|
num_pids = len(pids)
|
||
|
num_cams = len(cams)
|
||
|
num_imgs = len(data)
|
||
|
return num_pids, num_imgs, num_cams
|
||
|
|
||
|
def get_videodata_info(self, data, return_tracklet_info=False):
|
||
|
pids, cams, tracklet_info = [], [], []
|
||
|
for img_paths, pid, camid in data:
|
||
|
pids += [pid]
|
||
|
cams += [camid]
|
||
|
tracklet_info += [len(img_paths)]
|
||
|
pids = set(pids)
|
||
|
cams = set(cams)
|
||
|
num_pids = len(pids)
|
||
|
num_cams = len(cams)
|
||
|
num_tracklets = len(data)
|
||
|
if return_tracklet_info:
|
||
|
return num_pids, num_tracklets, num_cams, tracklet_info
|
||
|
return num_pids, num_tracklets, num_cams
|
||
|
|
||
|
def print_dataset_statistics(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
class BaseImageDataset(BaseDataset):
|
||
|
"""
|
||
|
Base class of image reid dataset
|
||
|
"""
|
||
|
|
||
|
def print_dataset_statistics(self, train, query, gallery):
|
||
|
num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
|
||
|
num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
|
||
|
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
|
||
|
|
||
|
print("Dataset statistics:")
|
||
|
print(" ----------------------------------------")
|
||
|
print(" subset | # ids | # images | # cameras")
|
||
|
print(" ----------------------------------------")
|
||
|
print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
|
||
|
print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
|
||
|
print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
|
||
|
print(" ----------------------------------------")
|
||
|
|
||
|
|
||
|
class BaseVideoDataset(BaseDataset):
|
||
|
"""
|
||
|
Base class of video reid dataset
|
||
|
"""
|
||
|
|
||
|
def print_dataset_statistics(self, train, query, gallery):
|
||
|
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_info = \
|
||
|
self.get_videodata_info(train, return_tracklet_info=True)
|
||
|
|
||
|
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_info = \
|
||
|
self.get_videodata_info(query, return_tracklet_info=True)
|
||
|
|
||
|
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_info = \
|
||
|
self.get_videodata_info(gallery, return_tracklet_info=True)
|
||
|
|
||
|
tracklet_info = train_tracklet_info + query_tracklet_info + gallery_tracklet_info
|
||
|
min_num = np.min(tracklet_info)
|
||
|
max_num = np.max(tracklet_info)
|
||
|
avg_num = np.mean(tracklet_info)
|
||
|
|
||
|
print("Dataset statistics:")
|
||
|
print(" -------------------------------------------")
|
||
|
print(" subset | # ids | # tracklets | # cameras")
|
||
|
print(" -------------------------------------------")
|
||
|
print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams))
|
||
|
print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams))
|
||
|
print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
|
||
|
print(" -------------------------------------------")
|
||
|
print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num))
|
||
|
print(" -------------------------------------------")
|