fast-reid/projects/FastAttr/fastattr/modeling/bce_loss.py

34 lines
974 B
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
def ratio2weight(targets, ratio):
pos_weights = targets * (1 - ratio)
neg_weights = (1 - targets) * ratio
weights = torch.exp(neg_weights + pos_weights)
weights[targets > 1] = 0.0
return weights
def cross_entropy_sigmoid_loss(pred_class_logits, gt_classes, sample_weight=None):
loss = F.binary_cross_entropy_with_logits(pred_class_logits, gt_classes, reduction='none')
if sample_weight is not None:
targets_mask = torch.where(gt_classes.detach() > 0.5,
torch.ones(1, device="cuda"), torch.zeros(1, device="cuda")) # dtype float32
weight = ratio2weight(targets_mask, sample_weight)
loss = loss * weight
with torch.no_grad():
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
loss = loss.sum() / non_zero_cnt
return loss