make CrossEntropy and Focal Loss filter -1 label

pull/608/head
zuchen.wang 2021-11-24 20:11:23 +08:00
parent c576948ec5
commit fc29be547c
3 changed files with 28 additions and 5 deletions

View File

@ -3,17 +3,21 @@
# @Author : zuchen.wang@vipshop.com
# @File : contrastive_loss.py
import torch
import torch.nn.functional as F
from .utils import normalize, euclidean_dist
__all__ = ['contrastive_loss']
def contrastive_loss(
query_feat: torch.Tensor,
gallery_feat: torch.Tensor,
feats: torch.Tensor,
targets: torch.Tensor,
margin: float) -> torch.Tensor:
feats_len = feats.size(0)
feats = F.normalize(feats, dim=1)
query_feat = feats[0:feats_len:2, :]
gallery_feat = feats[1:feats_len:2, :]
distance = torch.sqrt(torch.sum(torch.pow(query_feat - gallery_feat, 2), -1))
loss1 = 0.5 * targets * torch.pow(distance, 2)
loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=0), 2)
loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=1e-6), 2)
return torch.mean(loss1 + loss2)

View File

@ -13,6 +13,16 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
"""
Log the accuracy metrics to EventStorage.
"""
storage = get_event_storage()
index = torch.where(gt_classes != -1)[0]
if len(index) == 0:
storage.put_scalar("cls_accuracy", 0)
return
pred_class_logits = pred_class_logits[index, :]
gt_classes = gt_classes[index]
bsz = pred_class_logits.size(0)
maxk = max(topk)
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
@ -24,13 +34,19 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
ret.append(correct_k.mul_(1. / bsz))
storage = get_event_storage()
storage.put_scalar("cls_accuracy", ret[0])
def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
num_classes = pred_class_outputs.size(1)
index = torch.where(gt_classes != -1)[0]
if len(index) == 0:
return torch.tensor(0, dtype=torch.float32, device=pred_class_outputs.device)
pred_class_outputs = pred_class_outputs[index, :]
gt_classes = gt_classes[index]
if eps >= 0:
smooth_param = eps
else:

View File

@ -57,6 +57,10 @@ def focal_loss(
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
.format(input.size(0), target.size(0)))
index = torch.where(target != -1)[0]
target = target[index]
input = input[index, :]
n = input.size(0)
out_size = (n,) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
@ -96,7 +100,6 @@ def binary_focal_loss(inputs, targets, alpha=0.25, gamma=2):
'''
Reference: https://github.com/tensorflow/addons/blob/v0.14.0/tensorflow_addons/losses/focal_loss.py
'''
# __import__('ipdb').set_trace()
if alpha < 0:
raise ValueError(f'Value of alpha should be greater than or equal to zero, but get {alpha}')
if gamma < 0: