[Feature] Add asymmetric loss for multilabel task (#132)
* add asymmetric loss * minor change * fix docstring * do not apply sum over classes and fix docstring * fix docstring * fix weight shape * fix weight shape * add reference * fix linkting issue Co-authored-by: Y. Xiong <xiongyuxy@gmail.com>pull/138/head
parent
d062cdef0a
commit
6916f33d56
|
@ -1,4 +1,5 @@
|
|||
from .accuracy import Accuracy, accuracy
|
||||
from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
|
||||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy)
|
||||
from .eval_metrics import f1_score, precision, recall
|
||||
|
@ -7,8 +8,8 @@ from .label_smooth_loss import LabelSmoothLoss, label_smooth
|
|||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
||||
'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'label_smooth',
|
||||
'LabelSmoothLoss', 'weighted_loss', 'precision', 'recall', 'f1_score',
|
||||
'FocalLoss', 'sigmoid_focal_loss'
|
||||
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
|
||||
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
|
||||
'precision', 'recall', 'f1_score', 'FocalLoss', 'sigmoid_focal_loss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
def asymmetric_loss(pred,
|
||||
target,
|
||||
weight=None,
|
||||
gamma_pos=1.0,
|
||||
gamma_neg=4.0,
|
||||
clip=0.05,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
"""asymmetric loss
|
||||
|
||||
Please refer to the `paper <https://arxiv.org/abs/2009.14119>`_ for
|
||||
details.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, *).
|
||||
target (torch.Tensor): The ground truth label of the prediction with
|
||||
shape (N, *).
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight with shape
|
||||
(N, ). Dafaults to None.
|
||||
gamma_pos (float, optional): positive focusing parameter.
|
||||
Defaults to 0.0.
|
||||
gamma_neg (float, optional): Negative focusing parameter. We usually
|
||||
set gamma_neg > gamma_pos. Defaults to 4.0.
|
||||
clip (float, optional): Probability margin. Defaults to 0.05.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum". If reduction is 'none' , loss
|
||||
is same shape as pred and label. Defaults to 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss.
|
||||
"""
|
||||
assert pred.shape == \
|
||||
target.shape, 'pred and target should be in the same shape.'
|
||||
|
||||
eps = 1e-8
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
target = target.type_as(pred)
|
||||
|
||||
if clip and clip > 0:
|
||||
pt = (1 - pred_sigmoid +
|
||||
clip).clamp(max=1) * (1 - target) + pred_sigmoid * target
|
||||
else:
|
||||
pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target
|
||||
asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg *
|
||||
(1 - target))
|
||||
loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight
|
||||
if weight is not None:
|
||||
assert weight.dim() == 1
|
||||
weight = weight.float()
|
||||
if pred.dim() > 1:
|
||||
weight = weight.reshape(-1, 1)
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class AsymmetricLoss(nn.Module):
|
||||
"""asymmetric loss
|
||||
|
||||
Args:
|
||||
gamma_pos (float, optional): positive focusing parameter.
|
||||
Defaults to 0.0.
|
||||
gamma_neg (float, optional): Negative focusing parameter. We
|
||||
usually set gamma_neg > gamma_pos. Defaults to 4.0.
|
||||
clip (float, optional): Probability margin. Defaults to 0.05.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
gamma_pos=0.0,
|
||||
gamma_neg=4.0,
|
||||
clip=0.05,
|
||||
reduction='mean',
|
||||
loss_weight=1.0):
|
||||
super(AsymmetricLoss, self).__init__()
|
||||
self.gamma_pos = gamma_pos
|
||||
self.gamma_neg = gamma_neg
|
||||
self.clip = clip
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None):
|
||||
"""asymmetric loss
|
||||
"""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
loss_cls = self.loss_weight * asymmetric_loss(
|
||||
pred,
|
||||
target,
|
||||
weight,
|
||||
gamma_pos=self.gamma_pos,
|
||||
gamma_neg=self.gamma_neg,
|
||||
clip=self.clip,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor)
|
||||
return loss_cls
|
|
@ -3,6 +3,38 @@ import torch
|
|||
from mmcls.models import build_loss
|
||||
|
||||
|
||||
def test_asymmetric_loss():
|
||||
# test asymmetric_loss
|
||||
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
||||
label = torch.Tensor([[1, 0, 1], [0, 1, 0]])
|
||||
weight = torch.tensor([0.5, 0.5])
|
||||
|
||||
loss_cfg = dict(
|
||||
type='AsymmetricLoss',
|
||||
gamma_pos=1.0,
|
||||
gamma_neg=4.0,
|
||||
clip=0.05,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(3.80845 / 3))
|
||||
|
||||
# test asymmetric_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(3.80845 / 6))
|
||||
|
||||
# test asymmetric_loss without clip
|
||||
loss_cfg = dict(
|
||||
type='AsymmetricLoss',
|
||||
gamma_pos=1.0,
|
||||
gamma_neg=4.0,
|
||||
clip=None,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(5.1186 / 3))
|
||||
|
||||
|
||||
def test_cross_entropy_loss():
|
||||
|
||||
# test ce_loss
|
||||
|
|
Loading…
Reference in New Issue