mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
add pair collator_fn
This commit is contained in:
parent
bced0d04ff
commit
7201a82840
@ -91,7 +91,7 @@ def build_reid_train_loader(
|
|||||||
dataset=train_set,
|
dataset=train_set,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
batch_sampler=batch_sampler,
|
batch_sampler=batch_sampler,
|
||||||
collate_fn=fast_batch_collator,
|
collate_fn=pair_batch_collator,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
|
|||||||
dataset=test_set,
|
dataset=test_set,
|
||||||
batch_sampler=batch_sampler,
|
batch_sampler=batch_sampler,
|
||||||
num_workers=num_workers, # save some memory
|
num_workers=num_workers, # save some memory
|
||||||
collate_fn=fast_batch_collator,
|
collate_fn=pair_batch_collator,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
return test_loader, num_query
|
return test_loader, num_query
|
||||||
@ -185,3 +185,21 @@ def fast_batch_collator(batched_inputs):
|
|||||||
return torch.tensor(batched_inputs)
|
return torch.tensor(batched_inputs)
|
||||||
elif isinstance(elem, string_classes):
|
elif isinstance(elem, string_classes):
|
||||||
return batched_inputs
|
return batched_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def pair_batch_collator(batched_inputs):
|
||||||
|
"""
|
||||||
|
A pair batch collator for paired tasks:
|
||||||
|
"""
|
||||||
|
|
||||||
|
images = []
|
||||||
|
targets = []
|
||||||
|
for elem in batched_inputs:
|
||||||
|
images.append(elem['img1'])
|
||||||
|
images.append(elem['img2'])
|
||||||
|
targets.append(elem['target'])
|
||||||
|
|
||||||
|
images = torch.stack(images, dim=0)
|
||||||
|
targets = torch.tensor(targets)
|
||||||
|
return {'images': images, 'targets': targets}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from fastreid.data.data_utils import read_image
|
from fastreid.data.data_utils import read_image
|
||||||
@ -28,6 +29,7 @@ class PairDataset(Dataset):
|
|||||||
img_path1, img_path2 = random.sample(pf, 2)
|
img_path1, img_path2 = random.sample(pf, 2)
|
||||||
else:
|
else:
|
||||||
# generate negative pair
|
# generate negative pair
|
||||||
|
label = 0
|
||||||
img_path1, img_path2 = random.choice(pf), random.choice(nf)
|
img_path1, img_path2 = random.choice(pf), random.choice(nf)
|
||||||
|
|
||||||
img_path1 = os.path.join(self.img_root, img_path1)
|
img_path1 = os.path.join(self.img_root, img_path1)
|
||||||
@ -45,3 +47,8 @@ class PairDataset(Dataset):
|
|||||||
'img2': img2,
|
'img2': img2,
|
||||||
'target': label
|
'target': label
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return 2
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user