fast-reid/fastreid/data/datasets/bases.py

184 lines
5.8 KiB
Python

# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
import os
from tabulate import tabulate
from termcolor import colored
logger = logging.getLogger(__name__)
class Dataset(object):
"""An abstract class representing a Dataset.
This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args:
train (list or Callable): contains tuples of (img_path(s), pid, camid).
query (list or Callable): contains tuples of (img_path(s), pid, camid).
gallery (list or Callable): contains tuples of (img_path(s), pid, camid).
transform: transform function.
mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a
dataset for training.
verbose (bool): show information.
"""
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
def __init__(self, train, query, gallery, transform=None, mode='train',
combineall=False, verbose=True, **kwargs):
self._train = train
self._query = query
self._gallery = gallery
self.transform = transform
self.mode = mode
self.combineall = combineall
self.verbose = verbose
if self.combineall:
self.combine_all()
if self.mode == 'train':
self.data = self.train
elif self.mode == 'query':
self.data = self.query
elif self.mode == 'gallery':
self.data = self.gallery
else:
raise ValueError('Invalid mode. Got {}, but expected to be '
'one of [train | query | gallery]'.format(self.mode))
@property
def train(self):
if callable(self._train):
self._train = self._train()
return self._train
@property
def query(self):
if callable(self._query):
self._query = self._query()
return self._query
@property
def gallery(self):
if callable(self._gallery):
self._gallery = self._gallery()
return self._gallery
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
return len(self.data)
def __radd__(self, other):
"""Supports sum([dataset1, dataset2, dataset3])."""
if other == 0:
return self
else:
return self.__add__(other)
def parse_data(self, data):
"""Parses data list and returns the number of person IDs
and the number of camera views.
Args:
data (list): contains tuples of (img_path(s), pid, camid)
"""
pids = set()
cams = set()
for info in data:
pids.add(info[1])
cams.add(info[2])
return len(pids), len(cams)
def get_num_pids(self, data):
"""Returns the number of training person identities."""
return self.parse_data(data)[0]
def get_num_cams(self, data):
"""Returns the number of training cameras."""
return self.parse_data(data)[1]
def show_summary(self):
"""Shows dataset statistics."""
pass
def combine_all(self):
"""Combines train, query and gallery in a dataset for training."""
combined = copy.deepcopy(self.train)
def _combine_data(data):
for img_path, pid, camid in data:
if pid in self._junk_pids:
continue
pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid)
camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid)
combined.append((img_path, pid, camid))
_combine_data(self.query)
_combine_data(self.gallery)
self._train = combined
def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
Args:
required_files (str or list): string file name(s).
"""
if isinstance(required_files, str):
required_files = [required_files]
for fpath in required_files:
if not os.path.exists(fpath):
raise RuntimeError('"{}" is not found'.format(fpath))
class ImageDataset(Dataset):
"""A base class representing ImageDataset.
All other image datasets should subclass it.
``__getitem__`` returns an image given index.
It will return ``img``, ``pid``, ``camid`` and ``img_path``
where ``img`` has shape (channel, height, width). As a result,
data in each batch has shape (batch_size, channel, height, width).
"""
def show_train(self):
num_train_pids, num_train_cams = self.parse_data(self.train)
headers = ['subset', '# ids', '# images', '# cameras']
csv_results = [['train', num_train_pids, len(self.train), num_train_cams]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
def show_test(self):
num_query_pids, num_query_cams = self.parse_data(self.query)
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
headers = ['subset', '# ids', '# images', '# cameras']
csv_results = [
['query', num_query_pids, len(self.query), num_query_cams],
['gallery', num_gallery_pids, len(self.gallery), num_gallery_cams],
]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))