fast-reid/fastreid/modeling/losses/focal_loss.py

93 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
# based on:
# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
def focal_loss(
input: torch.Tensor,
target: torch.Tensor,
alpha: float,
gamma: float = 2.0,
reduction: str = 'mean') -> torch.Tensor:
r"""Criterion that computes Focal loss.
See :class:`fastreid.modeling.losses.FocalLoss` for details.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: none | mean | sum. none: no reduction will be applied,
mean: the sum of the output will be divided by the number of elements
in the output, sum: the output will be summed. Default: none.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 ≤ targets[i] ≤ C1`.
Examples:
>>> N = 5 # num_classes
>>> loss = FocalLoss(cfg)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) >= 2:
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
.format(input.shape))
if input.size(0) != target.size(0):
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
.format(input.size(0), target.size(0)))
n = input.size(0)
out_size = (n,) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
raise ValueError('Expected target size {}, got {}'.format(
out_size, target.size()))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {}".format(
input.device, target.device))
# compute softmax over the classes axis
input_soft = F.softmax(input, dim=1)
# create the labels one hot tensor
target_one_hot = F.one_hot(target, num_classes=input.shape[1])
# compute the actual focal loss
weight = torch.pow(-input_soft + 1., gamma)
focal = -alpha * weight * torch.log(input_soft)
loss_tmp = torch.sum(target_one_hot * focal, dim=1)
if reduction == 'none':
loss = loss_tmp
elif reduction == 'mean':
loss = torch.mean(loss_tmp)
elif reduction == 'sum':
loss = torch.sum(loss_tmp)
else:
raise NotImplementedError("Invalid reduction mode: {}"
.format(reduction))
return loss