192 lines
7.0 KiB
Python
192 lines
7.0 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import os.path as osp
|
|
import numpy as np
|
|
import copy
|
|
|
|
|
|
class BaseDataset(object):
|
|
"""Base class of reid dataset"""
|
|
|
|
def __init__(self, root):
|
|
self.root = osp.expanduser(root)
|
|
|
|
def check_before_run(self, required_files):
|
|
"""Check if required files exist before going deeper"""
|
|
for f in required_files:
|
|
if not osp.exists(f):
|
|
raise RuntimeError('"{}" is not found'.format(f))
|
|
|
|
def extract_data_info(self, data):
|
|
"""Extract info from data list
|
|
|
|
Args:
|
|
data (list): contains a list of (img_path, pid, camid)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def get_num_pids(self, data):
|
|
return self.extract_data_info(data)[0]
|
|
|
|
def get_num_cams(self, data):
|
|
return self.extract_data_info(data)[2]
|
|
|
|
def init_attributes(self, train, query, gallery, combineall=False, **kwargs):
|
|
"""Initialize class attributes
|
|
|
|
Args:
|
|
train (list): contains a list of (img_path, pid, camid)
|
|
query (list): contains a list of (img_path, pid, camid)
|
|
gallery (list): contains a list of (img_path, pid, camid)
|
|
combineall (bool): if set to True, combine all data for training, default is False
|
|
|
|
Notes:
|
|
This method has to be called (at the end) in each dataset class.
|
|
"""
|
|
self._train = train
|
|
self._query = query
|
|
self._gallery = gallery
|
|
self._num_train_pids = self.get_num_pids(train)
|
|
self._num_train_cams = self.get_num_cams(train)
|
|
|
|
if combineall:
|
|
self._train = self.combine_all(train, query, gallery)
|
|
self._num_train_pids = self.get_num_pids(self.train)
|
|
|
|
def combine_all(self, train, query, gallery):
|
|
"""Combine all data for training
|
|
|
|
Notes:
|
|
1. In general, we assume that all query identities appear in gallery set.
|
|
2. All pids in train have been relabeled (starts from 0)
|
|
3. pid=0 (background) and pid=-1 (junk) are discarded.
|
|
4. Camera views remain the same across train, query and gallery.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def train(self):
|
|
# train list containing (img_path, pid, camid)
|
|
return self._train
|
|
|
|
@property
|
|
def query(self):
|
|
# query list containing (img_path, pid, camid)
|
|
return self._query
|
|
|
|
@property
|
|
def gallery(self):
|
|
# gallery list containing (img_path, pid, camid)
|
|
return self._gallery
|
|
|
|
@property
|
|
def num_train_pids(self):
|
|
# number of train identities
|
|
return self._num_train_pids
|
|
|
|
@property
|
|
def num_train_cams(self):
|
|
# number of train camera views
|
|
return self._num_train_cams
|
|
|
|
def print_dataset_statistics(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class BaseImageDataset(BaseDataset):
|
|
"""Base class of image-reid dataset"""
|
|
|
|
def extract_data_info(self, data):
|
|
pids = set()
|
|
cams = set()
|
|
for _, pid, camid in data:
|
|
pids.add(pid)
|
|
cams.add(camid)
|
|
num_pids = len(pids)
|
|
num_cams = len(cams)
|
|
num_imgs = len(data)
|
|
return num_pids, num_imgs, num_cams
|
|
|
|
def combine_all(self, train, query, gallery):
|
|
combined = copy.deepcopy(train)
|
|
|
|
# relabel pids in gallery
|
|
g_pids = set()
|
|
for _, pid, _ in gallery:
|
|
if pid==0 or pid==-1:
|
|
continue
|
|
g_pids.add(pid)
|
|
pid2label = {pid: i for i, pid in enumerate(g_pids)}
|
|
|
|
def _combine_data(data):
|
|
for img_path, pid, camid in data:
|
|
if pid==0 or pid==-1:
|
|
continue
|
|
pid = pid2label[pid] + self.num_train_pids
|
|
combined.append((img_path, pid, camid))
|
|
|
|
_combine_data(query)
|
|
_combine_data(gallery)
|
|
|
|
return combined
|
|
|
|
def print_dataset_statistics(self, train, query, gallery):
|
|
num_train_pids, num_train_imgs, num_train_cams = self.extract_data_info(train)
|
|
num_query_pids, num_query_imgs, num_query_cams = self.extract_data_info(query)
|
|
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.extract_data_info(gallery)
|
|
|
|
print('=> Loaded {}'.format(self.__class__.__name__))
|
|
print(' ----------------------------------------')
|
|
print(' subset | # ids | # images | # cameras')
|
|
print(' ----------------------------------------')
|
|
print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, num_train_imgs, num_train_cams))
|
|
print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, num_query_imgs, num_query_cams))
|
|
print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
|
|
print(' ----------------------------------------')
|
|
|
|
|
|
class BaseVideoDataset(BaseDataset):
|
|
"""Base class of video-reid dataset"""
|
|
|
|
def extract_data_info(self, data, return_tracklet_stats=False):
|
|
pids = set()
|
|
cams = set()
|
|
tracklet_stats = []
|
|
for img_paths, pid, camid in data:
|
|
pids.add(pid)
|
|
cams.add(camid)
|
|
tracklet_stats += [len(img_paths)]
|
|
num_pids = len(pids)
|
|
num_cams = len(cams)
|
|
num_tracklets = len(data)
|
|
if return_tracklet_stats:
|
|
return num_pids, num_tracklets, num_cams, tracklet_stats
|
|
return num_pids, num_tracklets, num_cams
|
|
|
|
def print_dataset_statistics(self, train, query, gallery):
|
|
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
|
|
self.extract_data_info(train, return_tracklet_stats=True)
|
|
|
|
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
|
|
self.extract_data_info(query, return_tracklet_stats=True)
|
|
|
|
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
|
|
self.extract_data_info(gallery, return_tracklet_stats=True)
|
|
|
|
tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
|
|
min_num = np.min(tracklet_stats)
|
|
max_num = np.max(tracklet_stats)
|
|
avg_num = np.mean(tracklet_stats)
|
|
|
|
print('=> Loaded {}'.format(self.__class__.__name__))
|
|
print(' -------------------------------------------')
|
|
print(' subset | # ids | # tracklets | # cameras')
|
|
print(' -------------------------------------------')
|
|
print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, num_train_tracklets, num_train_cams))
|
|
print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, num_query_tracklets, num_query_cams))
|
|
print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
|
|
print(' -------------------------------------------')
|
|
print(' number of images per tracklet: {} ~ {}, average {:.2f}'.format(min_num, max_num, avg_num))
|
|
print(' -------------------------------------------') |