# encoding: utf-8 """ @author: sherlock @contact: sherlockliao01@gmail.com """ import copy import os import numpy as np import torch class Dataset(object): """An abstract class representing a Dataset. This is the base class for ``ImageDataset`` and ``VideoDataset``. Args: train (list): contains tuples of (img_path(s), pid, camid). query (list): contains tuples of (img_path(s), pid, camid). gallery (list): contains tuples of (img_path(s), pid, camid). transform: transform function. mode (str): 'train', 'query' or 'gallery'. combineall (bool): combines train, query and gallery in a dataset for training. verbose (bool): show information. """ _junk_pids = [] # contains useless person IDs, e.g. background, false detections def __init__(self, train, query, gallery, transform=None, mode='train', combineall=False, verbose=True, **kwargs): self.train = train self.query = query self.gallery = gallery self.transform = transform self.mode = mode self.combineall = combineall self.verbose = verbose self.num_train_pids = self.get_num_pids(self.train) self.num_train_cams = self.get_num_cams(self.train) if self.combineall: self.combine_all() if self.mode == 'train': self.data = self.train elif self.mode == 'query': self.data = self.query elif self.mode == 'gallery': self.data = self.gallery else: raise ValueError('Invalid mode. Got {}, but expected to be ' 'one of [train | query | gallery]'.format(self.mode)) # if self.verbose: # self.show_summary() def __getitem__(self, index): raise NotImplementedError def __len__(self): return len(self.data) # def __add__(self, other): # """Adds two datasets together (only the train set).""" # train = copy.deepcopy(self.train) # # for img_path, pid, camid in other.train: # pid += self.num_train_pids # camid += self.num_train_cams # train.append((img_path, pid, camid)) # # ################################### # # Things to do beforehand: # # 1. set verbose=False to avoid unnecessary print # # 2. set combineall=False because combineall would have been applied # # if it was True for a specific dataset, setting it to True will # # create new IDs that should have been included # ################################### # if isinstance(train[0][0], str): # return ImageDataset( # train, self.query, self.gallery, # transform=self.transform, # mode=self.mode, # combineall=False, # verbose=False # ) # else: # return VideoDataset( # train, self.query, self.gallery, # transform=self.transform, # mode=self.mode, # combineall=False, # verbose=False # ) def __radd__(self, other): """Supports sum([dataset1, dataset2, dataset3]).""" if other == 0: return self else: return self.__add__(other) def parse_data(self, data): """Parses data list and returns the number of person IDs and the number of camera views. Args: data (list): contains tuples of (img_path(s), pid, camid) """ pids = set() cams = set() for _, pid, camid in data: pids.add(pid) cams.add(camid) return len(pids), len(cams) def get_num_pids(self, data): """Returns the number of training person identities.""" return self.parse_data(data)[0] def get_num_cams(self, data): """Returns the number of training cameras.""" return self.parse_data(data)[1] def show_summary(self): """Shows dataset statistics.""" pass def combine_all(self): """Combines train, query and gallery in a dataset for training.""" combined = copy.deepcopy(self.train) # relabel pids in gallery (query shares the same scope) g_pids = set() for _, pid, _ in self.gallery: if pid in self._junk_pids: continue g_pids.add(pid) pid2label = {pid: i for i, pid in enumerate(g_pids)} def _combine_data(data): for img_path, pid, camid in data: if pid in self._junk_pids: continue pid = pid2label[pid] + self.num_train_pids combined.append((img_path, pid, camid)) _combine_data(self.query) _combine_data(self.gallery) self.train = combined self.num_train_pids = self.get_num_pids(self.train) def check_before_run(self, required_files): """Checks if required files exist before going deeper. Args: required_files (str or list): string file name(s). """ if isinstance(required_files, str): required_files = [required_files] for fpath in required_files: if not os.path.exists(fpath): raise RuntimeError('"{}" is not found'.format(fpath)) def __repr__(self): num_train_pids, num_train_cams = self.parse_data(self.train) num_query_pids, num_query_cams = self.parse_data(self.query) num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery) msg = ' ----------------------------------------\n' \ ' subset | # ids | # items | # cameras\n' \ ' ----------------------------------------\n' \ ' train | {:5d} | {:7d} | {:9d}\n' \ ' query | {:5d} | {:7d} | {:9d}\n' \ ' gallery | {:5d} | {:7d} | {:9d}\n' \ ' ----------------------------------------\n' \ ' items: images/tracklets for image/video dataset\n'.format( num_train_pids, len(self.train), num_train_cams, num_query_pids, len(self.query), num_query_cams, num_gallery_pids, len(self.gallery), num_gallery_cams ) return msg class ImageDataset(Dataset): """A base class representing ImageDataset. All other image datasets should subclass it. ``__getitem__`` returns an image given index. It will return ``img``, ``pid``, ``camid`` and ``img_path`` where ``img`` has shape (channel, height, width). As a result, data in each batch has shape (batch_size, channel, height, width). """ def __init__(self, train, query, gallery, **kwargs): super(ImageDataset, self).__init__(train, query, gallery, **kwargs) def show_summary(self): num_train_pids, num_train_cams = self.parse_data(self.train) num_query_pids, num_query_cams = self.parse_data(self.query) num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery) print('=> Loaded {}'.format(self.__class__.__name__)) print(' ----------------------------------------') print(' subset | # ids | # images | # cameras') print(' ----------------------------------------') print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams)) print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams)) print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams)) print(' ----------------------------------------') # class VideoDataset(Dataset): # """A base class representing VideoDataset. # All other video datasets should subclass it. # ``__getitem__`` returns an image given index. # It will return ``imgs``, ``pid`` and ``camid`` # where ``imgs`` has shape (seq_len, channel, height, width). As a result, # data in each batch has shape (batch_size, seq_len, channel, height, width). # """ # # def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs): # super(VideoDataset, self).__init__(train, query, gallery, **kwargs) # self.seq_len = seq_len # self.sample_method = sample_method # # if self.transform is None: # raise RuntimeError('transform must not be None') # # def __getitem__(self, index): # img_paths, pid, camid = self.data[index] # num_imgs = len(img_paths) # # if self.sample_method == 'random': # # Randomly samples seq_len images from a tracklet of length num_imgs, # # if num_imgs is smaller than seq_len, then replicates images # indices = np.arange(num_imgs) # replace = False if num_imgs >= self.seq_len else True # indices = np.random.choice(indices, size=self.seq_len, replace=replace) # # sort indices to keep temporal order (comment it to be order-agnostic) # indices = np.sort(indices) # # elif self.sample_method == 'evenly': # # Evenly samples seq_len images from a tracklet # if num_imgs >= self.seq_len: # num_imgs -= num_imgs % self.seq_len # indices = np.arange(0, num_imgs, num_imgs / self.seq_len) # else: # # if num_imgs is smaller than seq_len, simply replicate the last image # # until the seq_len requirement is satisfied # indices = np.arange(0, num_imgs) # num_pads = self.seq_len - num_imgs # indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32) * (num_imgs - 1)]) # assert len(indices) == self.seq_len # # elif self.sample_method == 'all': # # Samples all images in a tracklet. batch_size must be set to 1 # indices = np.arange(num_imgs) # # else: # raise ValueError('Unknown sample method: {}'.format(self.sample_method)) # # imgs = [] # for index in indices: # img_path = img_paths[int(index)] # img = read_image(img_path) # if self.transform is not None: # img = self.transform(img) # img = img.unsqueeze(0) # img must be torch.Tensor # imgs.append(img) # imgs = torch.cat(imgs, dim=0) # # return imgs, pid, camid # # def show_summary(self): # num_train_pids, num_train_cams = self.parse_data(self.train) # num_query_pids, num_query_cams = self.parse_data(self.query) # num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery) # # print('=> Loaded {}'.format(self.__class__.__name__)) # print(' -------------------------------------------') # print(' subset | # ids | # tracklets | # cameras') # print(' -------------------------------------------') # print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams)) # print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams)) # print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams)) # print(' -------------------------------------------')