deep-person-reid/torchreid/data/datamanager.py

266 lines
9.6 KiB
Python
Raw Normal View History

2019-03-21 20:59:54 +08:00
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
from torchreid.data.sampler import build_train_sampler
from torchreid.data.transforms import build_transforms
from torchreid.data.datasets import init_image_dataset, init_video_dataset
class DataManager(object):
def __init__(self, sources, targets=None, height=256, width=128, random_erase=False,
color_jitter=False, color_aug=False, use_cpu=False):
self.sources = sources
self.targets = targets
if isinstance(self.sources, str):
self.sources = [self.sources]
if isinstance(self.targets, str):
self.targets = [self.targets]
if self.targets is None:
self.targets = self.sources
self.transform_tr, self.transform_te = build_transforms(
height, width,
random_erase=random_erase,
color_jitter=color_jitter,
color_aug=color_aug
)
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
@property
def num_train_pids(self):
return self._num_train_pids
@property
def num_train_cams(self):
return self._num_train_cams
def return_dataloaders(self):
return self.trainloader, self.testloader
def return_testdataset_by_name(self, name):
return self.testdataset[name]['query'], self.testdataset[name]['gallery']
class ImageDataManager(DataManager):
def __init__(self, root, sources, targets=None, height=256, width=128, random_erase=False,
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
batch_size=32, workers=4, num_instances=4, train_sampler=None,
cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False):
super(ImageDataManager, self).__init__(sources, targets=targets, height=height, width=width,
random_erase=random_erase, color_jitter=color_jitter,
color_aug=color_aug, use_cpu=use_cpu)
print('=> Loading train (source) dataset')
trainset = []
for name in self.sources:
trainset_ = init_image_dataset(
name,
transform=self.transform_tr,
mode='train',
combineall=combineall,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
trainset.append(trainset_)
trainset = sum(trainset)
self._num_train_pids = trainset.num_train_pids
self._num_train_cams = trainset.num_train_cams
train_sampler = build_train_sampler(
trainset.train, train_sampler,
batch_size=batch_size,
num_instances=num_instances
)
self.trainloader = torch.utils.data.DataLoader(
trainset,
sampler=train_sampler,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=True
)
print('=> Loading test (target) dataset')
self.testloader = {name: {'query': None, 'gallery': None} for name in self.targets}
self.testdataset = {name: {'query': None, 'gallery': None} for name in self.targets}
for name in self.targets:
# build query loader
queryset = init_image_dataset(
name,
transform=self.transform_te,
mode='query',
combineall=combineall,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
self.testloader[name]['query'] = torch.utils.data.DataLoader(
queryset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
# build gallery loader
galleryset = init_image_dataset(
name,
transform=self.transform_te,
mode='gallery',
combineall=combineall,
verbose=False,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
galleryset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
self.testdataset[name]['query'] = queryset.query
self.testdataset[name]['gallery'] = galleryset.gallery
print('\n')
print(' **************** Summary ****************')
print(' train : {}'.format(self.sources))
print(' # train datasets : {}'.format(len(self.sources)))
print(' # train ids : {}'.format(self.num_train_pids))
print(' # train images : {}'.format(len(trainset)))
print(' # train cameras : {}'.format(self.num_train_cams))
print(' test : {}'.format(self.targets))
print(' *****************************************')
print('\n')
class VideoDataManager(DataManager):
def __init__(self, root, sources, targets=None, height=256, width=128, random_erase=False,
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
batch_size=32, workers=4, num_instances=4, train_sampler=None,
seq_len=15, sample_method='evenly'):
super(VideoDataManager, self).__init__(sources, targets=targets, height=height, width=width,
random_erase=random_erase, color_jitter=color_jitter,
color_aug=color_aug, use_cpu=use_cpu)
print('=> Loading train (source) dataset')
trainset = []
for name in self.sources:
trainset_ = init_video_dataset(
name,
transform=self.transform_tr,
mode='train',
combineall=combineall,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
trainset.append(trainset_)
trainset = sum(trainset)
self._num_train_pids = trainset.num_train_pids
self._num_train_cams = trainset.num_train_cams
train_sampler = build_train_sampler(
trainset.train, train_sampler,
batch_size=batch_size,
num_instances=num_instances
)
self.trainloader = torch.utils.data.DataLoader(
trainset,
sampler=train_sampler,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=True
)
print('=> Loading test (target) dataset')
self.testloader = {name: {'query': None, 'gallery': None} for name in self.targets}
self.testdataset = {name: {'query': None, 'gallery': None} for name in self.targets}
for name in self.targets:
# build query loader
queryset = init_video_dataset(
name,
transform=self.transform_te,
mode='query',
combineall=combineall,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
self.testloader[name]['query'] = torch.utils.data.DataLoader(
queryset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
# build gallery loader
galleryset = init_video_dataset(
name,
transform=self.transform_te,
mode='gallery',
combineall=combineall,
verbose=False,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
galleryset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
self.testdataset[name]['query'] = queryset.query
self.testdataset[name]['gallery'] = galleryset.gallery
print('\n')
print(' **************** Summary ****************')
print(' train : {}'.format(self.sources))
print(' # train datasets : {}'.format(len(self.sources)))
print(' # train ids : {}'.format(self.num_train_pids))
print(' # train tracklets : {}'.format(len(trainset)))
print(' # train cameras : {}'.format(self.num_train_cams))
print(' test : {}'.format(self.targets))
print(' *****************************************')
print('\n')