mmpretrain/mmcls/models/losses/asymmetric_loss.py

113 lines
3.8 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weight_reduce_loss
def asymmetric_loss(pred,
target,
weight=None,
gamma_pos=1.0,
gamma_neg=4.0,
clip=0.05,
reduction='mean',
avg_factor=None):
r"""asymmetric loss.
Please refer to the `paper <https://arxiv.org/abs/2009.14119>`__ for
details.
Args:
pred (torch.Tensor): The prediction with shape (N, \*).
target (torch.Tensor): The ground truth label of the prediction with
shape (N, \*).
weight (torch.Tensor, optional): Sample-wise loss weight with shape
(N, ). Dafaults to None.
gamma_pos (float): positive focusing parameter. Defaults to 0.0.
gamma_neg (float): Negative focusing parameter. We usually set
gamma_neg > gamma_pos. Defaults to 4.0.
clip (float, optional): Probability margin. Defaults to 0.05.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". If reduction is 'none' , loss
is same shape as pred and label. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
torch.Tensor: Loss.
"""
assert pred.shape == \
target.shape, 'pred and target should be in the same shape.'
eps = 1e-8
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
if clip and clip > 0:
pt = (1 - pred_sigmoid +
clip).clamp(max=1) * (1 - target) + pred_sigmoid * target
else:
pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target
asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg *
(1 - target))
loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight
if weight is not None:
assert weight.dim() == 1
weight = weight.float()
if pred.dim() > 1:
weight = weight.reshape(-1, 1)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module()
class AsymmetricLoss(nn.Module):
"""asymmetric loss.
Args:
gamma_pos (float): positive focusing parameter.
Defaults to 0.0.
gamma_neg (float): Negative focusing parameter. We
usually set gamma_neg > gamma_pos. Defaults to 4.0.
clip (float, optional): Probability margin. Defaults to 0.05.
reduction (str): The method used to reduce the loss into
a scalar.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(self,
gamma_pos=0.0,
gamma_neg=4.0,
clip=0.05,
reduction='mean',
loss_weight=1.0):
super(AsymmetricLoss, self).__init__()
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.clip = clip
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""asymmetric loss."""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_cls = self.loss_weight * asymmetric_loss(
pred,
target,
weight,
gamma_pos=self.gamma_pos,
gamma_neg=self.gamma_neg,
clip=self.clip,
reduction=reduction,
avg_factor=avg_factor)
return loss_cls