From fc29be547cc667c99d7fd048602482b68137da09 Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Wed, 24 Nov 2021 20:11:23 +0800 Subject: [PATCH] make CrossEntropy and Focal Loss filter -1 label --- fastreid/modeling/losses/contrastive_loss.py | 10 +++++++--- fastreid/modeling/losses/cross_entroy_loss.py | 18 +++++++++++++++++- fastreid/modeling/losses/focal_loss.py | 5 ++++- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/fastreid/modeling/losses/contrastive_loss.py b/fastreid/modeling/losses/contrastive_loss.py index 2476dbe..0947903 100644 --- a/fastreid/modeling/losses/contrastive_loss.py +++ b/fastreid/modeling/losses/contrastive_loss.py @@ -3,17 +3,21 @@ # @Author : zuchen.wang@vipshop.com # @File : contrastive_loss.py import torch +import torch.nn.functional as F from .utils import normalize, euclidean_dist __all__ = ['contrastive_loss'] def contrastive_loss( - query_feat: torch.Tensor, - gallery_feat: torch.Tensor, + feats: torch.Tensor, targets: torch.Tensor, margin: float) -> torch.Tensor: + feats_len = feats.size(0) + feats = F.normalize(feats, dim=1) + query_feat = feats[0:feats_len:2, :] + gallery_feat = feats[1:feats_len:2, :] distance = torch.sqrt(torch.sum(torch.pow(query_feat - gallery_feat, 2), -1)) loss1 = 0.5 * targets * torch.pow(distance, 2) - loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=0), 2) + loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=1e-6), 2) return torch.mean(loss1 + loss2) diff --git a/fastreid/modeling/losses/cross_entroy_loss.py b/fastreid/modeling/losses/cross_entroy_loss.py index 9986d33..ced5c2d 100644 --- a/fastreid/modeling/losses/cross_entroy_loss.py +++ b/fastreid/modeling/losses/cross_entroy_loss.py @@ -13,6 +13,16 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)): """ Log the accuracy metrics to EventStorage. """ + storage = get_event_storage() + + index = torch.where(gt_classes != -1)[0] + if len(index) == 0: + storage.put_scalar("cls_accuracy", 0) + return + + pred_class_logits = pred_class_logits[index, :] + gt_classes = gt_classes[index] + bsz = pred_class_logits.size(0) maxk = max(topk) _, pred_class = pred_class_logits.topk(maxk, 1, True, True) @@ -24,13 +34,19 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)): correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) ret.append(correct_k.mul_(1. / bsz)) - storage = get_event_storage() storage.put_scalar("cls_accuracy", ret[0]) def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2): num_classes = pred_class_outputs.size(1) + index = torch.where(gt_classes != -1)[0] + if len(index) == 0: + return torch.tensor(0, dtype=torch.float32, device=pred_class_outputs.device) + + pred_class_outputs = pred_class_outputs[index, :] + gt_classes = gt_classes[index] + if eps >= 0: smooth_param = eps else: diff --git a/fastreid/modeling/losses/focal_loss.py b/fastreid/modeling/losses/focal_loss.py index 5a481b9..ad6d99a 100644 --- a/fastreid/modeling/losses/focal_loss.py +++ b/fastreid/modeling/losses/focal_loss.py @@ -57,6 +57,10 @@ def focal_loss( raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' .format(input.size(0), target.size(0))) + index = torch.where(target != -1)[0] + target = target[index] + input = input[index, :] + n = input.size(0) out_size = (n,) + input.size()[2:] if target.size()[1:] != input.size()[2:]: @@ -96,7 +100,6 @@ def binary_focal_loss(inputs, targets, alpha=0.25, gamma=2): ''' Reference: https://github.com/tensorflow/addons/blob/v0.14.0/tensorflow_addons/losses/focal_loss.py ''' -# __import__('ipdb').set_trace() if alpha < 0: raise ValueError(f'Value of alpha should be greater than or equal to zero, but get {alpha}') if gamma < 0: