deep-person-reid/torchreid/data_manager.py

260 lines
11 KiB
Python
Raw Normal View History

2018-11-05 21:19:27 +00:00
from __future__ import absolute_import
from __future__ import print_function
from torch.utils.data import DataLoader
2018-11-07 15:36:49 +00:00
from .dataset_loader import ImageDataset, VideoDataset
2018-11-05 21:19:27 +00:00
from .datasets import init_imgreid_dataset, init_vidreid_dataset
2018-11-07 17:09:23 +00:00
from .transforms import build_transforms
2019-01-28 18:50:09 +00:00
from .samplers import build_train_sampler
2018-11-05 21:19:27 +00:00
class BaseDataManager(object):
2019-02-19 15:47:20 +00:00
def __init__(self,
use_gpu,
source_names,
target_names,
root='data',
split_id=0,
height=256,
width=128,
train_batch_size=32,
test_batch_size=100,
workers=4,
train_sampler='',
random_erase=False, # use random erasing for data augmentation
color_jitter=False, # randomly change the brightness, contrast and saturation
color_aug=False, # randomly alter the intensities of RGB channels
num_instances=4, # number of instances per identity (for RandomIdentitySampler)
**kwargs
):
self.use_gpu = use_gpu
self.source_names = source_names
self.target_names = target_names
self.root = root
self.split_id = split_id
self.height = height
self.width = width
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.workers = workers
self.train_sampler = train_sampler
self.random_erase = random_erase
self.color_jitter = color_jitter
self.color_aug = color_aug
self.num_instances = num_instances
transform_train, transform_test = build_transforms(
self.height, self.width, random_erase=self.random_erase, color_jitter=self.color_jitter,
color_aug=self.color_aug
)
self.transform_train = transform_train
self.transform_test = transform_test
2018-11-07 21:46:23 +00:00
@property
def num_train_pids(self):
return self._num_train_pids
@property
def num_train_cams(self):
return self._num_train_cams
def return_dataloaders(self):
2018-11-07 21:46:23 +00:00
"""
Return trainloader and testloader dictionary
"""
return self.trainloader, self.testloader_dict
2018-11-07 21:46:23 +00:00
def return_testdataset_by_name(self, name):
"""
Return query and gallery, each containing a list of (img_path, pid, camid).
"""
return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery']
class ImageDataManager(BaseDataManager):
"""
Image-ReID data manager
"""
2018-11-05 21:19:27 +00:00
2018-11-07 17:09:23 +00:00
def __init__(self,
use_gpu,
source_names,
target_names,
2018-11-08 22:02:53 +00:00
cuhk03_labeled=False, # use cuhk03's labeled or detected images
2019-01-22 16:07:43 +00:00
cuhk03_classic_split=False, # use cuhk03's classic split or 767/700 split
market1501_500k=False, # add 500k distractors to the gallery set for market1501
2019-02-19 15:47:20 +00:00
**kwargs
2018-11-07 17:09:23 +00:00
):
2019-02-19 15:47:20 +00:00
super(ImageDataManager, self).__init__(use_gpu, source_names, target_names, **kwargs)
self.cuhk03_labeled = cuhk03_labeled
self.cuhk03_classic_split = cuhk03_classic_split
self.market1501_500k = market1501_500k
2018-11-05 21:19:27 +00:00
2019-01-30 22:41:47 +00:00
print('=> Initializing TRAIN (source) datasets')
2019-01-28 18:50:09 +00:00
train = []
2018-11-07 21:46:23 +00:00
self._num_train_pids = 0
self._num_train_cams = 0
2018-11-05 21:19:27 +00:00
2019-02-19 15:47:20 +00:00
for name in self.source_names:
2018-11-07 17:09:23 +00:00
dataset = init_imgreid_dataset(
2019-02-19 15:47:20 +00:00
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
cuhk03_classic_split=self.cuhk03_classic_split, market1501_500k=self.market1501_500k
2018-11-07 17:09:23 +00:00
)
2018-11-05 21:19:27 +00:00
for img_path, pid, camid in dataset.train:
2018-11-07 21:46:23 +00:00
pid += self._num_train_pids
camid += self._num_train_cams
2019-01-28 18:50:09 +00:00
train.append((img_path, pid, camid))
2018-11-05 21:19:27 +00:00
2018-11-07 21:46:23 +00:00
self._num_train_pids += dataset.num_train_pids
self._num_train_cams += dataset.num_train_cams
2018-11-05 21:19:27 +00:00
2019-02-19 15:47:20 +00:00
self.train_sampler = build_train_sampler(
train, self.train_sampler,
train_batch_size=self.train_batch_size,
num_instances=self.num_instances,
2019-01-28 18:50:09 +00:00
)
self.trainloader = DataLoader(
2019-02-19 15:47:20 +00:00
ImageDataset(train, transform=self.transform_train), sampler=self.train_sampler,
batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.use_gpu, drop_last=True
2019-01-28 18:50:09 +00:00
)
2018-11-05 21:19:27 +00:00
2019-01-30 22:41:47 +00:00
print('=> Initializing TEST (target) datasets')
2019-01-28 18:50:09 +00:00
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
2018-11-07 21:46:23 +00:00
2019-02-19 15:47:20 +00:00
for name in self.target_names:
2018-11-07 17:09:23 +00:00
dataset = init_imgreid_dataset(
2019-02-19 15:47:20 +00:00
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
cuhk03_classic_split=self.cuhk03_classic_split, market1501_500k=self.market1501_500k
2018-11-07 17:09:23 +00:00
)
2018-11-05 21:19:27 +00:00
self.testloader_dict[name]['query'] = DataLoader(
2019-02-19 15:47:20 +00:00
ImageDataset(dataset.query, transform=self.transform_test),
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.use_gpu, drop_last=False
2018-11-05 21:19:27 +00:00
)
self.testloader_dict[name]['gallery'] = DataLoader(
2019-02-19 15:47:20 +00:00
ImageDataset(dataset.gallery, transform=self.transform_test),
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.use_gpu, drop_last=False
2018-11-05 21:19:27 +00:00
)
2018-11-07 21:46:23 +00:00
self.testdataset_dict[name]['query'] = dataset.query
self.testdataset_dict[name]['gallery'] = dataset.gallery
2019-01-30 22:41:47 +00:00
print('\n')
print(' **************** Summary ****************')
2019-02-19 15:47:20 +00:00
print(' train names : {}'.format(self.source_names))
print(' # train datasets : {}'.format(len(self.source_names)))
2019-01-30 22:41:47 +00:00
print(' # train ids : {}'.format(self.num_train_pids))
print(' # train images : {}'.format(len(train)))
print(' # train cameras : {}'.format(self.num_train_cams))
2019-02-19 15:47:20 +00:00
print(' test names : {}'.format(self.target_names))
2019-01-30 22:41:47 +00:00
print(' *****************************************')
print('\n')
2018-11-07 15:36:49 +00:00
class VideoDataManager(BaseDataManager):
"""
Video-ReID data manager
"""
2018-11-07 15:36:49 +00:00
2018-11-07 17:09:23 +00:00
def __init__(self,
use_gpu,
source_names,
target_names,
2018-11-07 17:09:23 +00:00
seq_len=15,
sample_method='evenly',
2019-02-19 15:47:20 +00:00
image_training=True, # train the video-reid model with images rather than tracklets
**kwargs
2018-11-07 17:09:23 +00:00
):
2019-02-19 15:47:20 +00:00
super(VideoDataManager, self).__init__(use_gpu, source_names, target_names, **kwargs)
self.seq_len = seq_len
self.sample_method = sample_method
self.image_training = image_training
2018-11-07 15:36:49 +00:00
2019-01-30 22:41:47 +00:00
print('=> Initializing TRAIN (source) datasets')
2019-01-28 18:50:09 +00:00
train = []
2018-11-07 21:46:23 +00:00
self._num_train_pids = 0
self._num_train_cams = 0
2018-11-07 15:36:49 +00:00
2019-02-19 15:47:20 +00:00
for name in self.source_names:
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
2018-11-07 15:36:49 +00:00
for img_paths, pid, camid in dataset.train:
2018-11-07 21:46:23 +00:00
pid += self._num_train_pids
camid += self._num_train_cams
2019-01-28 18:50:09 +00:00
if image_training:
2018-11-07 15:36:49 +00:00
# decompose tracklets into images
for img_path in img_paths:
2019-01-28 18:50:09 +00:00
train.append((img_path, pid, camid))
2018-11-07 15:36:49 +00:00
else:
2019-01-28 18:50:09 +00:00
train.append((img_paths, pid, camid))
2018-11-07 15:36:49 +00:00
2018-11-07 21:46:23 +00:00
self._num_train_pids += dataset.num_train_pids
self._num_train_cams += dataset.num_train_cams
2018-11-07 15:36:49 +00:00
2019-02-19 15:47:20 +00:00
self.train_sampler = build_train_sampler(
train, self.train_sampler,
train_batch_size=self.train_batch_size,
num_instances=self.num_instances,
2019-01-28 18:50:09 +00:00
)
if image_training:
2018-11-07 15:36:49 +00:00
# each batch has image data of shape (batch, channel, height, width)
self.trainloader = DataLoader(
2019-02-19 15:47:20 +00:00
ImageDataset(train, transform=self.transform_train), sampler=self.train_sampler,
batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.use_gpu, drop_last=True
2018-11-07 15:36:49 +00:00
)
2019-01-28 18:50:09 +00:00
2018-11-07 15:36:49 +00:00
else:
# each batch has image data of shape (batch, seq_len, channel, height, width)
# note: this requires new training scripts
2018-11-07 15:36:49 +00:00
self.trainloader = DataLoader(
2019-02-19 15:47:20 +00:00
VideoDataset(train, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_train),
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
pin_memory=self.use_gpu, drop_last=True
2018-11-07 15:36:49 +00:00
)
2019-01-30 22:41:47 +00:00
print('=> Initializing TEST (target) datasets')
2019-01-28 18:50:09 +00:00
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
2019-02-19 15:47:20 +00:00
for name in self.target_names:
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
2018-11-07 15:36:49 +00:00
self.testloader_dict[name]['query'] = DataLoader(
2019-02-19 15:47:20 +00:00
VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_test),
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.use_gpu, drop_last=False,
2018-11-07 15:36:49 +00:00
)
self.testloader_dict[name]['gallery'] = DataLoader(
2019-02-19 15:47:20 +00:00
VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=self.transform_test),
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.use_gpu, drop_last=False,
2018-11-07 15:36:49 +00:00
)
2018-11-07 21:46:23 +00:00
self.testdataset_dict[name]['query'] = dataset.query
self.testdataset_dict[name]['gallery'] = dataset.gallery
2019-01-30 22:41:47 +00:00
print('\n')
print(' **************** Summary ****************')
2019-02-19 15:47:20 +00:00
print(' train names : {}'.format(self.source_names))
print(' # train datasets : {}'.format(len(self.source_names)))
2019-01-30 22:41:47 +00:00
print(' # train ids : {}'.format(self.num_train_pids))
2019-02-19 15:47:20 +00:00
if self.image_training:
2019-01-30 22:41:47 +00:00
print(' # train images : {}'.format(len(train)))
2018-11-07 15:36:49 +00:00
else:
2019-01-30 22:41:47 +00:00
print(' # train tracklets: {}'.format(len(train)))
print(' # train cameras : {}'.format(self.num_train_cams))
2019-02-19 15:47:20 +00:00
print(' test names : {}'.format(self.target_names))
2019-01-30 22:41:47 +00:00
print(' *****************************************')
print('\n')