fast-reid/fastreid/data/datasets/bases.py

171 lines
5.6 KiB
Python

# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
import os
from tabulate import tabulate
from termcolor import colored
logger = logging.getLogger(__name__)
class Dataset(object):
"""An abstract class representing a Dataset.
This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args:
train (list): contains tuples of (img_path(s), pid, camid).
query (list): contains tuples of (img_path(s), pid, camid).
gallery (list): contains tuples of (img_path(s), pid, camid).
transform: transform function.
mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a
dataset for training.
verbose (bool): show information.
"""
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
def __init__(self, train, query, gallery, transform=None, mode='train',
combineall=False, verbose=True, **kwargs):
self.train = train
self.query = query
self.gallery = gallery
self.transform = transform
self.mode = mode
self.combineall = combineall
self.verbose = verbose
self.num_train_pids = self.get_num_pids(self.train)
self.num_train_cams = self.get_num_cams(self.train)
if self.combineall:
self.combine_all()
if self.mode == 'train':
self.data = self.train
elif self.mode == 'query':
self.data = self.query
elif self.mode == 'gallery':
self.data = self.gallery
else:
raise ValueError('Invalid mode. Got {}, but expected to be '
'one of [train | query | gallery]'.format(self.mode))
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
return len(self.data)
def __radd__(self, other):
"""Supports sum([dataset1, dataset2, dataset3])."""
if other == 0:
return self
else:
return self.__add__(other)
def parse_data(self, data):
"""Parses data list and returns the number of person IDs
and the number of camera views.
Args:
data (list): contains tuples of (img_path(s), pid, camid)
"""
pids = set()
cams = set()
for info in data:
pids.add(info[1])
cams.add(info[2])
return len(pids), len(cams)
def get_num_pids(self, data):
"""Returns the number of training person identities."""
return self.parse_data(data)[0]
def get_num_cams(self, data):
"""Returns the number of training cameras."""
return self.parse_data(data)[1]
def show_summary(self):
"""Shows dataset statistics."""
pass
def combine_all(self):
"""Combines train, query and gallery in a dataset for training."""
combined = copy.deepcopy(self.train)
def _combine_data(data):
for img_path, pid, camid in data:
if pid in self._junk_pids:
continue
pid = self.dataset_name + "_" + str(pid)
camid = self.dataset_name + "_" + str(camid)
combined.append((img_path, pid, camid))
_combine_data(self.query)
_combine_data(self.gallery)
self.train = combined
self.num_train_pids = self.get_num_pids(self.train)
def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
Args:
required_files (str or list): string file name(s).
"""
if isinstance(required_files, str):
required_files = [required_files]
for fpath in required_files:
if not os.path.exists(fpath):
raise RuntimeError('"{}" is not found'.format(fpath))
class ImageDataset(Dataset):
"""A base class representing ImageDataset.
All other image datasets should subclass it.
``__getitem__`` returns an image given index.
It will return ``img``, ``pid``, ``camid`` and ``img_path``
where ``img`` has shape (channel, height, width). As a result,
data in each batch has shape (batch_size, channel, height, width).
"""
def __init__(self, train, query, gallery, **kwargs):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def show_train(self):
num_train_pids, num_train_cams = self.parse_data(self.train)
headers = ['subset', '# ids', '# images', '# cameras']
csv_results = [['train', num_train_pids, len(self.train), num_train_cams]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
def show_test(self):
num_query_pids, num_query_cams = self.parse_data(self.query)
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
headers = ['subset', '# ids', '# images', '# cameras']
csv_results = [
['query', num_query_pids, len(self.query), num_query_cams],
['gallery', num_gallery_pids, len(self.gallery), num_gallery_cams],
]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))