deep-person-reid/torchreid/data/datamanager.py

530 lines
18 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
from __future__ import division, print_function, absolute_import
2019-03-21 20:59:54 +08:00
import torch
from torchreid.data.sampler import build_train_sampler
from torchreid.data.datasets import init_image_dataset, init_video_dataset
2019-12-01 10:35:44 +08:00
from torchreid.data.transforms import build_transforms
2019-03-21 20:59:54 +08:00
class DataManager(object):
r"""Base data manager.
2019-03-21 20:59:54 +08:00
2019-03-22 01:28:14 +08:00
Args:
sources (str or list): source dataset(s).
targets (str or list, optional): target dataset(s). If not given,
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
2019-08-26 17:34:31 +08:00
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
norm_std (list or None, optional): data std. Default is None (use imagenet std).
use_gpu (bool, optional): use gpu. Default is True.
2019-03-22 01:28:14 +08:00
"""
2019-12-01 10:35:44 +08:00
def __init__(
self,
sources=None,
targets=None,
height=256,
width=128,
transforms='random_flip',
norm_mean=None,
norm_std=None,
use_gpu=False
):
2019-03-21 20:59:54 +08:00
self.sources = sources
self.targets = targets
self.height = height
self.width = width
2019-03-21 20:59:54 +08:00
2019-03-22 01:28:14 +08:00
if self.sources is None:
raise ValueError('sources must not be None')
2019-03-21 20:59:54 +08:00
if isinstance(self.sources, str):
self.sources = [self.sources]
if self.targets is None:
self.targets = self.sources
2019-03-22 01:28:14 +08:00
if isinstance(self.targets, str):
self.targets = [self.targets]
2019-03-21 20:59:54 +08:00
self.transform_tr, self.transform_te = build_transforms(
2019-12-01 10:35:44 +08:00
self.height,
self.width,
transforms=transforms,
norm_mean=norm_mean,
norm_std=norm_std
2019-03-21 20:59:54 +08:00
)
2019-08-26 17:34:31 +08:00
self.use_gpu = (torch.cuda.is_available() and use_gpu)
2019-03-21 20:59:54 +08:00
@property
def num_train_pids(self):
2019-03-22 01:28:14 +08:00
"""Returns the number of training person identities."""
2019-03-21 20:59:54 +08:00
return self._num_train_pids
@property
def num_train_cams(self):
2019-03-22 01:28:14 +08:00
"""Returns the number of training cameras."""
2019-03-21 20:59:54 +08:00
return self._num_train_cams
def fetch_test_loaders(self, name):
2019-03-22 01:28:14 +08:00
"""Returns query and gallery of a test dataset, each containing
tuples of (img_path(s), pid, camid).
Args:
name (str): dataset name.
"""
query_loader = self.test_dataset[name]['query']
gallery_loader = self.test_dataset[name]['gallery']
return query_loader, gallery_loader
2019-11-28 00:35:54 +08:00
def preprocess_pil_img(self, img):
"""Transforms a PIL image to torch tensor for testing."""
return self.transform_te(img)
2019-03-21 20:59:54 +08:00
class ImageDataManager(DataManager):
r"""Image data manager.
2019-03-21 20:59:54 +08:00
2019-03-22 01:28:14 +08:00
Args:
root (str): root path to datasets.
sources (str or list): source dataset(s).
targets (str or list, optional): target dataset(s). If not given,
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
2019-08-26 17:34:31 +08:00
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
norm_std (list or None, optional): data std. Default is None (use imagenet std).
use_gpu (bool, optional): use gpu. Default is True.
2019-03-22 01:28:14 +08:00
split_id (int, optional): split id (*0-based*). Default is 0.
combineall (bool, optional): combine train, query and gallery in a dataset for
training. Default is False.
load_train_targets (bool, optional): construct train-loader for target datasets.
Default is False. This is useful for domain adaptation research.
2019-08-26 17:34:31 +08:00
batch_size_train (int, optional): number of images in a training batch. Default is 32.
batch_size_test (int, optional): number of images in a test batch. Default is 32.
2019-03-22 01:28:14 +08:00
workers (int, optional): number of workers. Default is 4.
num_instances (int, optional): number of instances per identity in a batch.
Default is 4.
2019-11-28 00:49:29 +08:00
train_sampler (str, optional): sampler. Default is RandomSampler.
train_sampler_t (str, optional): sampler for target train loader. Default is RandomSampler.
2019-03-22 01:28:14 +08:00
cuhk03_labeled (bool, optional): use cuhk03 labeled images.
Default is False (defaul is to use detected images).
cuhk03_classic_split (bool, optional): use the classic split in cuhk03.
Default is False.
market1501_500k (bool, optional): add 500K distractors to the gallery
set in market1501. Default is False.
Examples::
datamanager = torchreid.data.ImageDataManager(
2019-03-24 07:09:39 +08:00
root='path/to/reid-data',
2019-03-22 01:28:14 +08:00
sources='market1501',
height=256,
width=128,
2019-08-26 17:34:31 +08:00
batch_size_train=32,
batch_size_test=100
2019-03-22 01:28:14 +08:00
)
2019-11-28 00:35:54 +08:00
# return train loader of source data
train_loader = datamanager.train_loader
# return test loader of target data
test_loader = datamanager.test_loader
# return train loader of target data
train_loader_t = datamanager.train_loader_t
2019-03-22 01:28:14 +08:00
"""
data_type = 'image'
2019-03-22 01:28:14 +08:00
2019-11-28 00:35:54 +08:00
def __init__(
self,
root='',
sources=None,
targets=None,
height=256,
width=128,
transforms='random_flip',
norm_mean=None,
norm_std=None,
use_gpu=True,
split_id=0,
combineall=False,
load_train_targets=False,
2019-11-28 00:35:54 +08:00
batch_size_train=32,
batch_size_test=32,
workers=4,
num_instances=4,
2019-11-28 00:49:29 +08:00
train_sampler='RandomSampler',
train_sampler_t='RandomSampler',
2019-11-28 00:35:54 +08:00
cuhk03_labeled=False,
cuhk03_classic_split=False,
market1501_500k=False
):
2019-12-01 10:35:44 +08:00
2019-11-28 00:35:54 +08:00
super(ImageDataManager, self).__init__(
2019-12-01 10:35:44 +08:00
sources=sources,
targets=targets,
height=height,
width=width,
transforms=transforms,
norm_mean=norm_mean,
norm_std=norm_std,
2019-11-28 00:35:54 +08:00
use_gpu=use_gpu
)
2019-12-01 10:35:44 +08:00
2019-03-21 20:59:54 +08:00
print('=> Loading train (source) dataset')
2019-11-28 00:35:54 +08:00
trainset = []
2019-03-21 20:59:54 +08:00
for name in self.sources:
trainset_ = init_image_dataset(
name,
transform=self.transform_tr,
mode='train',
combineall=combineall,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
trainset.append(trainset_)
trainset = sum(trainset)
self._num_train_pids = trainset.num_train_pids
self._num_train_cams = trainset.num_train_cams
2019-11-28 00:35:54 +08:00
self.train_loader = torch.utils.data.DataLoader(
2019-03-21 20:59:54 +08:00
trainset,
2019-11-28 02:19:06 +08:00
sampler=build_train_sampler(
2019-12-01 10:35:44 +08:00
trainset.train,
train_sampler,
2019-11-28 02:19:06 +08:00
batch_size=batch_size_train,
num_instances=num_instances
),
2019-08-26 17:34:31 +08:00
batch_size=batch_size_train,
2019-03-21 20:59:54 +08:00
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=True
)
self.train_loader_t = None
if load_train_targets:
2019-11-28 00:35:54 +08:00
# check if sources and targets are identical
assert len(set(self.sources) & set(self.targets)) == 0, \
2019-11-28 02:19:06 +08:00
'sources={} and targets={} must not have overlap'.format(self.sources, self.targets)
2019-11-28 00:35:54 +08:00
print('=> Loading train (target) dataset')
trainset_t = []
2019-11-28 00:35:54 +08:00
for name in self.targets:
trainset_t_ = init_image_dataset(
2019-11-28 00:35:54 +08:00
name,
transform=self.transform_tr,
mode='train',
combineall=False, # only use the training data
2019-11-28 00:35:54 +08:00
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
trainset_t.append(trainset_t_)
trainset_t = sum(trainset_t)
2019-11-28 00:35:54 +08:00
self.train_loader_t = torch.utils.data.DataLoader(
trainset_t,
2019-11-28 02:19:06 +08:00
sampler=build_train_sampler(
2019-12-01 10:35:44 +08:00
trainset_t.train,
train_sampler_t,
2019-11-28 02:19:06 +08:00
batch_size=batch_size_train,
num_instances=num_instances
),
2019-11-28 00:35:54 +08:00
batch_size=batch_size_train,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=True
)
2019-03-21 20:59:54 +08:00
print('=> Loading test (target) dataset')
2019-12-01 10:35:44 +08:00
self.test_loader = {
name: {
'query': None,
'gallery': None
}
for name in self.targets
}
self.test_dataset = {
name: {
'query': None,
'gallery': None
}
for name in self.targets
}
2019-03-21 20:59:54 +08:00
for name in self.targets:
# build query loader
queryset = init_image_dataset(
name,
transform=self.transform_te,
mode='query',
combineall=combineall,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
2019-11-28 00:35:54 +08:00
self.test_loader[name]['query'] = torch.utils.data.DataLoader(
2019-03-21 20:59:54 +08:00
queryset,
2019-08-26 17:34:31 +08:00
batch_size=batch_size_test,
2019-03-21 20:59:54 +08:00
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
# build gallery loader
galleryset = init_image_dataset(
name,
transform=self.transform_te,
mode='gallery',
combineall=combineall,
verbose=False,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
2019-11-28 00:35:54 +08:00
self.test_loader[name]['gallery'] = torch.utils.data.DataLoader(
2019-03-21 20:59:54 +08:00
galleryset,
2019-08-26 17:34:31 +08:00
batch_size=batch_size_test,
2019-03-21 20:59:54 +08:00
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
2019-11-28 00:35:54 +08:00
self.test_dataset[name]['query'] = queryset.query
self.test_dataset[name]['gallery'] = galleryset.gallery
2019-03-21 20:59:54 +08:00
print('\n')
print(' **************** Summary ****************')
2019-11-28 02:19:06 +08:00
print(' source : {}'.format(self.sources))
print(' # source datasets : {}'.format(len(self.sources)))
print(' # source ids : {}'.format(self.num_train_pids))
print(' # source images : {}'.format(len(trainset)))
print(' # source cameras : {}'.format(self.num_train_cams))
if load_train_targets:
2019-12-01 10:35:44 +08:00
print(
' # target images : {} (unlabeled)'.format(len(trainset_t))
)
2019-11-28 02:19:06 +08:00
print(' target : {}'.format(self.targets))
2019-03-21 20:59:54 +08:00
print(' *****************************************')
print('\n')
class VideoDataManager(DataManager):
r"""Video data manager.
2019-03-22 01:28:14 +08:00
Args:
root (str): root path to datasets.
sources (str or list): source dataset(s).
targets (str or list, optional): target dataset(s). If not given,
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
2019-08-26 17:34:31 +08:00
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
norm_std (list or None, optional): data std. Default is None (use imagenet std).
use_gpu (bool, optional): use gpu. Default is True.
2019-03-22 01:28:14 +08:00
split_id (int, optional): split id (*0-based*). Default is 0.
combineall (bool, optional): combine train, query and gallery in a dataset for
training. Default is False.
2019-08-26 17:34:31 +08:00
batch_size_train (int, optional): number of tracklets in a training batch. Default is 3.
batch_size_test (int, optional): number of tracklets in a test batch. Default is 3.
2019-03-22 01:28:14 +08:00
workers (int, optional): number of workers. Default is 4.
num_instances (int, optional): number of instances per identity in a batch.
Default is 4.
2019-11-28 00:49:29 +08:00
train_sampler (str, optional): sampler. Default is RandomSampler.
2019-03-22 01:28:14 +08:00
seq_len (int, optional): how many images to sample in a tracklet. Default is 15.
2019-03-24 07:09:39 +08:00
sample_method (str, optional): how to sample images in a tracklet. Default is "evenly".
2019-08-26 17:34:31 +08:00
Choices are ["evenly", "random", "all"]. "evenly" and "random" will sample ``seq_len``
images in a tracklet while "all" samples all images in a tracklet, where the batch size
2019-03-24 07:09:39 +08:00
needs to be set to 1.
2019-03-22 01:28:14 +08:00
Examples::
datamanager = torchreid.data.VideoDataManager(
2019-03-24 07:09:39 +08:00
root='path/to/reid-data',
2019-03-22 01:28:14 +08:00
sources='mars',
height=256,
width=128,
2019-08-26 17:34:31 +08:00
batch_size_train=3,
batch_size_test=3,
2019-03-24 07:09:39 +08:00
seq_len=15,
sample_method='evenly'
2019-03-22 01:28:14 +08:00
)
2019-07-03 22:54:20 +08:00
2019-11-28 00:35:54 +08:00
# return train loader of source data
train_loader = datamanager.train_loader
# return test loader of target data
test_loader = datamanager.test_loader
2019-07-03 22:54:20 +08:00
.. note::
The current implementation only supports image-like training. Therefore, each image in a
sampled tracklet will undergo independent transformation functions. To achieve tracklet-aware
training, you need to modify the transformation functions for video reid such that each function
applies the same operation to all images in a tracklet to keep consistency.
2019-03-22 01:28:14 +08:00
"""
data_type = 'video'
2019-03-21 20:59:54 +08:00
2019-11-28 00:35:54 +08:00
def __init__(
self,
root='',
sources=None,
targets=None,
height=256,
width=128,
transforms='random_flip',
norm_mean=None,
norm_std=None,
use_gpu=True,
split_id=0,
combineall=False,
batch_size_train=3,
batch_size_test=3,
workers=4,
num_instances=4,
2019-11-28 00:49:29 +08:00
train_sampler='RandomSampler',
2019-11-28 00:35:54 +08:00
seq_len=15,
sample_method='evenly'
):
2019-12-01 10:35:44 +08:00
2019-11-28 00:35:54 +08:00
super(VideoDataManager, self).__init__(
2019-12-01 10:35:44 +08:00
sources=sources,
targets=targets,
height=height,
width=width,
transforms=transforms,
norm_mean=norm_mean,
norm_std=norm_std,
2019-11-28 00:35:54 +08:00
use_gpu=use_gpu
)
2019-03-21 20:59:54 +08:00
print('=> Loading train (source) dataset')
2019-11-28 00:35:54 +08:00
trainset = []
2019-03-21 20:59:54 +08:00
for name in self.sources:
trainset_ = init_video_dataset(
name,
transform=self.transform_tr,
mode='train',
combineall=combineall,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
trainset.append(trainset_)
trainset = sum(trainset)
self._num_train_pids = trainset.num_train_pids
self._num_train_cams = trainset.num_train_cams
train_sampler = build_train_sampler(
2019-12-01 10:35:44 +08:00
trainset.train,
train_sampler,
2019-08-26 17:34:31 +08:00
batch_size=batch_size_train,
2019-03-21 20:59:54 +08:00
num_instances=num_instances
)
2019-11-28 00:35:54 +08:00
self.train_loader = torch.utils.data.DataLoader(
2019-03-21 20:59:54 +08:00
trainset,
sampler=train_sampler,
2019-08-26 17:34:31 +08:00
batch_size=batch_size_train,
2019-03-21 20:59:54 +08:00
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=True
)
print('=> Loading test (target) dataset')
2019-12-01 10:35:44 +08:00
self.test_loader = {
name: {
'query': None,
'gallery': None
}
for name in self.targets
}
self.test_dataset = {
name: {
'query': None,
'gallery': None
}
for name in self.targets
}
2019-03-21 20:59:54 +08:00
for name in self.targets:
# build query loader
queryset = init_video_dataset(
name,
transform=self.transform_te,
mode='query',
combineall=combineall,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
2019-11-28 00:35:54 +08:00
self.test_loader[name]['query'] = torch.utils.data.DataLoader(
2019-03-21 20:59:54 +08:00
queryset,
2019-08-26 17:34:31 +08:00
batch_size=batch_size_test,
2019-03-21 20:59:54 +08:00
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
# build gallery loader
galleryset = init_video_dataset(
name,
transform=self.transform_te,
mode='gallery',
combineall=combineall,
verbose=False,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
2019-11-28 00:35:54 +08:00
self.test_loader[name]['gallery'] = torch.utils.data.DataLoader(
2019-03-21 20:59:54 +08:00
galleryset,
2019-08-26 17:34:31 +08:00
batch_size=batch_size_test,
2019-03-21 20:59:54 +08:00
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
2019-11-28 00:35:54 +08:00
self.test_dataset[name]['query'] = queryset.query
self.test_dataset[name]['gallery'] = galleryset.gallery
2019-03-21 20:59:54 +08:00
print('\n')
print(' **************** Summary ****************')
2019-11-28 02:19:06 +08:00
print(' source : {}'.format(self.sources))
print(' # source datasets : {}'.format(len(self.sources)))
print(' # source ids : {}'.format(self.num_train_pids))
print(' # source tracklets : {}'.format(len(trainset)))
print(' # source cameras : {}'.format(self.num_train_cams))
print(' target : {}'.format(self.targets))
2019-03-21 20:59:54 +08:00
print(' *****************************************')
2019-07-04 14:39:11 +08:00
print('\n')