source/target -> source_names/target_names

pull/119/head
KaiyangZhou 2018-11-08 21:40:18 +00:00
parent 8bfdbe24d1
commit e17bbcddf3
1 changed files with 31 additions and 31 deletions

View File

@ -38,8 +38,8 @@ class ImageDataManager(BaseDataManager):
def __init__(self,
use_gpu,
train_names,
test_names,
source_names,
target_names,
root,
split_id=0,
height=256,
@ -52,8 +52,8 @@ class ImageDataManager(BaseDataManager):
):
super(ImageDataManager, self).__init__()
self.use_gpu = use_gpu
self.train_names = train_names
self.test_names = test_names
self.source_names = source_names
self.target_names = target_names
self.root = root
self.split_id = split_id
self.height = height
@ -69,12 +69,12 @@ class ImageDataManager(BaseDataManager):
transform_train = build_transforms(self.height, self.width, is_train=True)
transform_test = build_transforms(self.height, self.width, is_train=False)
print("=> Initializing TRAIN datasets")
print("=> Initializing TRAIN (source) datasets")
self.train = []
self._num_train_pids = 0
self._num_train_cams = 0
for name in self.train_names:
for name in self.source_names:
dataset = init_imgreid_dataset(
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
cuhk03_classic_split=self.cuhk03_classic_split
@ -94,11 +94,11 @@ class ImageDataManager(BaseDataManager):
pin_memory=self.pin_memory, drop_last=True
)
print("=> Initializing TEST datasets")
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.test_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.test_names}
print("=> Initializing TEST (target) datasets")
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
for name in self.test_names:
for name in self.target_names:
dataset = init_imgreid_dataset(
root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled,
cuhk03_classic_split=self.cuhk03_classic_split
@ -121,12 +121,12 @@ class ImageDataManager(BaseDataManager):
print("\n")
print(" **************** Summary ****************")
print(" train names : {}".format(self.train_names))
print(" # train datasets : {}".format(len(self.train_names)))
print(" train names : {}".format(self.source_names))
print(" # train datasets : {}".format(len(self.source_names)))
print(" # train ids : {}".format(self._num_train_pids))
print(" # train images : {}".format(len(self.train)))
print(" # train cameras : {}".format(self._num_train_cams))
print(" test names : {}".format(self.test_names))
print(" test names : {}".format(self.target_names))
print(" *****************************************")
print("\n")
@ -138,8 +138,8 @@ class VideoDataManager(BaseDataManager):
def __init__(self,
use_gpu,
train_names,
test_names,
source_names,
target_names,
root,
split_id=0,
height=256,
@ -148,13 +148,13 @@ class VideoDataManager(BaseDataManager):
test_batch_size=100,
workers=4,
seq_len=15,
sample='evenly',
sample_method='evenly',
image_training=True # train the video-reid model with images rather than tracklets
):
super(VideoDataManager, self).__init__()
self.use_gpu = use_gpu
self.train_names = train_names
self.test_names = test_names
self.source_names = source_names
self.target_names = target_names
self.root = root
self.split_id = split_id
self.height = height
@ -163,7 +163,7 @@ class VideoDataManager(BaseDataManager):
self.test_batch_size = test_batch_size
self.workers = workers
self.seq_len = seq_len
self.sample = sample
self.sample_method = sample_method
self.image_training = image_training
self.pin_memory = True if self.use_gpu else False
@ -171,12 +171,12 @@ class VideoDataManager(BaseDataManager):
transform_train = build_transforms(self.height, self.width, is_train=True)
transform_test = build_transforms(self.height, self.width, is_train=False)
print("=> Initializing TRAIN datasets")
print("=> Initializing TRAIN (source) datasets")
self.train = []
self._num_train_pids = 0
self._num_train_cams = 0
for name in self.train_names:
for name in self.source_names:
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
for img_paths, pid, camid in dataset.train:
@ -202,26 +202,26 @@ class VideoDataManager(BaseDataManager):
else:
# each batch has image data of shape (batch, seq_len, channel, height, width)
self.trainloader = DataLoader(
VideoDataset(self.train, seq_len=self.seq_len, sample=self.sample, transform=transform_test),
VideoDataset(self.train, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=True
)
print("=> Initializing TEST datasets")
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.test_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.test_names}
print("=> Initializing TEST (target) datasets")
self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names}
for name in self.test_names:
for name in self.target_names:
dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id)
self.testloader_dict[name]['query'] = DataLoader(
VideoDataset(dataset.query, seq_len=self.seq_len, sample=self.sample, transform=transform_test),
VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=False,
)
self.testloader_dict[name]['gallery'] = DataLoader(
VideoDataset(dataset.gallery, seq_len=self.seq_len, sample=self.sample, transform=transform_test),
VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test),
batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
pin_memory=self.pin_memory, drop_last=False,
)
@ -231,14 +231,14 @@ class VideoDataManager(BaseDataManager):
print("\n")
print(" **************** Summary ****************")
print(" train names : {}".format(self.train_names))
print(" # train datasets : {}".format(len(self.train_names)))
print(" train names : {}".format(self.source_names))
print(" # train datasets : {}".format(len(self.source_names)))
print(" # train ids : {}".format(self._num_train_pids))
if self.image_training:
print(" # train images : {}".format(len(self.train)))
else:
print(" # train tracklets: {}".format(len(self.train)))
print(" # train cameras : {}".format(self._num_train_cams))
print(" test names : {}".format(self.test_names))
print(" test names : {}".format(self.target_names))
print(" *****************************************")
print("\n")