From 1bd41a79fc2285a1a0bc6e5f777b6f19bca677e8 Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Tue, 12 Oct 2021 10:39:04 +0800 Subject: [PATCH] add cls target in pair dataset --- fastreid/data/build.py | 5 ++++- fastreid/modeling/losses/contrastive_loss.py | 6 +++--- fastreid/modeling/meta_arch/baseline.py | 14 ++++++++++---- projects/FastShoe/fastshoe/data/pair_dataset.py | 6 +++--- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/fastreid/data/build.py b/fastreid/data/build.py index faab320..3cf8a00 100644 --- a/fastreid/data/build.py +++ b/fastreid/data/build.py @@ -194,12 +194,15 @@ def pair_batch_collator(batched_inputs): images = [] targets = [] + cls_targets = [] for elem in batched_inputs: images.append(elem['img1']) images.append(elem['img2']) targets.append(elem['target']) + cls_targets.append(elem['cls_target']) images = torch.stack(images, dim=0) targets = torch.tensor(targets) - return {'images': images, 'targets': targets} + cls_targets = torch.tensor(cls_targets) + return {'images': images, 'targets': targets, 'cls_targets': cls_targets} diff --git a/fastreid/modeling/losses/contrastive_loss.py b/fastreid/modeling/losses/contrastive_loss.py index bfeaea3..52b00fd 100644 --- a/fastreid/modeling/losses/contrastive_loss.py +++ b/fastreid/modeling/losses/contrastive_loss.py @@ -3,7 +3,7 @@ # @Author : zuchen.wang@vipshop.com # @File : contrastive_loss.py import torch -import torch.nn.functional as F +from .utils import normalize, euclidean_dist __all__ = ['contrastive_loss'] @@ -12,9 +12,9 @@ def contrastive_loss( embedding: torch.Tensor, targets: torch.Tensor, margin: float) -> torch.Tensor: - embedding = F.normalize(embedding, dim=1) + embedding = normalize(embedding, dim=1) embed1 = embedding[0:len(embedding):2, :] embed2 = embedding[1:len(embedding):2, :] - euclidean_distance = F.pairwise_distance(embed1, embed2) + euclidean_distance = euclidean_dist(embed1, embed2) return torch.mean(targets * torch.pow(euclidean_distance, 2) + (1 - targets) * torch.pow(torch.clamp(margin - euclidean_distance, min=0), 2)) diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py index 631ba52..8bb239d 100644 --- a/fastreid/modeling/meta_arch/baseline.py +++ b/fastreid/modeling/meta_arch/baseline.py @@ -108,13 +108,19 @@ class Baseline(nn.Module): assert "targets" in batched_inputs, "Person ID annotation are missing in training!" targets = batched_inputs["targets"] + if "cls_targets" in batched_inputs: + cls_targets = batched_inputs['cls_targets'] + # PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset # may be larger than that in the original dataset, so the circle/arcface will # throw an error. We just set all the targets to 0 to avoid this problem. if targets.sum() < 0: targets.zero_() - outputs = self.heads(features, targets) - losses = self.losses(outputs, targets) + if "cls_targets" in batched_inputs: + outputs = self.heads(features, cls_targets) + else: + outputs = self.heads(features, targets) + losses = self.losses(outputs, cls_targets, targets) return losses else: outputs = self.heads(features) @@ -134,7 +140,7 @@ class Baseline(nn.Module): images.sub_(self.pixel_mean).div_(self.pixel_std) return images - def losses(self, outputs, gt_labels): + def losses(self, outputs, gt_cls_labels, gt_metric_labels): """ Compute loss from modeling's outputs, the loss function input arguments must be the same as the outputs of the model forwarding. @@ -147,7 +153,7 @@ class Baseline(nn.Module): # fmt: on # Log prediction accuracy - # log_accuracy(pred_class_logits, gt_labels) + log_accuracy(pred_class_logits, gt_cls_labels) loss_dict = {} loss_names = self.loss_kwargs['loss_names'] diff --git a/projects/FastShoe/fastshoe/data/pair_dataset.py b/projects/FastShoe/fastshoe/data/pair_dataset.py index ae25fcd..7aa3c0a 100644 --- a/projects/FastShoe/fastshoe/data/pair_dataset.py +++ b/projects/FastShoe/fastshoe/data/pair_dataset.py @@ -45,10 +45,10 @@ class PairDataset(Dataset): return { 'img1': img1, 'img2': img2, - 'target': label + 'target': label, + 'cls_target': idx } @property def num_classes(self): - return 2 - + return len(self.pos_folders)