add binary target and multi target in ShoeDataset

pull/608/head
zuchen.wang 2021-11-24 20:07:31 +08:00
parent 95755377ab
commit c576948ec5
3 changed files with 24 additions and 6 deletions

View File

@ -212,13 +212,16 @@ def pair_batch_collator(batched_inputs):
"""
images = []
targets = []
binary_targets = []
multi_targets = []
for elem in batched_inputs:
images.append(elem['img1'])
images.append(elem['img2'])
targets.append(elem['target'])
binary_targets.append(elem['binary_target'])
multi_targets += elem['multi_target']
images = torch.stack(images, dim=0)
targets = torch.tensor(targets)
return {'images': images, 'targets': targets}
binary_targets = torch.tensor(binary_targets)
multi_targets = torch.tensor(multi_targets)
return {'images': images, 'binary_targets': binary_targets, 'multi_targets': multi_targets}

View File

@ -46,7 +46,8 @@ class ExcelDataset(ImageDataset):
return {
'img1': img1,
'img2': img2,
'target': label
'binary_target': label,
'multi_target': [label, label] # multi_target is positional and ununsed
}
def __len__(self):

View File

@ -30,11 +30,16 @@ class ShoePairDataset(ShoeDataset):
self.pos_folders = []
self.neg_folders = []
self.image_label_dict = {}
for data in self.all_data:
if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1:
self.pos_folders.append(data['positive_img_list'])
self.neg_folders.append(data['negative_img_list'])
for idx, folder in enumerate(self.pos_folders):
for img_path in folder:
self.image_label_dict[img_path] = idx
def __len__(self):
return len(self.pos_folders)
@ -72,6 +77,13 @@ class ShoePairDataset(ShoeDataset):
img_path1, img_path2 = random.choice(pf), random.choice(nf)
if label == 1:
multi_label = [self.image_label_dict[img_path1], self.image_label_dict[img_path1]]
else:
# -1 indicate it is a negative sample which has no multi class label
# this negative sample will be ignored in computing multi class related loss
multi_label = [self.image_label_dict[img_path1], -1]
img_path1 = os.path.join(self.img_root, img_path1)
img1 = read_image(img_path1)
@ -91,12 +103,14 @@ class ShoePairDataset(ShoeDataset):
return {
'img1': img1,
'img2': img2,
'target': label
'binary_target': label,
'multi_target': multi_label
}
#-------------下面是辅助信息------------------#
@property
def num_classes(self):
# return len(self.pos_folders)
return 2
@property