mirror of https://github.com/JDAI-CV/fast-reid.git
add binary target and multi target in ShoeDataset
parent
95755377ab
commit
c576948ec5
fastreid/data
projects/Shoe/shoe/data
|
@ -212,13 +212,16 @@ def pair_batch_collator(batched_inputs):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
targets = []
|
binary_targets = []
|
||||||
|
multi_targets = []
|
||||||
for elem in batched_inputs:
|
for elem in batched_inputs:
|
||||||
images.append(elem['img1'])
|
images.append(elem['img1'])
|
||||||
images.append(elem['img2'])
|
images.append(elem['img2'])
|
||||||
targets.append(elem['target'])
|
binary_targets.append(elem['binary_target'])
|
||||||
|
multi_targets += elem['multi_target']
|
||||||
|
|
||||||
images = torch.stack(images, dim=0)
|
images = torch.stack(images, dim=0)
|
||||||
targets = torch.tensor(targets)
|
binary_targets = torch.tensor(binary_targets)
|
||||||
return {'images': images, 'targets': targets}
|
multi_targets = torch.tensor(multi_targets)
|
||||||
|
return {'images': images, 'binary_targets': binary_targets, 'multi_targets': multi_targets}
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,8 @@ class ExcelDataset(ImageDataset):
|
||||||
return {
|
return {
|
||||||
'img1': img1,
|
'img1': img1,
|
||||||
'img2': img2,
|
'img2': img2,
|
||||||
'target': label
|
'binary_target': label,
|
||||||
|
'multi_target': [label, label] # multi_target is positional and ununsed
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -30,11 +30,16 @@ class ShoePairDataset(ShoeDataset):
|
||||||
|
|
||||||
self.pos_folders = []
|
self.pos_folders = []
|
||||||
self.neg_folders = []
|
self.neg_folders = []
|
||||||
|
self.image_label_dict = {}
|
||||||
for data in self.all_data:
|
for data in self.all_data:
|
||||||
if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1:
|
if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1:
|
||||||
self.pos_folders.append(data['positive_img_list'])
|
self.pos_folders.append(data['positive_img_list'])
|
||||||
self.neg_folders.append(data['negative_img_list'])
|
self.neg_folders.append(data['negative_img_list'])
|
||||||
|
|
||||||
|
for idx, folder in enumerate(self.pos_folders):
|
||||||
|
for img_path in folder:
|
||||||
|
self.image_label_dict[img_path] = idx
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.pos_folders)
|
return len(self.pos_folders)
|
||||||
|
|
||||||
|
@ -72,6 +77,13 @@ class ShoePairDataset(ShoeDataset):
|
||||||
img_path1, img_path2 = random.choice(pf), random.choice(nf)
|
img_path1, img_path2 = random.choice(pf), random.choice(nf)
|
||||||
|
|
||||||
|
|
||||||
|
if label == 1:
|
||||||
|
multi_label = [self.image_label_dict[img_path1], self.image_label_dict[img_path1]]
|
||||||
|
else:
|
||||||
|
# -1 indicate it is a negative sample which has no multi class label
|
||||||
|
# this negative sample will be ignored in computing multi class related loss
|
||||||
|
multi_label = [self.image_label_dict[img_path1], -1]
|
||||||
|
|
||||||
img_path1 = os.path.join(self.img_root, img_path1)
|
img_path1 = os.path.join(self.img_root, img_path1)
|
||||||
img1 = read_image(img_path1)
|
img1 = read_image(img_path1)
|
||||||
|
|
||||||
|
@ -91,12 +103,14 @@ class ShoePairDataset(ShoeDataset):
|
||||||
return {
|
return {
|
||||||
'img1': img1,
|
'img1': img1,
|
||||||
'img2': img2,
|
'img2': img2,
|
||||||
'target': label
|
'binary_target': label,
|
||||||
|
'multi_target': multi_label
|
||||||
}
|
}
|
||||||
|
|
||||||
#-------------下面是辅助信息------------------#
|
#-------------下面是辅助信息------------------#
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
|
# return len(self.pos_folders)
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
Loading…
Reference in New Issue