mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
add pair collator_fn
This commit is contained in:
parent
bced0d04ff
commit
7201a82840
@ -91,7 +91,7 @@ def build_reid_train_loader(
|
||||
dataset=train_set,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=fast_batch_collator,
|
||||
collate_fn=pair_batch_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
@ -152,7 +152,7 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
|
||||
dataset=test_set,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers, # save some memory
|
||||
collate_fn=fast_batch_collator,
|
||||
collate_fn=pair_batch_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
return test_loader, num_query
|
||||
@ -185,3 +185,21 @@ def fast_batch_collator(batched_inputs):
|
||||
return torch.tensor(batched_inputs)
|
||||
elif isinstance(elem, string_classes):
|
||||
return batched_inputs
|
||||
|
||||
|
||||
def pair_batch_collator(batched_inputs):
|
||||
"""
|
||||
A pair batch collator for paired tasks:
|
||||
"""
|
||||
|
||||
images = []
|
||||
targets = []
|
||||
for elem in batched_inputs:
|
||||
images.append(elem['img1'])
|
||||
images.append(elem['img2'])
|
||||
targets.append(elem['target'])
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
targets = torch.tensor(targets)
|
||||
return {'images': images, 'targets': targets}
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from fastreid.data.data_utils import read_image
|
||||
@ -28,6 +29,7 @@ class PairDataset(Dataset):
|
||||
img_path1, img_path2 = random.sample(pf, 2)
|
||||
else:
|
||||
# generate negative pair
|
||||
label = 0
|
||||
img_path1, img_path2 = random.choice(pf), random.choice(nf)
|
||||
|
||||
img_path1 = os.path.join(self.img_root, img_path1)
|
||||
@ -45,3 +47,8 @@ class PairDataset(Dataset):
|
||||
'img2': img2,
|
||||
'target': label
|
||||
}
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user