[Feature] Add asymmetric loss for multilabel task (#132)

* add asymmetric loss

* minor change

* fix docstring

* do not apply sum over classes and fix docstring

* fix docstring

* fix weight shape

* fix weight shape

* add reference

* fix linkting issue

Co-authored-by: Y. Xiong <xiongyuxy@gmail.com>
pull/138/head
LXXXXR 2021-01-11 11:22:22 +08:00 committed by GitHub
parent d062cdef0a
commit 6916f33d56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 150 additions and 4 deletions

View File

@ -1,4 +1,5 @@
from .accuracy import Accuracy, accuracy
from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy)
from .eval_metrics import f1_score, precision, recall
@ -7,8 +8,8 @@ from .label_smooth_loss import LabelSmoothLoss, label_smooth
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'label_smooth',
'LabelSmoothLoss', 'weighted_loss', 'precision', 'recall', 'f1_score',
'FocalLoss', 'sigmoid_focal_loss'
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
'precision', 'recall', 'f1_score', 'FocalLoss', 'sigmoid_focal_loss'
]

View File

@ -0,0 +1,113 @@
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weight_reduce_loss
def asymmetric_loss(pred,
target,
weight=None,
gamma_pos=1.0,
gamma_neg=4.0,
clip=0.05,
reduction='mean',
avg_factor=None):
"""asymmetric loss
Please refer to the `paper <https://arxiv.org/abs/2009.14119>`_ for
details.
Args:
pred (torch.Tensor): The prediction with shape (N, *).
target (torch.Tensor): The ground truth label of the prediction with
shape (N, *).
weight (torch.Tensor, optional): Sample-wise loss weight with shape
(N, ). Dafaults to None.
gamma_pos (float, optional): positive focusing parameter.
Defaults to 0.0.
gamma_neg (float, optional): Negative focusing parameter. We usually
set gamma_neg > gamma_pos. Defaults to 4.0.
clip (float, optional): Probability margin. Defaults to 0.05.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum". If reduction is 'none' , loss
is same shape as pred and label. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
torch.Tensor: Loss.
"""
assert pred.shape == \
target.shape, 'pred and target should be in the same shape.'
eps = 1e-8
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
if clip and clip > 0:
pt = (1 - pred_sigmoid +
clip).clamp(max=1) * (1 - target) + pred_sigmoid * target
else:
pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target
asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg *
(1 - target))
loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight
if weight is not None:
assert weight.dim() == 1
weight = weight.float()
if pred.dim() > 1:
weight = weight.reshape(-1, 1)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module()
class AsymmetricLoss(nn.Module):
"""asymmetric loss
Args:
gamma_pos (float, optional): positive focusing parameter.
Defaults to 0.0.
gamma_neg (float, optional): Negative focusing parameter. We
usually set gamma_neg > gamma_pos. Defaults to 4.0.
clip (float, optional): Probability margin. Defaults to 0.05.
reduction (str, optional): The method used to reduce the loss into
a scalar.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
def __init__(self,
gamma_pos=0.0,
gamma_neg=4.0,
clip=0.05,
reduction='mean',
loss_weight=1.0):
super(AsymmetricLoss, self).__init__()
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.clip = clip
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""asymmetric loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_cls = self.loss_weight * asymmetric_loss(
pred,
target,
weight,
gamma_pos=self.gamma_pos,
gamma_neg=self.gamma_neg,
clip=self.clip,
reduction=reduction,
avg_factor=avg_factor)
return loss_cls

View File

@ -3,6 +3,38 @@ import torch
from mmcls.models import build_loss
def test_asymmetric_loss():
# test asymmetric_loss
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
label = torch.Tensor([[1, 0, 1], [0, 1, 0]])
weight = torch.tensor([0.5, 0.5])
loss_cfg = dict(
type='AsymmetricLoss',
gamma_pos=1.0,
gamma_neg=4.0,
clip=0.05,
reduction='mean',
loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(loss(cls_score, label), torch.tensor(3.80845 / 3))
# test asymmetric_loss with weight
assert torch.allclose(
loss(cls_score, label, weight=weight), torch.tensor(3.80845 / 6))
# test asymmetric_loss without clip
loss_cfg = dict(
type='AsymmetricLoss',
gamma_pos=1.0,
gamma_neg=4.0,
clip=None,
reduction='mean',
loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(loss(cls_score, label), torch.tensor(5.1186 / 3))
def test_cross_entropy_loss():
# test ce_loss