mirror of https://github.com/JDAI-CV/fast-reid.git
184 lines
5.8 KiB
Python
184 lines
5.8 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: sherlock
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import copy
|
|
import logging
|
|
import os
|
|
|
|
from tabulate import tabulate
|
|
from termcolor import colored
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Dataset(object):
|
|
"""An abstract class representing a Dataset.
|
|
This is the base class for ``ImageDataset`` and ``VideoDataset``.
|
|
|
|
Args:
|
|
train (list or Callable): contains tuples of (img_path(s), pid, camid).
|
|
query (list or Callable): contains tuples of (img_path(s), pid, camid).
|
|
gallery (list or Callable): contains tuples of (img_path(s), pid, camid).
|
|
transform: transform function.
|
|
mode (str): 'train', 'query' or 'gallery'.
|
|
combineall (bool): combines train, query and gallery in a
|
|
dataset for training.
|
|
verbose (bool): show information.
|
|
"""
|
|
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
|
|
|
|
def __init__(self, train, query, gallery, transform=None, mode='train',
|
|
combineall=False, verbose=True, **kwargs):
|
|
self._train = train
|
|
self._query = query
|
|
self._gallery = gallery
|
|
self.transform = transform
|
|
self.mode = mode
|
|
self.combineall = combineall
|
|
self.verbose = verbose
|
|
|
|
if self.combineall:
|
|
self.combine_all()
|
|
|
|
if self.mode == 'train':
|
|
self.data = self.train
|
|
elif self.mode == 'query':
|
|
self.data = self.query
|
|
elif self.mode == 'gallery':
|
|
self.data = self.gallery
|
|
else:
|
|
raise ValueError('Invalid mode. Got {}, but expected to be '
|
|
'one of [train | query | gallery]'.format(self.mode))
|
|
|
|
@property
|
|
def train(self):
|
|
if callable(self._train):
|
|
self._train = self._train()
|
|
return self._train
|
|
|
|
@property
|
|
def query(self):
|
|
if callable(self._query):
|
|
self._query = self._query()
|
|
return self._query
|
|
|
|
@property
|
|
def gallery(self):
|
|
if callable(self._gallery):
|
|
self._gallery = self._gallery()
|
|
return self._gallery
|
|
|
|
def __getitem__(self, index):
|
|
raise NotImplementedError
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __radd__(self, other):
|
|
"""Supports sum([dataset1, dataset2, dataset3])."""
|
|
if other == 0:
|
|
return self
|
|
else:
|
|
return self.__add__(other)
|
|
|
|
def parse_data(self, data):
|
|
"""Parses data list and returns the number of person IDs
|
|
and the number of camera views.
|
|
Args:
|
|
data (list): contains tuples of (img_path(s), pid, camid)
|
|
"""
|
|
pids = set()
|
|
cams = set()
|
|
for info in data:
|
|
pids.add(info[1])
|
|
cams.add(info[2])
|
|
return len(pids), len(cams)
|
|
|
|
def get_num_pids(self, data):
|
|
"""Returns the number of training person identities."""
|
|
return self.parse_data(data)[0]
|
|
|
|
def get_num_cams(self, data):
|
|
"""Returns the number of training cameras."""
|
|
return self.parse_data(data)[1]
|
|
|
|
def show_summary(self):
|
|
"""Shows dataset statistics."""
|
|
pass
|
|
|
|
def combine_all(self):
|
|
"""Combines train, query and gallery in a dataset for training."""
|
|
combined = copy.deepcopy(self.train)
|
|
|
|
def _combine_data(data):
|
|
for img_path, pid, camid in data:
|
|
if pid in self._junk_pids:
|
|
continue
|
|
pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid)
|
|
camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid)
|
|
combined.append((img_path, pid, camid))
|
|
|
|
_combine_data(self.query)
|
|
_combine_data(self.gallery)
|
|
|
|
self._train = combined
|
|
|
|
def check_before_run(self, required_files):
|
|
"""Checks if required files exist before going deeper.
|
|
Args:
|
|
required_files (str or list): string file name(s).
|
|
"""
|
|
if isinstance(required_files, str):
|
|
required_files = [required_files]
|
|
|
|
for fpath in required_files:
|
|
if not os.path.exists(fpath):
|
|
raise RuntimeError('"{}" is not found'.format(fpath))
|
|
|
|
|
|
class ImageDataset(Dataset):
|
|
"""A base class representing ImageDataset.
|
|
All other image datasets should subclass it.
|
|
``__getitem__`` returns an image given index.
|
|
It will return ``img``, ``pid``, ``camid`` and ``img_path``
|
|
where ``img`` has shape (channel, height, width). As a result,
|
|
data in each batch has shape (batch_size, channel, height, width).
|
|
"""
|
|
|
|
def show_train(self):
|
|
num_train_pids, num_train_cams = self.parse_data(self.train)
|
|
|
|
headers = ['subset', '# ids', '# images', '# cameras']
|
|
csv_results = [['train', num_train_pids, len(self.train), num_train_cams]]
|
|
|
|
# tabulate it
|
|
table = tabulate(
|
|
csv_results,
|
|
tablefmt="pipe",
|
|
headers=headers,
|
|
numalign="left",
|
|
)
|
|
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
|
|
|
def show_test(self):
|
|
num_query_pids, num_query_cams = self.parse_data(self.query)
|
|
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
|
|
|
headers = ['subset', '# ids', '# images', '# cameras']
|
|
csv_results = [
|
|
['query', num_query_pids, len(self.query), num_query_cams],
|
|
['gallery', num_gallery_pids, len(self.gallery), num_gallery_cams],
|
|
]
|
|
|
|
# tabulate it
|
|
table = tabulate(
|
|
csv_results,
|
|
tablefmt="pipe",
|
|
headers=headers,
|
|
numalign="left",
|
|
)
|
|
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|