mirror of https://github.com/JDAI-CV/fast-reid.git
add binary target and multi target in ShoeDataset
parent
95755377ab
commit
c576948ec5
|
@ -212,13 +212,16 @@ def pair_batch_collator(batched_inputs):
|
|||
"""
|
||||
|
||||
images = []
|
||||
targets = []
|
||||
binary_targets = []
|
||||
multi_targets = []
|
||||
for elem in batched_inputs:
|
||||
images.append(elem['img1'])
|
||||
images.append(elem['img2'])
|
||||
targets.append(elem['target'])
|
||||
binary_targets.append(elem['binary_target'])
|
||||
multi_targets += elem['multi_target']
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
targets = torch.tensor(targets)
|
||||
return {'images': images, 'targets': targets}
|
||||
binary_targets = torch.tensor(binary_targets)
|
||||
multi_targets = torch.tensor(multi_targets)
|
||||
return {'images': images, 'binary_targets': binary_targets, 'multi_targets': multi_targets}
|
||||
|
||||
|
|
|
@ -46,7 +46,8 @@ class ExcelDataset(ImageDataset):
|
|||
return {
|
||||
'img1': img1,
|
||||
'img2': img2,
|
||||
'target': label
|
||||
'binary_target': label,
|
||||
'multi_target': [label, label] # multi_target is positional and ununsed
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -30,11 +30,16 @@ class ShoePairDataset(ShoeDataset):
|
|||
|
||||
self.pos_folders = []
|
||||
self.neg_folders = []
|
||||
self.image_label_dict = {}
|
||||
for data in self.all_data:
|
||||
if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1:
|
||||
self.pos_folders.append(data['positive_img_list'])
|
||||
self.neg_folders.append(data['negative_img_list'])
|
||||
|
||||
for idx, folder in enumerate(self.pos_folders):
|
||||
for img_path in folder:
|
||||
self.image_label_dict[img_path] = idx
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pos_folders)
|
||||
|
||||
|
@ -72,6 +77,13 @@ class ShoePairDataset(ShoeDataset):
|
|||
img_path1, img_path2 = random.choice(pf), random.choice(nf)
|
||||
|
||||
|
||||
if label == 1:
|
||||
multi_label = [self.image_label_dict[img_path1], self.image_label_dict[img_path1]]
|
||||
else:
|
||||
# -1 indicate it is a negative sample which has no multi class label
|
||||
# this negative sample will be ignored in computing multi class related loss
|
||||
multi_label = [self.image_label_dict[img_path1], -1]
|
||||
|
||||
img_path1 = os.path.join(self.img_root, img_path1)
|
||||
img1 = read_image(img_path1)
|
||||
|
||||
|
@ -91,12 +103,14 @@ class ShoePairDataset(ShoeDataset):
|
|||
return {
|
||||
'img1': img1,
|
||||
'img2': img2,
|
||||
'target': label
|
||||
'binary_target': label,
|
||||
'multi_target': multi_label
|
||||
}
|
||||
|
||||
#-------------下面是辅助信息------------------#
|
||||
@property
|
||||
def num_classes(self):
|
||||
# return len(self.pos_folders)
|
||||
return 2
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue