mirror of https://github.com/JDAI-CV/fast-reid.git
添加binary cross entropy loss和binary focal loss
parent
da8b623ce8
commit
dfd7e5f61e
|
@ -88,6 +88,10 @@ _C.MODEL.HEADS.SCALE = 1
|
|||
_C.MODEL.LOSSES = CN()
|
||||
_C.MODEL.LOSSES.NAME = ("CrossEntropyLoss",)
|
||||
|
||||
# Binary Cross Entropy Loss
|
||||
_C.MODEL.LOSSES.BCE = CN()
|
||||
_C.MODEL.LOSSES.BCE.SCALE = 1.0
|
||||
|
||||
# Cross Entropy Loss options
|
||||
_C.MODEL.LOSSES.CE = CN()
|
||||
# if epsilon == 0, it means no label smooth regularization,
|
||||
|
@ -96,6 +100,12 @@ _C.MODEL.LOSSES.CE.EPSILON = 0.0
|
|||
_C.MODEL.LOSSES.CE.ALPHA = 0.2
|
||||
_C.MODEL.LOSSES.CE.SCALE = 1.0
|
||||
|
||||
# Binary Focal Loss options
|
||||
_C.MODEL.LOSSES.BFL = CN()
|
||||
_C.MODEL.LOSSES.BFL.ALPHA = 0.25
|
||||
_C.MODEL.LOSSES.BFL.GAMMA = 2
|
||||
_C.MODEL.LOSSES.BFL.SCALE = 1.0
|
||||
|
||||
# Focal Loss options
|
||||
_C.MODEL.LOSSES.FL = CN()
|
||||
_C.MODEL.LOSSES.FL.ALPHA = 0.25
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
|
||||
from .circle_loss import *
|
||||
from .contrastive_loss import contrastive_loss
|
||||
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
|
||||
from .focal_loss import focal_loss
|
||||
from .cross_entroy_loss import binary_cross_entropy_loss, cross_entropy_loss, log_accuracy
|
||||
from .focal_loss import binary_focal_loss, focal_loss
|
||||
from .triplet_loss import triplet_loss
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
|
|
@ -52,3 +52,7 @@ def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
|
|||
loss = loss.sum() / non_zero_cnt
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def binary_cross_entropy_loss(inputs, targets):
|
||||
return F.binary_cross_entropy_with_logits(inputs, targets)
|
||||
|
|
|
@ -90,3 +90,51 @@ def focal_loss(
|
|||
raise NotImplementedError("Invalid reduction mode: {}"
|
||||
.format(reduction))
|
||||
return loss
|
||||
|
||||
|
||||
def binary_focal_loss(inputs, targets, alpha=0.25, gamma=2):
|
||||
'''
|
||||
Reference: https://github.com/tensorflow/addons/blob/v0.14.0/tensorflow_addons/losses/focal_loss.py
|
||||
'''
|
||||
# __import__('ipdb').set_trace()
|
||||
if alpha < 0:
|
||||
raise ValueError(f'Value of alpha should be greater than or equal to zero, but get {alpha}')
|
||||
if gamma < 0:
|
||||
raise ValueError(f'Value of gamma should be greater than or equal to zero, but get {gamma}')
|
||||
|
||||
if not torch.is_tensor(inputs):
|
||||
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(inputs)))
|
||||
|
||||
if not len(inputs.shape) >= 2:
|
||||
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}".format(inputs.shape))
|
||||
|
||||
if inputs.size(0) != targets.size(0):
|
||||
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
|
||||
.format(inputs.size(0), targets.size(0)))
|
||||
|
||||
if not inputs.device == targets.device:
|
||||
raise ValueError(
|
||||
"input and target must be in the same device. Got: {}".format(
|
||||
inputs.device, targets.device))
|
||||
|
||||
if len(targets.shape) == 1:
|
||||
targets = torch.unsqueeze(targets, 1)
|
||||
|
||||
if targets.dtype != inputs.dtype:
|
||||
targets = targets.to(inputs.dtype)
|
||||
|
||||
bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
|
||||
pred_prob = torch.sigmoid(inputs)
|
||||
|
||||
p_t = targets * pred_prob + (1 - targets) * (1 - pred_prob)
|
||||
alpha_factor = 1.0
|
||||
modulating_factor = 1.0
|
||||
|
||||
if alpha > 0:
|
||||
alpha_factor = targets * alpha + (1 - targets) * (1 - alpha)
|
||||
|
||||
if gamma > 0:
|
||||
modulating_factor = torch.pow(1.0 - p_t, gamma)
|
||||
|
||||
loss = torch.mean(alpha_factor * modulating_factor * bce)
|
||||
return loss
|
||||
|
|
|
@ -68,11 +68,24 @@ class Baseline(nn.Module):
|
|||
'loss_names': cfg.MODEL.LOSSES.NAME,
|
||||
|
||||
# loss hyperparameters
|
||||
'bce': {
|
||||
'scale': cfg.MODEL.LOSSES.BCE.SCALE
|
||||
},
|
||||
'ce': {
|
||||
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
'scale': cfg.MODEL.LOSSES.CE.SCALE
|
||||
},
|
||||
'bfl': {
|
||||
'alpha': cfg.MODEL.LOSSES.BFL.ALPHA,
|
||||
'gamma': cfg.MODEL.LOSSES.BFL.GAMMA,
|
||||
'scale': cfg.MODEL.LOSSES.BFL.SCALE
|
||||
},
|
||||
'fl': {
|
||||
'alpha': cfg.MODEL.LOSSES.FL.ALPHA,
|
||||
'gamma': cfg.MODEL.LOSSES.FL.GAMMA,
|
||||
'scale': cfg.MODEL.LOSSES.FL.SCALE
|
||||
},
|
||||
'tri': {
|
||||
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
|
@ -152,6 +165,13 @@ class Baseline(nn.Module):
|
|||
loss_dict = {}
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
||||
if 'BinaryCrossEntropyLoss' in loss_names:
|
||||
bce_kwargs = self.loss_kwargs.get('bce')
|
||||
loss_dict['loss_bcls'] = binary_cross_entropy_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
) * bce_kwargs.get('scale')
|
||||
|
||||
if 'CrossEntropyLoss' in loss_names:
|
||||
ce_kwargs = self.loss_kwargs.get('ce')
|
||||
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||
|
@ -161,6 +181,24 @@ class Baseline(nn.Module):
|
|||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale')
|
||||
|
||||
if 'BinaryFocalLoss' in loss_names:
|
||||
bfl_kwargs = self.loss_kwargs.get('bfl')
|
||||
loss_dict['loss_bfl'] = binary_focal_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
bfl_kwargs.get('alpha'),
|
||||
bfl_kwargs.get('gamma')
|
||||
) * bfl_kwargs.get('scale')
|
||||
|
||||
if 'FocalLoss' in loss_names:
|
||||
fl_kwargs = self.loss_kwargs.get('fl')
|
||||
loss_dict['loss_fl'] = focal_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
fl_kwargs.get('alpha'),
|
||||
fl_kwargs.get('gamma')
|
||||
) * fl_kwargs.get('scale')
|
||||
|
||||
if 'TripletLoss' in loss_names:
|
||||
tri_kwargs = self.loss_kwargs.get('tri')
|
||||
loss_dict['loss_triplet'] = triplet_loss(
|
||||
|
|
|
@ -27,6 +27,7 @@ class PcbOnline(Baseline):
|
|||
|
||||
outputs['query_feature'] = qf
|
||||
outputs['gallery_feature'] = xf
|
||||
outputs['features'] = {}
|
||||
if self.training:
|
||||
targets = batched_inputs['targets']
|
||||
losses = self.losses(outputs, targets)
|
||||
|
@ -34,39 +35,3 @@ class PcbOnline(Baseline):
|
|||
else:
|
||||
return outputs
|
||||
|
||||
def losses(self, outputs, gt_labels):
|
||||
"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
# model predictions
|
||||
pred_query_feature = outputs['query_feature']
|
||||
pred_gallery_feature = outputs['gallery_feature']
|
||||
pred_class_logits = outputs['pred_class_logits'].detach()
|
||||
cls_outputs = outputs['cls_outputs']
|
||||
|
||||
# Log prediction accuracy
|
||||
log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
||||
if 'CrossEntropyLoss' in loss_names:
|
||||
ce_kwargs = self.loss_kwargs.get('ce')
|
||||
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale')
|
||||
|
||||
if 'ContrastiveLoss' in loss_names:
|
||||
contrastive_kwargs = self.loss_kwargs.get('contrastive')
|
||||
loss_dict['loss_contrastive'] = contrastive_loss(
|
||||
pred_query_feature,
|
||||
pred_gallery_feature,
|
||||
gt_labels,
|
||||
contrastive_kwargs.get('margin')
|
||||
) * contrastive_kwargs.get('scale')
|
||||
|
||||
return loss_dict
|
||||
|
|
Loading…
Reference in New Issue