104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
from __future__ import print_function, absolute_import
|
|
import os
|
|
from PIL import Image
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
def read_image(img_path):
|
|
"""Keep reading image until succeed.
|
|
This can avoid IOError incurred by heavy IO process."""
|
|
got_img = False
|
|
while not got_img:
|
|
try:
|
|
img = Image.open(img_path).convert('RGB')
|
|
got_img = True
|
|
except IOError:
|
|
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(fpath))
|
|
pass
|
|
return img
|
|
|
|
class ImageDataset(Dataset):
|
|
"""Image Person ReID Dataset"""
|
|
def __init__(self, dataset, transform=None):
|
|
self.dataset = dataset
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, index):
|
|
img_path, pid, camid = self.dataset[index]
|
|
img = read_image(img_path)
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, pid, camid
|
|
|
|
class VideoDataset(Dataset):
|
|
"""Video Person ReID Dataset.
|
|
Note batch data has shape (batch, seq_len, channel, height, width).
|
|
"""
|
|
sample_methods = ['evenly', 'random', 'all']
|
|
|
|
def __init__(self, dataset, seq_len=15, sample='evenly', transform=None):
|
|
self.dataset = dataset
|
|
self.seq_len = seq_len
|
|
self.sample = sample
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, index):
|
|
img_paths, pid, camid = self.dataset[index]
|
|
num = len(img_paths)
|
|
|
|
if self.sample == 'random':
|
|
"""
|
|
Randomly sample seq_len items from num items,
|
|
if num is smaller than seq_len, then replicate items
|
|
"""
|
|
indices = np.arange(num)
|
|
replace = False if num >= self.seq_len else True
|
|
indices = np.random.choice(indices, size=self.seq_len, replace=replace)
|
|
# sort indices to keep temporal order
|
|
# comment it to be order-agnostic
|
|
indices = np.sort(indices)
|
|
elif self.sample == 'evenly':
|
|
"""Evenly sample seq_len items from num items."""
|
|
if num >= self.seq_len:
|
|
num -= num % self.seq_len
|
|
indices = np.arange(0, num, num/self.seq_len)
|
|
else:
|
|
# if num is smaller than seq_len, simply replicate the last image
|
|
# until the seq_len requirement is satisfied
|
|
indices = np.arange(0, num)
|
|
num_pads = self.seq_len - num
|
|
indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)])
|
|
assert len(indices) == self.seq_len
|
|
elif self.sample == 'all':
|
|
"""
|
|
Sample all items, seq_len is useless now and batch_size needs
|
|
to be set to 1.
|
|
"""
|
|
indices = np.arange(num)
|
|
else:
|
|
raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods))
|
|
|
|
imgs = []
|
|
for index in indices:
|
|
img_path = img_paths[index]
|
|
img = read_image(img_path)
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
img = img.unsqueeze(0)
|
|
imgs.append(img)
|
|
imgs = torch.cat(imgs, dim=0)
|
|
|
|
return imgs, pid, camid
|
|
|
|
|
|
|
|
|
|
|
|
|