deep-person-reid/torchreid/dataset_loader.py

115 lines
3.7 KiB
Python
Raw Normal View History

2018-07-04 10:32:43 +01:00
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
2018-07-02 10:17:14 +01:00
2018-03-11 21:17:48 +00:00
import os
from PIL import Image
2018-03-12 18:38:12 +00:00
import numpy as np
2018-05-01 16:30:10 +01:00
import os.path as osp
2018-07-02 12:54:45 +01:00
import io
2018-03-11 21:17:48 +00:00
2018-03-12 18:39:28 +00:00
import torch
2018-03-11 21:17:48 +00:00
from torch.utils.data import Dataset
2018-07-02 10:17:14 +01:00
2018-03-11 21:17:48 +00:00
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
2018-05-01 16:30:10 +01:00
if not osp.exists(img_path):
2019-01-30 22:41:47 +00:00
raise IOError('{} does not exist'.format(img_path))
2018-03-11 21:17:48 +00:00
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
2019-01-30 22:41:47 +00:00
print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path))
2018-03-11 21:17:48 +00:00
pass
return img
2018-07-02 10:17:14 +01:00
2018-03-11 21:17:48 +00:00
class ImageDataset(Dataset):
"""Image Person ReID Dataset"""
2018-08-11 22:22:48 +01:00
def __init__(self, dataset, transform=None):
2018-03-11 21:17:48 +00:00
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img_path, pid, camid = self.dataset[index]
2018-08-11 22:22:48 +01:00
img = read_image(img_path)
2018-07-02 12:54:45 +01:00
2018-03-11 21:17:48 +00:00
if self.transform is not None:
img = self.transform(img)
2018-07-02 12:54:45 +01:00
return img, pid, camid, img_path
2018-03-11 21:17:48 +00:00
2018-07-02 10:17:14 +01:00
2018-03-11 21:17:48 +00:00
class VideoDataset(Dataset):
"""Video Person ReID Dataset.
Note batch data has shape (batch, seq_len, channel, height, width).
"""
2018-11-08 21:40:44 +00:00
_sample_methods = ['evenly', 'random', 'all']
2018-03-11 21:17:48 +00:00
2018-11-08 21:40:44 +00:00
def __init__(self, dataset, seq_len=15, sample_method='evenly', transform=None):
2018-03-11 21:17:48 +00:00
self.dataset = dataset
self.seq_len = seq_len
2018-11-08 21:40:44 +00:00
self.sample_method = sample_method
2018-03-11 21:17:48 +00:00
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img_paths, pid, camid = self.dataset[index]
num = len(img_paths)
2018-11-08 21:40:44 +00:00
if self.sample_method == 'random':
2018-03-11 21:17:48 +00:00
"""
Randomly sample seq_len items from num items,
if num is smaller than seq_len, then replicate items
"""
indices = np.arange(num)
replace = False if num >= self.seq_len else True
indices = np.random.choice(indices, size=self.seq_len, replace=replace)
2018-07-02 10:17:14 +01:00
# sort indices to keep temporal order (comment it to be order-agnostic)
2018-03-11 21:17:48 +00:00
indices = np.sort(indices)
2018-11-08 21:40:44 +00:00
elif self.sample_method == 'evenly':
2018-07-02 10:17:14 +01:00
"""
Evenly sample seq_len items from num items.
"""
2018-03-11 21:17:48 +00:00
if num >= self.seq_len:
num -= num % self.seq_len
indices = np.arange(0, num, num/self.seq_len)
else:
# if num is smaller than seq_len, simply replicate the last image
# until the seq_len requirement is satisfied
indices = np.arange(0, num)
num_pads = self.seq_len - num
indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)])
assert len(indices) == self.seq_len
2018-11-08 21:40:44 +00:00
elif self.sample_method == 'all':
2018-03-11 21:17:48 +00:00
"""
Sample all items, seq_len is useless now and batch_size needs
to be set to 1.
"""
indices = np.arange(num)
2018-11-08 21:40:44 +00:00
2018-03-11 21:17:48 +00:00
else:
2019-01-30 22:41:47 +00:00
raise ValueError('Unknown sample method: {}. Expected one of {}'.format(self.sample_method, self._sample_methods))
2018-03-11 21:17:48 +00:00
imgs = []
for index in indices:
2018-07-02 11:57:01 +01:00
img_path = img_paths[int(index)]
2018-03-11 21:17:48 +00:00
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0)
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
2019-01-30 22:41:47 +00:00
return imgs, pid, camid