add pair collator_fn

This commit is contained in:
zuchen.wang 2021-10-11 15:34:44 +08:00
parent bced0d04ff
commit 7201a82840
2 changed files with 27 additions and 2 deletions

View File

@ -91,7 +91,7 @@ def build_reid_train_loader(
dataset=train_set,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
collate_fn=pair_batch_collator,
pin_memory=True,
)
@ -152,7 +152,7 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
dataset=test_set,
batch_sampler=batch_sampler,
num_workers=num_workers, # save some memory
collate_fn=fast_batch_collator,
collate_fn=pair_batch_collator,
pin_memory=True,
)
return test_loader, num_query
@ -185,3 +185,21 @@ def fast_batch_collator(batched_inputs):
return torch.tensor(batched_inputs)
elif isinstance(elem, string_classes):
return batched_inputs
def pair_batch_collator(batched_inputs):
"""
A pair batch collator for paired tasks:
"""
images = []
targets = []
for elem in batched_inputs:
images.append(elem['img1'])
images.append(elem['img2'])
targets.append(elem['target'])
images = torch.stack(images, dim=0)
targets = torch.tensor(targets)
return {'images': images, 'targets': targets}

View File

@ -5,6 +5,7 @@
import os
import random
import torch
from torch.utils.data import Dataset
from fastreid.data.data_utils import read_image
@ -28,6 +29,7 @@ class PairDataset(Dataset):
img_path1, img_path2 = random.sample(pf, 2)
else:
# generate negative pair
label = 0
img_path1, img_path2 = random.choice(pf), random.choice(nf)
img_path1 = os.path.join(self.img_root, img_path1)
@ -45,3 +47,8 @@ class PairDataset(Dataset):
'img2': img2,
'target': label
}
@property
def num_classes(self):
return 2