deep-person-reid/torchreid/data_manager.py

234 lines
9.3 KiB
Python
Raw Normal View History

2018-11-06 05:19:27 +08:00
from __future__ import absolute_import
from __future__ import print_function
from torch.utils.data import DataLoader
2018-11-07 23:36:49 +08:00
from .dataset_loader import ImageDataset, VideoDataset
2018-11-06 05:19:27 +08:00
from .datasets import init_imgreid_dataset, init_vidreid_dataset
2018-11-08 01:09:23 +08:00
from .transforms import build_transforms
2019-01-29 02:50:09 +08:00
from .samplers import build_train_sampler
2018-11-06 05:19:27 +08:00
class BaseDataManager(object):
2018-11-08 05:46:23 +08:00
@property
def num_train_pids(self):
return self._num_train_pids
@property
def num_train_cams(self):
return self._num_train_cams
def return_dataloaders(self):
2018-11-08 05:46:23 +08:00
"""
Return trainloader and testloader dictionary
"""
return self.trainloader, self.testloader_dict
2018-11-08 05:46:23 +08:00
def return_testdataset_by_name(self, name):
"""
Return query and gallery, each containing a list of (img_path, pid, camid).
"""
return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery']
class ImageDataManager(BaseDataManager):
"""
Image-ReID data manager
"""
2018-11-06 05:19:27 +08:00
2018-11-08 01:09:23 +08:00
def __init__(self,
use_gpu,
source_names,
target_names,
2019-01-29 02:50:09 +08:00
root='data',
2018-11-08 01:09:23 +08:00
split_id=0,
height=256,
width=128,
train_batch_size=32,
test_batch_size=100,
workers=4,
2018-11-09 06:02:53 +08:00
train_sampler='',
2019-02-03 22:03:06 +08:00
augdata_re=False, # use random erasing for data augmentation
2018-11-09 06:02:53 +08:00
num_instances=4, # number of instances per identity (for RandomIdentitySampler)
cuhk03_labeled=False, # use cuhk03's labeled or detected images
2019-01-23 00:07:43 +08:00
cuhk03_classic_split=False, # use cuhk03's classic split or 767/700 split
market1501_500k=False, # add 500k distractors to the gallery set for market1501
2018-11-08 01:09:23 +08:00
):
super(ImageDataManager, self).__init__()
2018-11-06 05:19:27 +08:00
2019-01-31 06:41:47 +08:00
print('=> Initializing TRAIN (source) datasets')
2019-01-29 02:50:09 +08:00
train = []
2018-11-08 05:46:23 +08:00
self._num_train_pids = 0
self._num_train_cams = 0
2018-11-06 05:19:27 +08:00
2019-01-29 02:50:09 +08:00
for name in source_names:
2018-11-08 01:09:23 +08:00
dataset = init_imgreid_dataset(
2019-01-29 02:50:09 +08:00
root=root, name=name, split_id=split_id, cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split, market1501_500k=market1501_500k
2018-11-08 01:09:23 +08:00
)
2018-11-06 05:19:27 +08:00
for img_path, pid, camid in dataset.train:
2018-11-08 05:46:23 +08:00
pid += self._num_train_pids
camid += self._num_train_cams
2019-01-29 02:50:09 +08:00
train.append((img_path, pid, camid))
2018-11-06 05:19:27 +08:00
2018-11-08 05:46:23 +08:00
self._num_train_pids += dataset.num_train_pids
self._num_train_cams += dataset.num_train_cams
2018-11-06 05:19:27 +08:00
2019-02-03 22:03:06 +08:00
transform_train, transform_test = build_transforms(height, width, augdata_re=augdata_re)
2019-01-29 02:50:09 +08:00
train_sampler = build_train_sampler(
train, train_sampler,
train_batch_size=train_batch_size,
num_instances=num_instances,
)
self.trainloader = DataLoader(
ImageDataset(train, transform=transform_train), sampler=train_sampler,
batch_size=train_batch_size, shuffle=False, num_workers=workers,
pin_memory=use_gpu, drop_last=True
)
2018-11-06 05:19:27 +08:00
2019-01-31 06:41:47 +08:00
print('=> Initializing TEST (target) datasets')
2019-01-29 02:50:09 +08:00
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
2018-11-08 05:46:23 +08:00
2019-01-29 02:50:09 +08:00
for name in target_names:
2018-11-08 01:09:23 +08:00
dataset = init_imgreid_dataset(
2019-01-29 02:50:09 +08:00
root=root, name=name, split_id=split_id, cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split, market1501_500k=market1501_500k
2018-11-08 01:09:23 +08:00
)
2018-11-06 05:19:27 +08:00
self.testloader_dict[name]['query'] = DataLoader(
ImageDataset(dataset.query, transform=transform_test),
2019-01-29 02:50:09 +08:00
batch_size=test_batch_size, shuffle=False, num_workers=workers,
pin_memory=use_gpu, drop_last=False
2018-11-06 05:19:27 +08:00
)
self.testloader_dict[name]['gallery'] = DataLoader(
ImageDataset(dataset.gallery, transform=transform_test),
2019-01-29 02:50:09 +08:00
batch_size=test_batch_size, shuffle=False, num_workers=workers,
pin_memory=use_gpu, drop_last=False
2018-11-06 05:19:27 +08:00
)
2018-11-08 05:46:23 +08:00
self.testdataset_dict[name]['query'] = dataset.query
self.testdataset_dict[name]['gallery'] = dataset.gallery
2019-01-31 06:41:47 +08:00
print('\n')
print(' **************** Summary ****************')
print(' train names : {}'.format(source_names))
print(' # train datasets : {}'.format(len(source_names)))
print(' # train ids : {}'.format(self.num_train_pids))
print(' # train images : {}'.format(len(train)))
print(' # train cameras : {}'.format(self.num_train_cams))
print(' test names : {}'.format(target_names))
print(' *****************************************')
print('\n')
2018-11-07 23:36:49 +08:00
class VideoDataManager(BaseDataManager):
"""
Video-ReID data manager
"""
2018-11-07 23:36:49 +08:00
2018-11-08 01:09:23 +08:00
def __init__(self,
use_gpu,
source_names,
target_names,
2019-01-29 02:50:09 +08:00
root='data',
2018-11-08 01:09:23 +08:00
split_id=0,
height=256,
width=128,
train_batch_size=32,
test_batch_size=100,
workers=4,
2019-01-29 02:50:09 +08:00
train_sampler='',
2019-02-03 22:03:06 +08:00
augdata_re=False, # use random erasing for data augmentation
2019-01-29 02:50:09 +08:00
num_instances=4,
2018-11-08 01:09:23 +08:00
seq_len=15,
sample_method='evenly',
2018-11-08 05:46:23 +08:00
image_training=True # train the video-reid model with images rather than tracklets
2018-11-08 01:09:23 +08:00
):
super(VideoDataManager, self).__init__()
2018-11-07 23:36:49 +08:00
2019-01-31 06:41:47 +08:00
print('=> Initializing TRAIN (source) datasets')
2019-01-29 02:50:09 +08:00
train = []
2018-11-08 05:46:23 +08:00
self._num_train_pids = 0
self._num_train_cams = 0
2018-11-07 23:36:49 +08:00
2019-01-29 02:50:09 +08:00
for name in source_names:
dataset = init_vidreid_dataset(root=root, name=name, split_id=split_id)
2018-11-07 23:36:49 +08:00
for img_paths, pid, camid in dataset.train:
2018-11-08 05:46:23 +08:00
pid += self._num_train_pids
camid += self._num_train_cams
2019-01-29 02:50:09 +08:00
if image_training:
2018-11-07 23:36:49 +08:00
# decompose tracklets into images
for img_path in img_paths:
2019-01-29 02:50:09 +08:00
train.append((img_path, pid, camid))
2018-11-07 23:36:49 +08:00
else:
2019-01-29 02:50:09 +08:00
train.append((img_paths, pid, camid))
2018-11-07 23:36:49 +08:00
2018-11-08 05:46:23 +08:00
self._num_train_pids += dataset.num_train_pids
self._num_train_cams += dataset.num_train_cams
2018-11-07 23:36:49 +08:00
2019-02-03 22:03:06 +08:00
transform_train, transform_test = build_transforms(height, width, augdata_re=augdata_re)
2019-01-29 02:50:09 +08:00
train_sampler = build_train_sampler(
train, train_sampler,
train_batch_size=train_batch_size,
num_instances=num_instances,
)
if image_training:
2018-11-07 23:36:49 +08:00
# each batch has image data of shape (batch, channel, height, width)
self.trainloader = DataLoader(
2019-01-29 02:50:09 +08:00
ImageDataset(train, transform=transform_train), sampler=train_sampler,
batch_size=train_batch_size, shuffle=False, num_workers=workers,
pin_memory=use_gpu, drop_last=True
2018-11-07 23:36:49 +08:00
)
2019-01-29 02:50:09 +08:00
2018-11-07 23:36:49 +08:00
else:
# each batch has image data of shape (batch, seq_len, channel, height, width)
# note: this requires new training scripts
2018-11-07 23:36:49 +08:00
self.trainloader = DataLoader(
2019-01-29 02:50:09 +08:00
VideoDataset(train, seq_len=seq_len, sample_method=sample_method, transform=transform_train),
batch_size=train_batch_size, shuffle=True, num_workers=workers,
pin_memory=use_gpu, drop_last=True
2018-11-07 23:36:49 +08:00
)
2019-01-31 06:41:47 +08:00
print('=> Initializing TEST (target) datasets')
2019-01-29 02:50:09 +08:00
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
for name in target_names:
dataset = init_vidreid_dataset(root=root, name=name, split_id=split_id)
2018-11-07 23:36:49 +08:00
self.testloader_dict[name]['query'] = DataLoader(
2019-01-29 02:50:09 +08:00
VideoDataset(dataset.query, seq_len=seq_len, sample_method=sample_method, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=workers,
pin_memory=use_gpu, drop_last=False,
2018-11-07 23:36:49 +08:00
)
self.testloader_dict[name]['gallery'] = DataLoader(
2019-01-29 02:50:09 +08:00
VideoDataset(dataset.gallery, seq_len=seq_len, sample_method=sample_method, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=workers,
pin_memory=use_gpu, drop_last=False,
2018-11-07 23:36:49 +08:00
)
2018-11-08 05:46:23 +08:00
self.testdataset_dict[name]['query'] = dataset.query
self.testdataset_dict[name]['gallery'] = dataset.gallery
2019-01-31 06:41:47 +08:00
print('\n')
print(' **************** Summary ****************')
print(' train names : {}'.format(source_names))
print(' # train datasets : {}'.format(len(source_names)))
print(' # train ids : {}'.format(self.num_train_pids))
2019-01-29 02:50:09 +08:00
if image_training:
2019-01-31 06:41:47 +08:00
print(' # train images : {}'.format(len(train)))
2018-11-07 23:36:49 +08:00
else:
2019-01-31 06:41:47 +08:00
print(' # train tracklets: {}'.format(len(train)))
print(' # train cameras : {}'.format(self.num_train_cams))
print(' test names : {}'.format(target_names))
print(' *****************************************')
print('\n')