266 lines
9.6 KiB
Python
266 lines
9.6 KiB
Python
|
from __future__ import absolute_import
|
||
|
from __future__ import print_function
|
||
|
from __future__ import division
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from torchreid.data.sampler import build_train_sampler
|
||
|
from torchreid.data.transforms import build_transforms
|
||
|
from torchreid.data.datasets import init_image_dataset, init_video_dataset
|
||
|
|
||
|
|
||
|
class DataManager(object):
|
||
|
|
||
|
def __init__(self, sources, targets=None, height=256, width=128, random_erase=False,
|
||
|
color_jitter=False, color_aug=False, use_cpu=False):
|
||
|
self.sources = sources
|
||
|
self.targets = targets
|
||
|
|
||
|
if isinstance(self.sources, str):
|
||
|
self.sources = [self.sources]
|
||
|
|
||
|
if isinstance(self.targets, str):
|
||
|
self.targets = [self.targets]
|
||
|
|
||
|
if self.targets is None:
|
||
|
self.targets = self.sources
|
||
|
|
||
|
self.transform_tr, self.transform_te = build_transforms(
|
||
|
height, width,
|
||
|
random_erase=random_erase,
|
||
|
color_jitter=color_jitter,
|
||
|
color_aug=color_aug
|
||
|
)
|
||
|
|
||
|
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
|
||
|
|
||
|
@property
|
||
|
def num_train_pids(self):
|
||
|
return self._num_train_pids
|
||
|
|
||
|
@property
|
||
|
def num_train_cams(self):
|
||
|
return self._num_train_cams
|
||
|
|
||
|
def return_dataloaders(self):
|
||
|
return self.trainloader, self.testloader
|
||
|
|
||
|
def return_testdataset_by_name(self, name):
|
||
|
return self.testdataset[name]['query'], self.testdataset[name]['gallery']
|
||
|
|
||
|
|
||
|
class ImageDataManager(DataManager):
|
||
|
|
||
|
def __init__(self, root, sources, targets=None, height=256, width=128, random_erase=False,
|
||
|
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
|
||
|
batch_size=32, workers=4, num_instances=4, train_sampler=None,
|
||
|
cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False):
|
||
|
|
||
|
super(ImageDataManager, self).__init__(sources, targets=targets, height=height, width=width,
|
||
|
random_erase=random_erase, color_jitter=color_jitter,
|
||
|
color_aug=color_aug, use_cpu=use_cpu)
|
||
|
|
||
|
print('=> Loading train (source) dataset')
|
||
|
trainset = []
|
||
|
for name in self.sources:
|
||
|
trainset_ = init_image_dataset(
|
||
|
name,
|
||
|
transform=self.transform_tr,
|
||
|
mode='train',
|
||
|
combineall=combineall,
|
||
|
root=root,
|
||
|
split_id=split_id,
|
||
|
cuhk03_labeled=cuhk03_labeled,
|
||
|
cuhk03_classic_split=cuhk03_classic_split,
|
||
|
market1501_500k=market1501_500k
|
||
|
)
|
||
|
trainset.append(trainset_)
|
||
|
trainset = sum(trainset)
|
||
|
|
||
|
self._num_train_pids = trainset.num_train_pids
|
||
|
self._num_train_cams = trainset.num_train_cams
|
||
|
|
||
|
train_sampler = build_train_sampler(
|
||
|
trainset.train, train_sampler,
|
||
|
batch_size=batch_size,
|
||
|
num_instances=num_instances
|
||
|
)
|
||
|
|
||
|
self.trainloader = torch.utils.data.DataLoader(
|
||
|
trainset,
|
||
|
sampler=train_sampler,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=workers,
|
||
|
pin_memory=self.use_gpu,
|
||
|
drop_last=True
|
||
|
)
|
||
|
|
||
|
print('=> Loading test (target) dataset')
|
||
|
self.testloader = {name: {'query': None, 'gallery': None} for name in self.targets}
|
||
|
self.testdataset = {name: {'query': None, 'gallery': None} for name in self.targets}
|
||
|
|
||
|
for name in self.targets:
|
||
|
# build query loader
|
||
|
queryset = init_image_dataset(
|
||
|
name,
|
||
|
transform=self.transform_te,
|
||
|
mode='query',
|
||
|
combineall=combineall,
|
||
|
root=root,
|
||
|
split_id=split_id,
|
||
|
cuhk03_labeled=cuhk03_labeled,
|
||
|
cuhk03_classic_split=cuhk03_classic_split,
|
||
|
market1501_500k=market1501_500k
|
||
|
)
|
||
|
self.testloader[name]['query'] = torch.utils.data.DataLoader(
|
||
|
queryset,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=workers,
|
||
|
pin_memory=self.use_gpu,
|
||
|
drop_last=False
|
||
|
)
|
||
|
|
||
|
# build gallery loader
|
||
|
galleryset = init_image_dataset(
|
||
|
name,
|
||
|
transform=self.transform_te,
|
||
|
mode='gallery',
|
||
|
combineall=combineall,
|
||
|
verbose=False,
|
||
|
root=root,
|
||
|
split_id=split_id,
|
||
|
cuhk03_labeled=cuhk03_labeled,
|
||
|
cuhk03_classic_split=cuhk03_classic_split,
|
||
|
market1501_500k=market1501_500k
|
||
|
)
|
||
|
self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
|
||
|
galleryset,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=workers,
|
||
|
pin_memory=self.use_gpu,
|
||
|
drop_last=False
|
||
|
)
|
||
|
|
||
|
self.testdataset[name]['query'] = queryset.query
|
||
|
self.testdataset[name]['gallery'] = galleryset.gallery
|
||
|
|
||
|
print('\n')
|
||
|
print(' **************** Summary ****************')
|
||
|
print(' train : {}'.format(self.sources))
|
||
|
print(' # train datasets : {}'.format(len(self.sources)))
|
||
|
print(' # train ids : {}'.format(self.num_train_pids))
|
||
|
print(' # train images : {}'.format(len(trainset)))
|
||
|
print(' # train cameras : {}'.format(self.num_train_cams))
|
||
|
print(' test : {}'.format(self.targets))
|
||
|
print(' *****************************************')
|
||
|
print('\n')
|
||
|
|
||
|
|
||
|
class VideoDataManager(DataManager):
|
||
|
|
||
|
def __init__(self, root, sources, targets=None, height=256, width=128, random_erase=False,
|
||
|
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
|
||
|
batch_size=32, workers=4, num_instances=4, train_sampler=None,
|
||
|
seq_len=15, sample_method='evenly'):
|
||
|
|
||
|
super(VideoDataManager, self).__init__(sources, targets=targets, height=height, width=width,
|
||
|
random_erase=random_erase, color_jitter=color_jitter,
|
||
|
color_aug=color_aug, use_cpu=use_cpu)
|
||
|
|
||
|
print('=> Loading train (source) dataset')
|
||
|
trainset = []
|
||
|
for name in self.sources:
|
||
|
trainset_ = init_video_dataset(
|
||
|
name,
|
||
|
transform=self.transform_tr,
|
||
|
mode='train',
|
||
|
combineall=combineall,
|
||
|
root=root,
|
||
|
split_id=split_id,
|
||
|
seq_len=seq_len,
|
||
|
sample_method=sample_method
|
||
|
)
|
||
|
trainset.append(trainset_)
|
||
|
trainset = sum(trainset)
|
||
|
|
||
|
self._num_train_pids = trainset.num_train_pids
|
||
|
self._num_train_cams = trainset.num_train_cams
|
||
|
|
||
|
train_sampler = build_train_sampler(
|
||
|
trainset.train, train_sampler,
|
||
|
batch_size=batch_size,
|
||
|
num_instances=num_instances
|
||
|
)
|
||
|
|
||
|
self.trainloader = torch.utils.data.DataLoader(
|
||
|
trainset,
|
||
|
sampler=train_sampler,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=workers,
|
||
|
pin_memory=self.use_gpu,
|
||
|
drop_last=True
|
||
|
)
|
||
|
|
||
|
print('=> Loading test (target) dataset')
|
||
|
self.testloader = {name: {'query': None, 'gallery': None} for name in self.targets}
|
||
|
self.testdataset = {name: {'query': None, 'gallery': None} for name in self.targets}
|
||
|
|
||
|
for name in self.targets:
|
||
|
# build query loader
|
||
|
queryset = init_video_dataset(
|
||
|
name,
|
||
|
transform=self.transform_te,
|
||
|
mode='query',
|
||
|
combineall=combineall,
|
||
|
root=root,
|
||
|
split_id=split_id,
|
||
|
seq_len=seq_len,
|
||
|
sample_method=sample_method
|
||
|
)
|
||
|
self.testloader[name]['query'] = torch.utils.data.DataLoader(
|
||
|
queryset,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=workers,
|
||
|
pin_memory=self.use_gpu,
|
||
|
drop_last=False
|
||
|
)
|
||
|
|
||
|
# build gallery loader
|
||
|
galleryset = init_video_dataset(
|
||
|
name,
|
||
|
transform=self.transform_te,
|
||
|
mode='gallery',
|
||
|
combineall=combineall,
|
||
|
verbose=False,
|
||
|
root=root,
|
||
|
split_id=split_id,
|
||
|
seq_len=seq_len,
|
||
|
sample_method=sample_method
|
||
|
)
|
||
|
self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
|
||
|
galleryset,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=workers,
|
||
|
pin_memory=self.use_gpu,
|
||
|
drop_last=False
|
||
|
)
|
||
|
|
||
|
self.testdataset[name]['query'] = queryset.query
|
||
|
self.testdataset[name]['gallery'] = galleryset.gallery
|
||
|
|
||
|
print('\n')
|
||
|
print(' **************** Summary ****************')
|
||
|
print(' train : {}'.format(self.sources))
|
||
|
print(' # train datasets : {}'.format(len(self.sources)))
|
||
|
print(' # train ids : {}'.format(self.num_train_pids))
|
||
|
print(' # train tracklets : {}'.format(len(trainset)))
|
||
|
print(' # train cameras : {}'.format(self.num_train_cams))
|
||
|
print(' test : {}'.format(self.targets))
|
||
|
print(' *****************************************')
|
||
|
print('\n')
|