mirror of https://github.com/JDAI-CV/fast-reid.git
34 lines
974 B
Python
34 lines
974 B
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def ratio2weight(targets, ratio):
|
|
pos_weights = targets * (1 - ratio)
|
|
neg_weights = (1 - targets) * ratio
|
|
weights = torch.exp(neg_weights + pos_weights)
|
|
|
|
weights[targets > 1] = 0.0
|
|
return weights
|
|
|
|
|
|
def cross_entropy_sigmoid_loss(pred_class_logits, gt_classes, sample_weight=None):
|
|
loss = F.binary_cross_entropy_with_logits(pred_class_logits, gt_classes, reduction='none')
|
|
|
|
if sample_weight is not None:
|
|
targets_mask = torch.where(gt_classes.detach() > 0.5,
|
|
torch.ones(1, device="cuda"), torch.zeros(1, device="cuda")) # dtype float32
|
|
weight = ratio2weight(targets_mask, sample_weight)
|
|
loss = loss * weight
|
|
|
|
with torch.no_grad():
|
|
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
|
|
|
|
loss = loss.sum() / non_zero_cnt
|
|
return loss
|