添加binary cross entropy loss和binary focal loss

pull/608/head
zuchen.wang 2021-11-10 16:55:40 +08:00
parent da8b623ce8
commit dfd7e5f61e
6 changed files with 104 additions and 39 deletions

View File

@ -88,6 +88,10 @@ _C.MODEL.HEADS.SCALE = 1
_C.MODEL.LOSSES = CN()
_C.MODEL.LOSSES.NAME = ("CrossEntropyLoss",)
# Binary Cross Entropy Loss
_C.MODEL.LOSSES.BCE = CN()
_C.MODEL.LOSSES.BCE.SCALE = 1.0
# Cross Entropy Loss options
_C.MODEL.LOSSES.CE = CN()
# if epsilon == 0, it means no label smooth regularization,
@ -96,6 +100,12 @@ _C.MODEL.LOSSES.CE.EPSILON = 0.0
_C.MODEL.LOSSES.CE.ALPHA = 0.2
_C.MODEL.LOSSES.CE.SCALE = 1.0
# Binary Focal Loss options
_C.MODEL.LOSSES.BFL = CN()
_C.MODEL.LOSSES.BFL.ALPHA = 0.25
_C.MODEL.LOSSES.BFL.GAMMA = 2
_C.MODEL.LOSSES.BFL.SCALE = 1.0
# Focal Loss options
_C.MODEL.LOSSES.FL = CN()
_C.MODEL.LOSSES.FL.ALPHA = 0.25

View File

@ -6,8 +6,8 @@
from .circle_loss import *
from .contrastive_loss import contrastive_loss
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
from .focal_loss import focal_loss
from .cross_entroy_loss import binary_cross_entropy_loss, cross_entropy_loss, log_accuracy
from .focal_loss import binary_focal_loss, focal_loss
from .triplet_loss import triplet_loss
__all__ = [k for k in globals().keys() if not k.startswith("_")]
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -52,3 +52,7 @@ def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
loss = loss.sum() / non_zero_cnt
return loss
def binary_cross_entropy_loss(inputs, targets):
return F.binary_cross_entropy_with_logits(inputs, targets)

View File

@ -90,3 +90,51 @@ def focal_loss(
raise NotImplementedError("Invalid reduction mode: {}"
.format(reduction))
return loss
def binary_focal_loss(inputs, targets, alpha=0.25, gamma=2):
'''
Reference: https://github.com/tensorflow/addons/blob/v0.14.0/tensorflow_addons/losses/focal_loss.py
'''
# __import__('ipdb').set_trace()
if alpha < 0:
raise ValueError(f'Value of alpha should be greater than or equal to zero, but get {alpha}')
if gamma < 0:
raise ValueError(f'Value of gamma should be greater than or equal to zero, but get {gamma}')
if not torch.is_tensor(inputs):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(inputs)))
if not len(inputs.shape) >= 2:
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}".format(inputs.shape))
if inputs.size(0) != targets.size(0):
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
.format(inputs.size(0), targets.size(0)))
if not inputs.device == targets.device:
raise ValueError(
"input and target must be in the same device. Got: {}".format(
inputs.device, targets.device))
if len(targets.shape) == 1:
targets = torch.unsqueeze(targets, 1)
if targets.dtype != inputs.dtype:
targets = targets.to(inputs.dtype)
bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pred_prob = torch.sigmoid(inputs)
p_t = targets * pred_prob + (1 - targets) * (1 - pred_prob)
alpha_factor = 1.0
modulating_factor = 1.0
if alpha > 0:
alpha_factor = targets * alpha + (1 - targets) * (1 - alpha)
if gamma > 0:
modulating_factor = torch.pow(1.0 - p_t, gamma)
loss = torch.mean(alpha_factor * modulating_factor * bce)
return loss

View File

@ -68,11 +68,24 @@ class Baseline(nn.Module):
'loss_names': cfg.MODEL.LOSSES.NAME,
# loss hyperparameters
'bce': {
'scale': cfg.MODEL.LOSSES.BCE.SCALE
},
'ce': {
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
'scale': cfg.MODEL.LOSSES.CE.SCALE
},
'bfl': {
'alpha': cfg.MODEL.LOSSES.BFL.ALPHA,
'gamma': cfg.MODEL.LOSSES.BFL.GAMMA,
'scale': cfg.MODEL.LOSSES.BFL.SCALE
},
'fl': {
'alpha': cfg.MODEL.LOSSES.FL.ALPHA,
'gamma': cfg.MODEL.LOSSES.FL.GAMMA,
'scale': cfg.MODEL.LOSSES.FL.SCALE
},
'tri': {
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
@ -152,6 +165,13 @@ class Baseline(nn.Module):
loss_dict = {}
loss_names = self.loss_kwargs['loss_names']
if 'BinaryCrossEntropyLoss' in loss_names:
bce_kwargs = self.loss_kwargs.get('bce')
loss_dict['loss_bcls'] = binary_cross_entropy_loss(
cls_outputs,
gt_labels,
) * bce_kwargs.get('scale')
if 'CrossEntropyLoss' in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls'] = cross_entropy_loss(
@ -161,6 +181,24 @@ class Baseline(nn.Module):
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale')
if 'BinaryFocalLoss' in loss_names:
bfl_kwargs = self.loss_kwargs.get('bfl')
loss_dict['loss_bfl'] = binary_focal_loss(
cls_outputs,
gt_labels,
bfl_kwargs.get('alpha'),
bfl_kwargs.get('gamma')
) * bfl_kwargs.get('scale')
if 'FocalLoss' in loss_names:
fl_kwargs = self.loss_kwargs.get('fl')
loss_dict['loss_fl'] = focal_loss(
cls_outputs,
gt_labels,
fl_kwargs.get('alpha'),
fl_kwargs.get('gamma')
) * fl_kwargs.get('scale')
if 'TripletLoss' in loss_names:
tri_kwargs = self.loss_kwargs.get('tri')
loss_dict['loss_triplet'] = triplet_loss(

View File

@ -27,6 +27,7 @@ class PcbOnline(Baseline):
outputs['query_feature'] = qf
outputs['gallery_feature'] = xf
outputs['features'] = {}
if self.training:
targets = batched_inputs['targets']
losses = self.losses(outputs, targets)
@ -34,39 +35,3 @@ class PcbOnline(Baseline):
else:
return outputs
def losses(self, outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
pred_query_feature = outputs['query_feature']
pred_gallery_feature = outputs['gallery_feature']
pred_class_logits = outputs['pred_class_logits'].detach()
cls_outputs = outputs['cls_outputs']
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self.loss_kwargs['loss_names']
if 'CrossEntropyLoss' in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls'] = cross_entropy_loss(
cls_outputs,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale')
if 'ContrastiveLoss' in loss_names:
contrastive_kwargs = self.loss_kwargs.get('contrastive')
loss_dict['loss_contrastive'] = contrastive_loss(
pred_query_feature,
pred_gallery_feature,
gt_labels,
contrastive_kwargs.get('margin')
) * contrastive_kwargs.get('scale')
return loss_dict