from __future__ import absolute_import import os from PIL import Image from torch.utils.data import Dataset def read_image(img_path): """Keep reading image until succeed. This can avoid IOError incurred by heavy IO process.""" got_img = False while not got_img: try: img = Image.open(img_path).convert('RGB') got_img = True except IOError: print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(fpath)) pass return img class ImageDataset(Dataset): """Image Person ReID Dataset""" def __init__(self, dataset, transform=None): self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, index): img_path, pid, camid = self.dataset[index] img = read_image(img_path) if self.transform is not None: img = self.transform(img) return img, pid, camid class VideoDataset(Dataset): """Video Person ReID Dataset. Note batch data has shape (batch, seq_len, channel, height, width). """ sample_methods = ['evenly', 'random', 'all'] def __init__(self, dataset, seq_len=15, sample='evenly', transform=None): self.dataset = dataset self.seq_len = seq_len self.sample = sample self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, index): img_paths, pid, camid = self.dataset[index] num = len(img_paths) if self.sample == 'random': """ Randomly sample seq_len items from num items, if num is smaller than seq_len, then replicate items """ indices = np.arange(num) replace = False if num >= self.seq_len else True indices = np.random.choice(indices, size=self.seq_len, replace=replace) # sort indices to keep temporal order # comment it to be order-agnostic indices = np.sort(indices) elif self.sample == 'evenly': """Evenly sample seq_len items from num items.""" if num >= self.seq_len: num -= num % self.seq_len indices = np.arange(0, num, num/self.seq_len) else: # if num is smaller than seq_len, simply replicate the last image # until the seq_len requirement is satisfied indices = np.arange(0, num) num_pads = self.seq_len - num indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)]) assert len(indices) == self.seq_len elif self.sample == 'all': """ Sample all items, seq_len is useless now and batch_size needs to be set to 1. """ indices = np.arange(num) else: raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods)) imgs = [] for index in indices: img_path = img_paths[index] img = read_image(img_path) if self.transform is not None: img = self.transform(img) img = img.unsqueeze(0) imgs.append(img) imgs = torch.cat(imgs, dim=0) return imgs, pid, camid