mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
make CrossEntropy and Focal Loss filter -1 label
This commit is contained in:
parent
c576948ec5
commit
fc29be547c
@ -3,17 +3,21 @@
|
|||||||
# @Author : zuchen.wang@vipshop.com
|
# @Author : zuchen.wang@vipshop.com
|
||||||
# @File : contrastive_loss.py
|
# @File : contrastive_loss.py
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from .utils import normalize, euclidean_dist
|
from .utils import normalize, euclidean_dist
|
||||||
|
|
||||||
__all__ = ['contrastive_loss']
|
__all__ = ['contrastive_loss']
|
||||||
|
|
||||||
|
|
||||||
def contrastive_loss(
|
def contrastive_loss(
|
||||||
query_feat: torch.Tensor,
|
feats: torch.Tensor,
|
||||||
gallery_feat: torch.Tensor,
|
|
||||||
targets: torch.Tensor,
|
targets: torch.Tensor,
|
||||||
margin: float) -> torch.Tensor:
|
margin: float) -> torch.Tensor:
|
||||||
|
feats_len = feats.size(0)
|
||||||
|
feats = F.normalize(feats, dim=1)
|
||||||
|
query_feat = feats[0:feats_len:2, :]
|
||||||
|
gallery_feat = feats[1:feats_len:2, :]
|
||||||
distance = torch.sqrt(torch.sum(torch.pow(query_feat - gallery_feat, 2), -1))
|
distance = torch.sqrt(torch.sum(torch.pow(query_feat - gallery_feat, 2), -1))
|
||||||
loss1 = 0.5 * targets * torch.pow(distance, 2)
|
loss1 = 0.5 * targets * torch.pow(distance, 2)
|
||||||
loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=0), 2)
|
loss2 = 0.5 * (1 - targets) * torch.pow(torch.clamp(margin - distance, min=1e-6), 2)
|
||||||
return torch.mean(loss1 + loss2)
|
return torch.mean(loss1 + loss2)
|
||||||
|
@ -13,6 +13,16 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
|
|||||||
"""
|
"""
|
||||||
Log the accuracy metrics to EventStorage.
|
Log the accuracy metrics to EventStorage.
|
||||||
"""
|
"""
|
||||||
|
storage = get_event_storage()
|
||||||
|
|
||||||
|
index = torch.where(gt_classes != -1)[0]
|
||||||
|
if len(index) == 0:
|
||||||
|
storage.put_scalar("cls_accuracy", 0)
|
||||||
|
return
|
||||||
|
|
||||||
|
pred_class_logits = pred_class_logits[index, :]
|
||||||
|
gt_classes = gt_classes[index]
|
||||||
|
|
||||||
bsz = pred_class_logits.size(0)
|
bsz = pred_class_logits.size(0)
|
||||||
maxk = max(topk)
|
maxk = max(topk)
|
||||||
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
|
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
|
||||||
@ -24,13 +34,19 @@ def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
|
|||||||
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
|
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
|
||||||
ret.append(correct_k.mul_(1. / bsz))
|
ret.append(correct_k.mul_(1. / bsz))
|
||||||
|
|
||||||
storage = get_event_storage()
|
|
||||||
storage.put_scalar("cls_accuracy", ret[0])
|
storage.put_scalar("cls_accuracy", ret[0])
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
|
def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
|
||||||
num_classes = pred_class_outputs.size(1)
|
num_classes = pred_class_outputs.size(1)
|
||||||
|
|
||||||
|
index = torch.where(gt_classes != -1)[0]
|
||||||
|
if len(index) == 0:
|
||||||
|
return torch.tensor(0, dtype=torch.float32, device=pred_class_outputs.device)
|
||||||
|
|
||||||
|
pred_class_outputs = pred_class_outputs[index, :]
|
||||||
|
gt_classes = gt_classes[index]
|
||||||
|
|
||||||
if eps >= 0:
|
if eps >= 0:
|
||||||
smooth_param = eps
|
smooth_param = eps
|
||||||
else:
|
else:
|
||||||
|
@ -57,6 +57,10 @@ def focal_loss(
|
|||||||
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
|
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
|
||||||
.format(input.size(0), target.size(0)))
|
.format(input.size(0), target.size(0)))
|
||||||
|
|
||||||
|
index = torch.where(target != -1)[0]
|
||||||
|
target = target[index]
|
||||||
|
input = input[index, :]
|
||||||
|
|
||||||
n = input.size(0)
|
n = input.size(0)
|
||||||
out_size = (n,) + input.size()[2:]
|
out_size = (n,) + input.size()[2:]
|
||||||
if target.size()[1:] != input.size()[2:]:
|
if target.size()[1:] != input.size()[2:]:
|
||||||
@ -96,7 +100,6 @@ def binary_focal_loss(inputs, targets, alpha=0.25, gamma=2):
|
|||||||
'''
|
'''
|
||||||
Reference: https://github.com/tensorflow/addons/blob/v0.14.0/tensorflow_addons/losses/focal_loss.py
|
Reference: https://github.com/tensorflow/addons/blob/v0.14.0/tensorflow_addons/losses/focal_loss.py
|
||||||
'''
|
'''
|
||||||
# __import__('ipdb').set_trace()
|
|
||||||
if alpha < 0:
|
if alpha < 0:
|
||||||
raise ValueError(f'Value of alpha should be greater than or equal to zero, but get {alpha}')
|
raise ValueError(f'Value of alpha should be greater than or equal to zero, but get {alpha}')
|
||||||
if gamma < 0:
|
if gamma < 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user