2018-11-06 05:19:27 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
2018-11-07 23:36:49 +08:00
|
|
|
from .dataset_loader import ImageDataset, VideoDataset
|
2018-11-06 05:19:27 +08:00
|
|
|
from .datasets import init_imgreid_dataset, init_vidreid_dataset
|
2018-11-08 01:09:23 +08:00
|
|
|
from .transforms import build_transforms
|
2018-11-06 05:19:27 +08:00
|
|
|
|
|
|
|
|
2018-11-08 04:48:21 +08:00
|
|
|
class BaseDataManager(object):
|
|
|
|
|
2018-11-08 05:46:23 +08:00
|
|
|
@property
|
|
|
|
def num_train_pids(self):
|
|
|
|
return self._num_train_pids
|
|
|
|
|
|
|
|
@property
|
|
|
|
def num_train_cams(self):
|
|
|
|
return self._num_train_cams
|
|
|
|
|
2018-11-08 04:48:21 +08:00
|
|
|
def return_dataloaders(self):
|
2018-11-08 05:46:23 +08:00
|
|
|
"""
|
|
|
|
Return trainloader and testloader dictionary
|
|
|
|
"""
|
2018-11-08 04:48:21 +08:00
|
|
|
return self.trainloader, self.testloader_dict
|
|
|
|
|
2018-11-08 05:46:23 +08:00
|
|
|
def return_testdataset_by_name(self, name):
|
|
|
|
"""
|
|
|
|
Return query and gallery, each containing a list of (img_path, pid, camid).
|
|
|
|
"""
|
|
|
|
return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery']
|
|
|
|
|
2018-11-08 04:48:21 +08:00
|
|
|
|
|
|
|
class ImageDataManager(BaseDataManager):
|
|
|
|
"""
|
|
|
|
Image-ReID data manager
|
|
|
|
"""
|
2018-11-06 05:19:27 +08:00
|
|
|
|
2018-11-08 01:09:23 +08:00
|
|
|
def __init__(self,
|
|
|
|
use_gpu,
|
2018-11-09 05:40:18 +08:00
|
|
|
source_names,
|
|
|
|
target_names,
|
2018-11-08 01:09:23 +08:00
|
|
|
root,
|
|
|
|
split_id=0,
|
|
|
|
height=256,
|
|
|
|
width=128,
|
|
|
|
train_batch_size=32,
|
|
|
|
test_batch_size=100,
|
|
|
|
workers=4,
|
|
|
|
cuhk03_labeled=False,
|
|
|
|
cuhk03_classic_split=False
|
|
|
|
):
|
2018-11-08 04:48:21 +08:00
|
|
|
super(ImageDataManager, self).__init__()
|
2018-11-08 05:46:23 +08:00
|
|
|
self.use_gpu = use_gpu
|
2018-11-09 05:40:18 +08:00
|
|
|
self.source_names = source_names
|
|
|
|
self.target_names = target_names
|
2018-11-08 05:46:23 +08:00
|
|
|
self.root = root
|
|
|
|
self.split_id = split_id
|
|
|
|
self.height = height
|
|
|
|
self.width = width
|
|
|
|
self.train_batch_size = train_batch_size
|
|
|
|
self.test_batch_size = test_batch_size
|
|
|
|
self.workers = workers
|
|
|
|
self.cuhk03_labeled = cuhk03_labeled
|
|
|
|
self.cuhk03_classic_split = cuhk03_classic_split
|
|
|
|
self.pin_memory = True if self.use_gpu else False
|
|
|
|
|
|
|
|
# Build train and test transform functions
|
|
|
|
transform_train = build_transforms(self.height, self.width, is_train=True)
|
|
|
|
transform_test = build_transforms(self.height, self.width, is_train=False)
|
2018-11-06 05:19:27 +08:00
|
|
|
|
2018-11-09 05:40:18 +08:00
|
|
|
print("=> Initializing TRAIN (source) datasets")
|
2018-11-08 05:46:23 +08:00
|
|
|
self.train = []
|
|
|
|
self._num_train_pids = 0
|
|
|
|
self._num_train_cams = 0
|
2018-11-06 05:19:27 +08:00
|
|
|
|
2018-11-09 05:40:18 +08:00
|
|
|
for name in self.source_names:
|
2018-11-08 01:09:23 +08:00
|
|
|
dataset = init_imgreid_dataset(
|
2018-11-08 05:46:23 +08:00
|
|
|
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
|
|
|
|
cuhk03_classic_split=self.cuhk03_classic_split
|
2018-11-08 01:09:23 +08:00
|
|
|
)
|
2018-11-06 05:19:27 +08:00
|
|
|
|
|
|
|
for img_path, pid, camid in dataset.train:
|
2018-11-08 05:46:23 +08:00
|
|
|
pid += self._num_train_pids
|
|
|
|
camid += self._num_train_cams
|
2018-11-06 05:19:27 +08:00
|
|
|
self.train.append((img_path, pid, camid))
|
|
|
|
|
2018-11-08 05:46:23 +08:00
|
|
|
self._num_train_pids += dataset.num_train_pids
|
|
|
|
self._num_train_cams += dataset.num_train_cams
|
2018-11-06 05:19:27 +08:00
|
|
|
|
|
|
|
self.trainloader = DataLoader(
|
|
|
|
ImageDataset(self.train, transform=transform_train),
|
2018-11-08 05:46:23 +08:00
|
|
|
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
|
|
|
|
pin_memory=self.pin_memory, drop_last=True
|
2018-11-06 05:19:27 +08:00
|
|
|
)
|
|
|
|
|
2018-11-09 05:40:18 +08:00
|
|
|
print("=> Initializing TEST (target) datasets")
|
|
|
|
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
|
|
|
|
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
|
2018-11-08 05:46:23 +08:00
|
|
|
|
2018-11-09 05:40:18 +08:00
|
|
|
for name in self.target_names:
|
2018-11-08 01:09:23 +08:00
|
|
|
dataset = init_imgreid_dataset(
|
2018-11-08 05:46:23 +08:00
|
|
|
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
|
|
|
|
cuhk03_classic_split=self.cuhk03_classic_split
|
2018-11-08 01:09:23 +08:00
|
|
|
)
|
2018-11-06 05:19:27 +08:00
|
|
|
|
|
|
|
self.testloader_dict[name]['query'] = DataLoader(
|
|
|
|
ImageDataset(dataset.query, transform=transform_test),
|
2018-11-08 05:46:23 +08:00
|
|
|
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
|
|
|
pin_memory=self.pin_memory, drop_last=False
|
2018-11-06 05:19:27 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
self.testloader_dict[name]['gallery'] = DataLoader(
|
|
|
|
ImageDataset(dataset.gallery, transform=transform_test),
|
2018-11-08 05:46:23 +08:00
|
|
|
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
|
|
|
pin_memory=self.pin_memory, drop_last=False
|
2018-11-06 05:19:27 +08:00
|
|
|
)
|
|
|
|
|
2018-11-08 05:46:23 +08:00
|
|
|
self.testdataset_dict[name]['query'] = dataset.query
|
|
|
|
self.testdataset_dict[name]['gallery'] = dataset.gallery
|
|
|
|
|
2018-11-06 05:19:27 +08:00
|
|
|
print("\n")
|
|
|
|
print(" **************** Summary ****************")
|
2018-11-09 05:40:18 +08:00
|
|
|
print(" train names : {}".format(self.source_names))
|
|
|
|
print(" # train datasets : {}".format(len(self.source_names)))
|
2018-11-08 05:46:23 +08:00
|
|
|
print(" # train ids : {}".format(self._num_train_pids))
|
2018-11-06 05:19:27 +08:00
|
|
|
print(" # train images : {}".format(len(self.train)))
|
2018-11-08 05:46:23 +08:00
|
|
|
print(" # train cameras : {}".format(self._num_train_cams))
|
2018-11-09 05:40:18 +08:00
|
|
|
print(" test names : {}".format(self.target_names))
|
2018-11-06 05:36:12 +08:00
|
|
|
print(" *****************************************")
|
2018-11-07 23:36:49 +08:00
|
|
|
print("\n")
|
|
|
|
|
|
|
|
|
2018-11-08 04:48:21 +08:00
|
|
|
class VideoDataManager(BaseDataManager):
|
|
|
|
"""
|
|
|
|
Video-ReID data manager
|
|
|
|
"""
|
2018-11-07 23:36:49 +08:00
|
|
|
|
2018-11-08 01:09:23 +08:00
|
|
|
def __init__(self,
|
|
|
|
use_gpu,
|
2018-11-09 05:40:18 +08:00
|
|
|
source_names,
|
|
|
|
target_names,
|
2018-11-08 01:09:23 +08:00
|
|
|
root,
|
|
|
|
split_id=0,
|
|
|
|
height=256,
|
|
|
|
width=128,
|
|
|
|
train_batch_size=32,
|
|
|
|
test_batch_size=100,
|
|
|
|
workers=4,
|
|
|
|
seq_len=15,
|
2018-11-09 05:40:18 +08:00
|
|
|
sample_method='evenly',
|
2018-11-08 05:46:23 +08:00
|
|
|
image_training=True # train the video-reid model with images rather than tracklets
|
2018-11-08 01:09:23 +08:00
|
|
|
):
|
2018-11-08 04:48:21 +08:00
|
|
|
super(VideoDataManager, self).__init__()
|
2018-11-08 05:46:23 +08:00
|
|
|
self.use_gpu = use_gpu
|
2018-11-09 05:40:18 +08:00
|
|
|
self.source_names = source_names
|
|
|
|
self.target_names = target_names
|
2018-11-08 05:46:23 +08:00
|
|
|
self.root = root
|
|
|
|
self.split_id = split_id
|
|
|
|
self.height = height
|
|
|
|
self.width = width
|
|
|
|
self.train_batch_size = train_batch_size
|
|
|
|
self.test_batch_size = test_batch_size
|
|
|
|
self.workers = workers
|
|
|
|
self.seq_len = seq_len
|
2018-11-09 05:40:18 +08:00
|
|
|
self.sample_method = sample_method
|
2018-11-08 05:46:23 +08:00
|
|
|
self.image_training = image_training
|
|
|
|
self.pin_memory = True if self.use_gpu else False
|
|
|
|
|
|
|
|
# Build train and test transform functions
|
|
|
|
transform_train = build_transforms(self.height, self.width, is_train=True)
|
|
|
|
transform_test = build_transforms(self.height, self.width, is_train=False)
|
2018-11-07 23:36:49 +08:00
|
|
|
|
2018-11-09 05:40:18 +08:00
|
|
|
print("=> Initializing TRAIN (source) datasets")
|
2018-11-08 05:46:23 +08:00
|
|
|
self.train = []
|
|
|
|
self._num_train_pids = 0
|
|
|
|
self._num_train_cams = 0
|
2018-11-07 23:36:49 +08:00
|
|
|
|
2018-11-09 05:40:18 +08:00
|
|
|
for name in self.source_names:
|
2018-11-08 05:46:23 +08:00
|
|
|
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
|
2018-11-07 23:36:49 +08:00
|
|
|
|
|
|
|
for img_paths, pid, camid in dataset.train:
|
2018-11-08 05:46:23 +08:00
|
|
|
pid += self._num_train_pids
|
|
|
|
camid += self._num_train_cams
|
|
|
|
if self.image_training:
|
2018-11-07 23:36:49 +08:00
|
|
|
# decompose tracklets into images
|
|
|
|
for img_path in img_paths:
|
|
|
|
self.train.append((img_path, pid, camid))
|
|
|
|
else:
|
|
|
|
self.train.append((img_paths, pid, camid))
|
|
|
|
|
2018-11-08 05:46:23 +08:00
|
|
|
self._num_train_pids += dataset.num_train_pids
|
|
|
|
self._num_train_cams += dataset.num_train_cams
|
2018-11-07 23:36:49 +08:00
|
|
|
|
|
|
|
if image_training:
|
|
|
|
# each batch has image data of shape (batch, channel, height, width)
|
|
|
|
self.trainloader = DataLoader(
|
|
|
|
ImageDataset(self.train, transform=transform_train),
|
2018-11-08 05:46:23 +08:00
|
|
|
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
|
|
|
|
pin_memory=self.pin_memory, drop_last=True
|
2018-11-07 23:36:49 +08:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
# each batch has image data of shape (batch, seq_len, channel, height, width)
|
|
|
|
self.trainloader = DataLoader(
|
2018-11-09 05:40:18 +08:00
|
|
|
VideoDataset(self.train, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
|
2018-11-08 05:46:23 +08:00
|
|
|
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
|
|
|
|
pin_memory=self.pin_memory, drop_last=True
|
2018-11-07 23:36:49 +08:00
|
|
|
)
|
|
|
|
|
2018-11-09 05:40:18 +08:00
|
|
|
print("=> Initializing TEST (target) datasets")
|
|
|
|
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
|
|
|
|
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
|
2018-11-08 05:46:23 +08:00
|
|
|
|
2018-11-09 05:40:18 +08:00
|
|
|
for name in self.target_names:
|
2018-11-08 05:46:23 +08:00
|
|
|
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
|
2018-11-07 23:36:49 +08:00
|
|
|
|
|
|
|
self.testloader_dict[name]['query'] = DataLoader(
|
2018-11-09 05:40:18 +08:00
|
|
|
VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
|
2018-11-08 05:46:23 +08:00
|
|
|
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
|
|
|
pin_memory=self.pin_memory, drop_last=False,
|
2018-11-07 23:36:49 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
self.testloader_dict[name]['gallery'] = DataLoader(
|
2018-11-09 05:40:18 +08:00
|
|
|
VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
|
2018-11-08 05:46:23 +08:00
|
|
|
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
|
|
|
|
pin_memory=self.pin_memory, drop_last=False,
|
2018-11-07 23:36:49 +08:00
|
|
|
)
|
|
|
|
|
2018-11-08 05:46:23 +08:00
|
|
|
self.testdataset_dict[name]['query'] = dataset.query
|
|
|
|
self.testdataset_dict[name]['gallery'] = dataset.gallery
|
|
|
|
|
2018-11-07 23:36:49 +08:00
|
|
|
print("\n")
|
|
|
|
print(" **************** Summary ****************")
|
2018-11-09 05:40:18 +08:00
|
|
|
print(" train names : {}".format(self.source_names))
|
|
|
|
print(" # train datasets : {}".format(len(self.source_names)))
|
2018-11-08 05:46:23 +08:00
|
|
|
print(" # train ids : {}".format(self._num_train_pids))
|
|
|
|
if self.image_training:
|
2018-11-07 23:36:49 +08:00
|
|
|
print(" # train images : {}".format(len(self.train)))
|
|
|
|
else:
|
|
|
|
print(" # train tracklets: {}".format(len(self.train)))
|
2018-11-08 05:46:23 +08:00
|
|
|
print(" # train cameras : {}".format(self._num_train_cams))
|
2018-11-09 05:40:18 +08:00
|
|
|
print(" test names : {}".format(self.target_names))
|
2018-11-07 23:36:49 +08:00
|
|
|
print(" *****************************************")
|
2018-11-06 05:36:12 +08:00
|
|
|
print("\n")
|