2021-08-17 19:52:42 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2021-01-11 11:22:22 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from ..builder import LOSSES
|
|
|
|
from .utils import weight_reduce_loss
|
|
|
|
|
|
|
|
|
|
|
|
def asymmetric_loss(pred,
|
|
|
|
target,
|
|
|
|
weight=None,
|
|
|
|
gamma_pos=1.0,
|
|
|
|
gamma_neg=4.0,
|
|
|
|
clip=0.05,
|
|
|
|
reduction='mean',
|
|
|
|
avg_factor=None):
|
2021-07-14 15:06:50 +08:00
|
|
|
r"""asymmetric loss.
|
2021-01-11 11:22:22 +08:00
|
|
|
|
2021-07-14 15:06:50 +08:00
|
|
|
Please refer to the `paper <https://arxiv.org/abs/2009.14119>`__ for
|
2021-01-11 11:22:22 +08:00
|
|
|
details.
|
|
|
|
|
|
|
|
Args:
|
2021-07-14 15:06:50 +08:00
|
|
|
pred (torch.Tensor): The prediction with shape (N, \*).
|
2021-01-11 11:22:22 +08:00
|
|
|
target (torch.Tensor): The ground truth label of the prediction with
|
2021-07-14 15:06:50 +08:00
|
|
|
shape (N, \*).
|
2021-01-11 11:22:22 +08:00
|
|
|
weight (torch.Tensor, optional): Sample-wise loss weight with shape
|
|
|
|
(N, ). Dafaults to None.
|
2021-01-14 11:09:08 +08:00
|
|
|
gamma_pos (float): positive focusing parameter. Defaults to 0.0.
|
|
|
|
gamma_neg (float): Negative focusing parameter. We usually set
|
|
|
|
gamma_neg > gamma_pos. Defaults to 4.0.
|
2021-01-11 11:22:22 +08:00
|
|
|
clip (float, optional): Probability margin. Defaults to 0.05.
|
2021-01-14 11:09:08 +08:00
|
|
|
reduction (str): The method used to reduce the loss.
|
2021-01-11 11:22:22 +08:00
|
|
|
Options are "none", "mean" and "sum". If reduction is 'none' , loss
|
2021-07-14 15:06:50 +08:00
|
|
|
is same shape as pred and label. Defaults to 'mean'.
|
2021-01-11 11:22:22 +08:00
|
|
|
avg_factor (int, optional): Average factor that is used to average
|
|
|
|
the loss. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Loss.
|
|
|
|
"""
|
|
|
|
assert pred.shape == \
|
|
|
|
target.shape, 'pred and target should be in the same shape.'
|
|
|
|
|
|
|
|
eps = 1e-8
|
|
|
|
pred_sigmoid = pred.sigmoid()
|
|
|
|
target = target.type_as(pred)
|
|
|
|
|
|
|
|
if clip and clip > 0:
|
|
|
|
pt = (1 - pred_sigmoid +
|
|
|
|
clip).clamp(max=1) * (1 - target) + pred_sigmoid * target
|
|
|
|
else:
|
|
|
|
pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target
|
|
|
|
asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg *
|
|
|
|
(1 - target))
|
|
|
|
loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight
|
|
|
|
if weight is not None:
|
|
|
|
assert weight.dim() == 1
|
|
|
|
weight = weight.float()
|
|
|
|
if pred.dim() > 1:
|
|
|
|
weight = weight.reshape(-1, 1)
|
|
|
|
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
@LOSSES.register_module()
|
|
|
|
class AsymmetricLoss(nn.Module):
|
2021-04-14 21:22:37 +08:00
|
|
|
"""asymmetric loss.
|
2021-01-11 11:22:22 +08:00
|
|
|
|
|
|
|
Args:
|
2021-01-14 11:09:08 +08:00
|
|
|
gamma_pos (float): positive focusing parameter.
|
2021-01-11 11:22:22 +08:00
|
|
|
Defaults to 0.0.
|
2021-01-14 11:09:08 +08:00
|
|
|
gamma_neg (float): Negative focusing parameter. We
|
2021-01-11 11:22:22 +08:00
|
|
|
usually set gamma_neg > gamma_pos. Defaults to 4.0.
|
|
|
|
clip (float, optional): Probability margin. Defaults to 0.05.
|
2021-01-14 11:09:08 +08:00
|
|
|
reduction (str): The method used to reduce the loss into
|
2021-01-11 11:22:22 +08:00
|
|
|
a scalar.
|
2021-01-14 11:09:08 +08:00
|
|
|
loss_weight (float): Weight of loss. Defaults to 1.0.
|
2021-04-14 21:22:37 +08:00
|
|
|
"""
|
2021-01-11 11:22:22 +08:00
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
gamma_pos=0.0,
|
|
|
|
gamma_neg=4.0,
|
|
|
|
clip=0.05,
|
|
|
|
reduction='mean',
|
|
|
|
loss_weight=1.0):
|
|
|
|
super(AsymmetricLoss, self).__init__()
|
|
|
|
self.gamma_pos = gamma_pos
|
|
|
|
self.gamma_neg = gamma_neg
|
|
|
|
self.clip = clip
|
|
|
|
self.reduction = reduction
|
|
|
|
self.loss_weight = loss_weight
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
pred,
|
|
|
|
target,
|
|
|
|
weight=None,
|
|
|
|
avg_factor=None,
|
|
|
|
reduction_override=None):
|
2021-04-14 21:22:37 +08:00
|
|
|
"""asymmetric loss."""
|
2021-01-11 11:22:22 +08:00
|
|
|
assert reduction_override in (None, 'none', 'mean', 'sum')
|
|
|
|
reduction = (
|
|
|
|
reduction_override if reduction_override else self.reduction)
|
|
|
|
loss_cls = self.loss_weight * asymmetric_loss(
|
|
|
|
pred,
|
|
|
|
target,
|
|
|
|
weight,
|
|
|
|
gamma_pos=self.gamma_pos,
|
|
|
|
gamma_neg=self.gamma_neg,
|
|
|
|
clip=self.clip,
|
|
|
|
reduction=reduction,
|
|
|
|
avg_factor=avg_factor)
|
|
|
|
return loss_cls
|