# encoding: utf-8 """ @author: sherlock @contact: sherlockliao01@gmail.com """ import copy import logging import os from tabulate import tabulate from termcolor import colored logger = logging.getLogger(__name__) class Dataset(object): """An abstract class representing a Dataset. This is the base class for ``ImageDataset`` and ``VideoDataset``. Args: train (list or Callable): contains tuples of (img_path(s), pid, camid). query (list or Callable): contains tuples of (img_path(s), pid, camid). gallery (list or Callable): contains tuples of (img_path(s), pid, camid). transform: transform function. mode (str): 'train', 'query' or 'gallery'. combineall (bool): combines train, query and gallery in a dataset for training. verbose (bool): show information. """ _junk_pids = [] # contains useless person IDs, e.g. background, false detections def __init__(self, train, query, gallery, transform=None, mode='train', combineall=False, verbose=True, **kwargs): self._train = train self._query = query self._gallery = gallery self.transform = transform self.mode = mode self.combineall = combineall self.verbose = verbose if self.combineall: self.combine_all() if self.mode == 'train': self.data = self.train elif self.mode == 'query': self.data = self.query elif self.mode == 'gallery': self.data = self.gallery else: raise ValueError('Invalid mode. Got {}, but expected to be ' 'one of [train | query | gallery]'.format(self.mode)) @property def train(self): if callable(self._train): self._train = self._train() return self._train @property def query(self): if callable(self._query): self._query = self._query() return self._query @property def gallery(self): if callable(self._gallery): self._gallery = self._gallery() return self._gallery def __getitem__(self, index): raise NotImplementedError def __len__(self): return len(self.data) def __radd__(self, other): """Supports sum([dataset1, dataset2, dataset3]).""" if other == 0: return self else: return self.__add__(other) def parse_data(self, data): """Parses data list and returns the number of person IDs and the number of camera views. Args: data (list): contains tuples of (img_path(s), pid, camid) """ pids = set() cams = set() for info in data: pids.add(info[1]) cams.add(info[2]) return len(pids), len(cams) def get_num_pids(self, data): """Returns the number of training person identities.""" return self.parse_data(data)[0] def get_num_cams(self, data): """Returns the number of training cameras.""" return self.parse_data(data)[1] def show_summary(self): """Shows dataset statistics.""" pass def combine_all(self): """Combines train, query and gallery in a dataset for training.""" combined = copy.deepcopy(self.train) def _combine_data(data): for img_path, pid, camid in data: if pid in self._junk_pids: continue pid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(pid) camid = getattr(self, "dataset_name", "Unknown") + "_test_" + str(camid) combined.append((img_path, pid, camid)) _combine_data(self.query) _combine_data(self.gallery) self._train = combined def check_before_run(self, required_files): """Checks if required files exist before going deeper. Args: required_files (str or list): string file name(s). """ if isinstance(required_files, str): required_files = [required_files] for fpath in required_files: if not os.path.exists(fpath): raise RuntimeError('"{}" is not found'.format(fpath)) class ImageDataset(Dataset): """A base class representing ImageDataset. All other image datasets should subclass it. ``__getitem__`` returns an image given index. It will return ``img``, ``pid``, ``camid`` and ``img_path`` where ``img`` has shape (channel, height, width). As a result, data in each batch has shape (batch_size, channel, height, width). """ def show_train(self): num_train_pids, num_train_cams = self.parse_data(self.train) headers = ['subset', '# ids', '# images', '# cameras'] csv_results = [['train', num_train_pids, len(self.train), num_train_cams]] # tabulate it table = tabulate( csv_results, tablefmt="pipe", headers=headers, numalign="left", ) logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) def show_test(self): num_query_pids, num_query_cams = self.parse_data(self.query) num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery) headers = ['subset', '# ids', '# images', '# cameras'] csv_results = [ ['query', num_query_pids, len(self.query), num_query_cams], ['gallery', num_gallery_pids, len(self.gallery), num_gallery_cams], ] # tabulate it table = tabulate( csv_results, tablefmt="pipe", headers=headers, numalign="left", ) logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))