From c576948ec56b904c66173b9d33f35388f9c384d8 Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Wed, 24 Nov 2021 20:07:31 +0800 Subject: [PATCH] add binary target and multi target in ShoeDataset --- fastreid/data/build.py | 11 +++++++---- projects/Shoe/shoe/data/excel_dataset.py | 3 ++- projects/Shoe/shoe/data/shoe_pair.py | 16 +++++++++++++++- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/fastreid/data/build.py b/fastreid/data/build.py index 2d89064..d7d5c35 100644 --- a/fastreid/data/build.py +++ b/fastreid/data/build.py @@ -212,13 +212,16 @@ def pair_batch_collator(batched_inputs): """ images = [] - targets = [] + binary_targets = [] + multi_targets = [] for elem in batched_inputs: images.append(elem['img1']) images.append(elem['img2']) - targets.append(elem['target']) + binary_targets.append(elem['binary_target']) + multi_targets += elem['multi_target'] images = torch.stack(images, dim=0) - targets = torch.tensor(targets) - return {'images': images, 'targets': targets} + binary_targets = torch.tensor(binary_targets) + multi_targets = torch.tensor(multi_targets) + return {'images': images, 'binary_targets': binary_targets, 'multi_targets': multi_targets} diff --git a/projects/Shoe/shoe/data/excel_dataset.py b/projects/Shoe/shoe/data/excel_dataset.py index 3f26e2c..70ff878 100644 --- a/projects/Shoe/shoe/data/excel_dataset.py +++ b/projects/Shoe/shoe/data/excel_dataset.py @@ -46,7 +46,8 @@ class ExcelDataset(ImageDataset): return { 'img1': img1, 'img2': img2, - 'target': label + 'binary_target': label, + 'multi_target': [label, label] # multi_target is positional and ununsed } def __len__(self): diff --git a/projects/Shoe/shoe/data/shoe_pair.py b/projects/Shoe/shoe/data/shoe_pair.py index 25563a1..979b6b5 100644 --- a/projects/Shoe/shoe/data/shoe_pair.py +++ b/projects/Shoe/shoe/data/shoe_pair.py @@ -30,11 +30,16 @@ class ShoePairDataset(ShoeDataset): self.pos_folders = [] self.neg_folders = [] + self.image_label_dict = {} for data in self.all_data: if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1: self.pos_folders.append(data['positive_img_list']) self.neg_folders.append(data['negative_img_list']) + for idx, folder in enumerate(self.pos_folders): + for img_path in folder: + self.image_label_dict[img_path] = idx + def __len__(self): return len(self.pos_folders) @@ -72,6 +77,13 @@ class ShoePairDataset(ShoeDataset): img_path1, img_path2 = random.choice(pf), random.choice(nf) + if label == 1: + multi_label = [self.image_label_dict[img_path1], self.image_label_dict[img_path1]] + else: + # -1 indicate it is a negative sample which has no multi class label + # this negative sample will be ignored in computing multi class related loss + multi_label = [self.image_label_dict[img_path1], -1] + img_path1 = os.path.join(self.img_root, img_path1) img1 = read_image(img_path1) @@ -91,12 +103,14 @@ class ShoePairDataset(ShoeDataset): return { 'img1': img1, 'img2': img2, - 'target': label + 'binary_target': label, + 'multi_target': multi_label } #-------------下面是辅助信息------------------# @property def num_classes(self): + # return len(self.pos_folders) return 2 @property