mirror of https://github.com/JDAI-CV/fast-reid.git
make CrossEntropy and Focal Loss filter -1 label
parent
c576948ec5
commit
fc29be547c
|
@ -3,17 +3,21 @@
|
|||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : contrastive_loss.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .utils import normalize, euclidean_dist
|
||||
|
||||
__all__ = ['contrastive_loss']
|
||||
|
||||
|
||||
def contrastive_loss(
|
||||
query_feat: torch.Tensor,
|
||||
gallery_feat: torch.Tensor,
|
||||
feats: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
margin: float) -> torch.Tensor:
|
||||
feats_len = feats.size(0)
|
||||
feats = F.normalize(feats, dim=1)
|
||||
query_feat = feats[0:feats_len:2, :]
|
||||
gallery_feat = feats[1:feats_len:2, :]
|
||||
distance = torch.sqrt(torch.sum(torch.pow(query_feat - gallery_feat, 2), -1))
|
||||
loss1 = 0.5 * targets * torch.pow(distance, 2)
|
||||
loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=0), 2)
|
||||
loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=1e-6), 2)
|
||||
return torch.mean(loss1 + loss2)
|
||||
|
|
|
@ -13,6 +13,16 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
|
|||
"""
|
||||
Log the accuracy metrics to EventStorage.
|
||||
"""
|
||||
storage = get_event_storage()
|
||||
|
||||
index = torch.where(gt_classes != -1)[0]
|
||||
if len(index) == 0:
|
||||
storage.put_scalar("cls_accuracy", 0)
|
||||
return
|
||||
|
||||
pred_class_logits = pred_class_logits[index, :]
|
||||
gt_classes = gt_classes[index]
|
||||
|
||||
bsz = pred_class_logits.size(0)
|
||||
maxk = max(topk)
|
||||
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
|
||||
|
@ -24,13 +34,19 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
|
|||
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
|
||||
ret.append(correct_k.mul_(1. / bsz))
|
||||
|
||||
storage = get_event_storage()
|
||||
storage.put_scalar("cls_accuracy", ret[0])
|
||||
|
||||
|
||||
def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
|
||||
num_classes = pred_class_outputs.size(1)
|
||||
|
||||
index = torch.where(gt_classes != -1)[0]
|
||||
if len(index) == 0:
|
||||
return torch.tensor(0, dtype=torch.float32, device=pred_class_outputs.device)
|
||||
|
||||
pred_class_outputs = pred_class_outputs[index, :]
|
||||
gt_classes = gt_classes[index]
|
||||
|
||||
if eps >= 0:
|
||||
smooth_param = eps
|
||||
else:
|
||||
|
|
|
@ -57,6 +57,10 @@ def focal_loss(
|
|||
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
|
||||
.format(input.size(0), target.size(0)))
|
||||
|
||||
index = torch.where(target != -1)[0]
|
||||
target = target[index]
|
||||
input = input[index, :]
|
||||
|
||||
n = input.size(0)
|
||||
out_size = (n,) + input.size()[2:]
|
||||
if target.size()[1:] != input.size()[2:]:
|
||||
|
@ -96,7 +100,6 @@ def binary_focal_loss(inputs, targets, alpha=0.25, gamma=2):
|
|||
'''
|
||||
Reference: https://github.com/tensorflow/addons/blob/v0.14.0/tensorflow_addons/losses/focal_loss.py
|
||||
'''
|
||||
# __import__('ipdb').set_trace()
|
||||
if alpha < 0:
|
||||
raise ValueError(f'Value of alpha should be greater than or equal to zero, but get {alpha}')
|
||||
if gamma < 0:
|
||||
|
|
Loading…
Reference in New Issue