deep-person-reid/torchreid/data_manager.py

200 lines
7.6 KiB
Python

from __future__ import absolute_import
from __future__ import print_function
from torch.utils.data import DataLoader
from .dataset_loader import ImageDataset, VideoDataset
from .datasets import init_imgreid_dataset, init_vidreid_dataset
from .transforms import build_transforms
class BaseDataManager(object):
def return_dataloaders(self):
return self.trainloader, self.testloader_dict
class ImageDataManager(BaseDataManager):
"""
Image-ReID data manager
"""
def __init__(self,
use_gpu,
train_names,
test_names,
root,
split_id=0,
height=256,
width=128,
train_batch_size=32,
test_batch_size=100,
workers=4,
cuhk03_labeled=False,
cuhk03_classic_split=False
):
super(ImageDataManager, self).__init__()
pin_memory = True if use_gpu else False
transform_train = build_transforms(height, width, is_train=True)
transform_test = build_transforms(height, width, is_train=False)
self.train_names = train_names
self.test_names = test_names
self.train = []
self.num_train_pids = 0
self.num_train_cams = 0
print("=> Initializing TRAIN datasets")
for name in self.train_names:
dataset = init_imgreid_dataset(
root=root, name=name, split_id=split_id, cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split
)
for img_path, pid, camid in dataset.train:
pid += self.num_train_pids
camid += self.num_train_cams
self.train.append((img_path, pid, camid))
self.num_train_pids += dataset.num_train_pids
self.num_train_cams += dataset.num_train_cams
self.trainloader = DataLoader(
ImageDataset(self.train, transform=transform_train),
batch_size=train_batch_size, shuffle=True, num_workers=workers,
pin_memory=pin_memory, drop_last=True
)
print("=> Initializing TEST datasets")
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.test_names}
for name in self.test_names:
dataset = init_imgreid_dataset(
root=root, name=name, split_id=split_id, cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split
)
self.testloader_dict[name]['query'] = DataLoader(
ImageDataset(dataset.query, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=workers,
pin_memory=pin_memory, drop_last=False
)
self.testloader_dict[name]['gallery'] = DataLoader(
ImageDataset(dataset.gallery, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=workers,
pin_memory=pin_memory, drop_last=False
)
print("\n")
print(" **************** Summary ****************")
print(" train names : {}".format(self.train_names))
print(" # train datasets : {}".format(len(self.train_names)))
print(" # train ids : {}".format(self.num_train_pids))
print(" # train images : {}".format(len(self.train)))
print(" # train cameras : {}".format(self.num_train_cams))
print(" test names : {}".format(self.test_names))
print(" *****************************************")
print("\n")
class VideoDataManager(BaseDataManager):
"""
Video-ReID data manager
"""
def __init__(self,
use_gpu,
train_names,
test_names,
root,
split_id=0,
height=256,
width=128,
train_batch_size=32,
test_batch_size=100,
workers=4,
seq_len=15,
sample='evenly',
image_training=True
):
super(VideoDataManager, self).__init__()
pin_memory = True if use_gpu else False
transform_train = build_transforms(height, width, is_train=True)
transform_test = build_transforms(height, width, is_train=False)
self.train_names = train_names
self.test_names = test_names
self.train = []
self.num_train_pids = 0
self.num_train_cams = 0
print("=> Initializing TRAIN datasets")
for name in self.train_names:
dataset = init_vidreid_dataset(root=root, name=name, split_id=split_id)
for img_paths, pid, camid in dataset.train:
pid += self.num_train_pids
camid += self.num_train_cams
if image_training:
# decompose tracklets into images
for img_path in img_paths:
self.train.append((img_path, pid, camid))
else:
self.train.append((img_paths, pid, camid))
self.num_train_pids += dataset.num_train_pids
self.num_train_cams += dataset.num_train_cams
if image_training:
# each batch has image data of shape (batch, channel, height, width)
self.trainloader = DataLoader(
ImageDataset(self.train, transform=transform_train),
batch_size=train_batch_size, shuffle=True, num_workers=workers,
pin_memory=pin_memory, drop_last=True
)
else:
# each batch has image data of shape (batch, seq_len, channel, height, width)
self.trainloader = DataLoader(
VideoDataset(self.train, seq_len=seq_len, sample=sample, transform=transform_test),
batch_size=train_batch_size, shuffle=True, num_workers=workers,
pin_memory=pin_memory, drop_last=True
)
print("=> Initializing TEST datasets")
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.test_names}
for name in self.test_names:
dataset = init_vidreid_dataset(root=root, name=name, split_id=split_id)
self.testloader_dict[name]['query'] = DataLoader(
VideoDataset(dataset.query, seq_len=seq_len, sample=sample, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=workers,
pin_memory=pin_memory, drop_last=False,
)
self.testloader_dict[name]['gallery'] = DataLoader(
VideoDataset(dataset.gallery, seq_len=seq_len, sample=sample, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=workers,
pin_memory=pin_memory, drop_last=False,
)
print("\n")
print(" **************** Summary ****************")
print(" train names : {}".format(self.train_names))
print(" # train datasets : {}".format(len(self.train_names)))
print(" # train ids : {}".format(self.num_train_pids))
if image_training:
print(" # train images : {}".format(len(self.train)))
else:
print(" # train tracklets: {}".format(len(self.train)))
print(" # train cameras : {}".format(self.num_train_cams))
print(" test names : {}".format(self.test_names))
print(" *****************************************")
print("\n")