deep-person-reid/torchreid/data_manager.py

244 lines
9.8 KiB
Python
Raw Normal View History

2018-11-06 05:19:27 +08:00
from __future__ import absolute_import
from __future__ import print_function
from torch.utils.data import DataLoader
2018-11-07 23:36:49 +08:00
from .dataset_loader import ImageDataset, VideoDataset
2018-11-06 05:19:27 +08:00
from .datasets import init_imgreid_dataset, init_vidreid_dataset
2018-11-08 01:09:23 +08:00
from .transforms import build_transforms
2018-11-06 05:19:27 +08:00
class BaseDataManager(object):
2018-11-08 05:46:23 +08:00
@property
def num_train_pids(self):
return self._num_train_pids
@property
def num_train_cams(self):
return self._num_train_cams
def return_dataloaders(self):
2018-11-08 05:46:23 +08:00
"""
Return trainloader and testloader dictionary
"""
return self.trainloader, self.testloader_dict
2018-11-08 05:46:23 +08:00
def return_testdataset_by_name(self, name):
"""
Return query and gallery, each containing a list of (img_path, pid, camid).
"""
return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery']
class ImageDataManager(BaseDataManager):
"""
Image-ReID data manager
"""
2018-11-06 05:19:27 +08:00
2018-11-08 01:09:23 +08:00
def __init__(self,
use_gpu,
source_names,
target_names,
2018-11-08 01:09:23 +08:00
root,
split_id=0,
height=256,
width=128,
train_batch_size=32,
test_batch_size=100,
workers=4,
cuhk03_labeled=False,
cuhk03_classic_split=False
):
super(ImageDataManager, self).__init__()
2018-11-08 05:46:23 +08:00
self.use_gpu = use_gpu
self.source_names = source_names
self.target_names = target_names
2018-11-08 05:46:23 +08:00
self.root = root
self.split_id = split_id
self.height = height
self.width = width
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.workers = workers
self.cuhk03_labeled = cuhk03_labeled
self.cuhk03_classic_split = cuhk03_classic_split
self.pin_memory = True if self.use_gpu else False
# Build train and test transform functions
transform_train = build_transforms(self.height, self.width, is_train=True)
transform_test = build_transforms(self.height, self.width, is_train=False)
2018-11-06 05:19:27 +08:00
print("=> Initializing TRAIN (source) datasets")
2018-11-08 05:46:23 +08:00
self.train = []
self._num_train_pids = 0
self._num_train_cams = 0
2018-11-06 05:19:27 +08:00
for name in self.source_names:
2018-11-08 01:09:23 +08:00
dataset = init_imgreid_dataset(
2018-11-08 05:46:23 +08:00
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
cuhk03_classic_split=self.cuhk03_classic_split
2018-11-08 01:09:23 +08:00
)
2018-11-06 05:19:27 +08:00
for img_path, pid, camid in dataset.train:
2018-11-08 05:46:23 +08:00
pid += self._num_train_pids
camid += self._num_train_cams
2018-11-06 05:19:27 +08:00
self.train.append((img_path, pid, camid))
2018-11-08 05:46:23 +08:00
self._num_train_pids += dataset.num_train_pids
self._num_train_cams += dataset.num_train_cams
2018-11-06 05:19:27 +08:00
self.trainloader = DataLoader(
ImageDataset(self.train, transform=transform_train),
2018-11-08 05:46:23 +08:00
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=True
2018-11-06 05:19:27 +08:00
)
print("=> Initializing TEST (target) datasets")
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
2018-11-08 05:46:23 +08:00
for name in self.target_names:
2018-11-08 01:09:23 +08:00
dataset = init_imgreid_dataset(
2018-11-08 05:46:23 +08:00
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
cuhk03_classic_split=self.cuhk03_classic_split
2018-11-08 01:09:23 +08:00
)
2018-11-06 05:19:27 +08:00
self.testloader_dict[name]['query'] = DataLoader(
ImageDataset(dataset.query, transform=transform_test),
2018-11-08 05:46:23 +08:00
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=False
2018-11-06 05:19:27 +08:00
)
self.testloader_dict[name]['gallery'] = DataLoader(
ImageDataset(dataset.gallery, transform=transform_test),
2018-11-08 05:46:23 +08:00
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=False
2018-11-06 05:19:27 +08:00
)
2018-11-08 05:46:23 +08:00
self.testdataset_dict[name]['query'] = dataset.query
self.testdataset_dict[name]['gallery'] = dataset.gallery
2018-11-06 05:19:27 +08:00
print("\n")
print(" **************** Summary ****************")
print(" train names : {}".format(self.source_names))
print(" # train datasets : {}".format(len(self.source_names)))
2018-11-08 05:46:23 +08:00
print(" # train ids : {}".format(self._num_train_pids))
2018-11-06 05:19:27 +08:00
print(" # train images : {}".format(len(self.train)))
2018-11-08 05:46:23 +08:00
print(" # train cameras : {}".format(self._num_train_cams))
print(" test names : {}".format(self.target_names))
2018-11-06 05:36:12 +08:00
print(" *****************************************")
2018-11-07 23:36:49 +08:00
print("\n")
class VideoDataManager(BaseDataManager):
"""
Video-ReID data manager
"""
2018-11-07 23:36:49 +08:00
2018-11-08 01:09:23 +08:00
def __init__(self,
use_gpu,
source_names,
target_names,
2018-11-08 01:09:23 +08:00
root,
split_id=0,
height=256,
width=128,
train_batch_size=32,
test_batch_size=100,
workers=4,
seq_len=15,
sample_method='evenly',
2018-11-08 05:46:23 +08:00
image_training=True # train the video-reid model with images rather than tracklets
2018-11-08 01:09:23 +08:00
):
super(VideoDataManager, self).__init__()
2018-11-08 05:46:23 +08:00
self.use_gpu = use_gpu
self.source_names = source_names
self.target_names = target_names
2018-11-08 05:46:23 +08:00
self.root = root
self.split_id = split_id
self.height = height
self.width = width
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.workers = workers
self.seq_len = seq_len
self.sample_method = sample_method
2018-11-08 05:46:23 +08:00
self.image_training = image_training
self.pin_memory = True if self.use_gpu else False
# Build train and test transform functions
transform_train = build_transforms(self.height, self.width, is_train=True)
transform_test = build_transforms(self.height, self.width, is_train=False)
2018-11-07 23:36:49 +08:00
print("=> Initializing TRAIN (source) datasets")
2018-11-08 05:46:23 +08:00
self.train = []
self._num_train_pids = 0
self._num_train_cams = 0
2018-11-07 23:36:49 +08:00
for name in self.source_names:
2018-11-08 05:46:23 +08:00
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
2018-11-07 23:36:49 +08:00
for img_paths, pid, camid in dataset.train:
2018-11-08 05:46:23 +08:00
pid += self._num_train_pids
camid += self._num_train_cams
if self.image_training:
2018-11-07 23:36:49 +08:00
# decompose tracklets into images
for img_path in img_paths:
self.train.append((img_path, pid, camid))
else:
self.train.append((img_paths, pid, camid))
2018-11-08 05:46:23 +08:00
self._num_train_pids += dataset.num_train_pids
self._num_train_cams += dataset.num_train_cams
2018-11-07 23:36:49 +08:00
if image_training:
# each batch has image data of shape (batch, channel, height, width)
self.trainloader = DataLoader(
ImageDataset(self.train, transform=transform_train),
2018-11-08 05:46:23 +08:00
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=True
2018-11-07 23:36:49 +08:00
)
else:
# each batch has image data of shape (batch, seq_len, channel, height, width)
self.trainloader = DataLoader(
VideoDataset(self.train, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
2018-11-08 05:46:23 +08:00
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=True
2018-11-07 23:36:49 +08:00
)
print("=> Initializing TEST (target) datasets")
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
2018-11-08 05:46:23 +08:00
for name in self.target_names:
2018-11-08 05:46:23 +08:00
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
2018-11-07 23:36:49 +08:00
self.testloader_dict[name]['query'] = DataLoader(
VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
2018-11-08 05:46:23 +08:00
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=False,
2018-11-07 23:36:49 +08:00
)
self.testloader_dict[name]['gallery'] = DataLoader(
VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
2018-11-08 05:46:23 +08:00
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=False,
2018-11-07 23:36:49 +08:00
)
2018-11-08 05:46:23 +08:00
self.testdataset_dict[name]['query'] = dataset.query
self.testdataset_dict[name]['gallery'] = dataset.gallery
2018-11-07 23:36:49 +08:00
print("\n")
print(" **************** Summary ****************")
print(" train names : {}".format(self.source_names))
print(" # train datasets : {}".format(len(self.source_names)))
2018-11-08 05:46:23 +08:00
print(" # train ids : {}".format(self._num_train_pids))
if self.image_training:
2018-11-07 23:36:49 +08:00
print(" # train images : {}".format(len(self.train)))
else:
print(" # train tracklets: {}".format(len(self.train)))
2018-11-08 05:46:23 +08:00
print(" # train cameras : {}".format(self._num_train_cams))
print(" test names : {}".format(self.target_names))
2018-11-07 23:36:49 +08:00
print(" *****************************************")
2018-11-06 05:36:12 +08:00
print("\n")