From 6916f33d5653ea7a3bd997363528bcf017b24a02 Mon Sep 17 00:00:00 2001 From: LXXXXR <73265258+LXXXXR@users.noreply.github.com> Date: Mon, 11 Jan 2021 11:22:22 +0800 Subject: [PATCH] [Feature] Add asymmetric loss for multilabel task (#132) * add asymmetric loss * minor change * fix docstring * do not apply sum over classes and fix docstring * fix docstring * fix weight shape * fix weight shape * add reference * fix linkting issue Co-authored-by: Y. Xiong --- mmcls/models/losses/__init__.py | 9 +- mmcls/models/losses/asymmetric_loss.py | 113 +++++++++++++++++++++++++ tests/test_losses.py | 32 +++++++ 3 files changed, 150 insertions(+), 4 deletions(-) create mode 100644 mmcls/models/losses/asymmetric_loss.py diff --git a/mmcls/models/losses/__init__.py b/mmcls/models/losses/__init__.py index 12631b89e..774b4917d 100644 --- a/mmcls/models/losses/__init__.py +++ b/mmcls/models/losses/__init__.py @@ -1,4 +1,5 @@ from .accuracy import Accuracy, accuracy +from .asymmetric_loss import AsymmetricLoss, asymmetric_loss from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, cross_entropy) from .eval_metrics import f1_score, precision, recall @@ -7,8 +8,8 @@ from .label_smooth_loss import LabelSmoothLoss, label_smooth from .utils import reduce_loss, weight_reduce_loss, weighted_loss __all__ = [ - 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', - 'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'label_smooth', - 'LabelSmoothLoss', 'weighted_loss', 'precision', 'recall', 'f1_score', - 'FocalLoss', 'sigmoid_focal_loss' + 'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss', + 'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', + 'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss', + 'precision', 'recall', 'f1_score', 'FocalLoss', 'sigmoid_focal_loss' ] diff --git a/mmcls/models/losses/asymmetric_loss.py b/mmcls/models/losses/asymmetric_loss.py new file mode 100644 index 000000000..da27e3f36 --- /dev/null +++ b/mmcls/models/losses/asymmetric_loss.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn + +from ..builder import LOSSES +from .utils import weight_reduce_loss + + +def asymmetric_loss(pred, + target, + weight=None, + gamma_pos=1.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + avg_factor=None): + """asymmetric loss + + Please refer to the `paper `_ for + details. + + Args: + pred (torch.Tensor): The prediction with shape (N, *). + target (torch.Tensor): The ground truth label of the prediction with + shape (N, *). + weight (torch.Tensor, optional): Sample-wise loss weight with shape + (N, ). Dafaults to None. + gamma_pos (float, optional): positive focusing parameter. + Defaults to 0.0. + gamma_neg (float, optional): Negative focusing parameter. We usually + set gamma_neg > gamma_pos. Defaults to 4.0. + clip (float, optional): Probability margin. Defaults to 0.05. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". If reduction is 'none' , loss + is same shape as pred and label. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + + Returns: + torch.Tensor: Loss. + """ + assert pred.shape == \ + target.shape, 'pred and target should be in the same shape.' + + eps = 1e-8 + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + + if clip and clip > 0: + pt = (1 - pred_sigmoid + + clip).clamp(max=1) * (1 - target) + pred_sigmoid * target + else: + pt = (1 - pred_sigmoid) * (1 - target) + pred_sigmoid * target + asymmetric_weight = (1 - pt).pow(gamma_pos * target + gamma_neg * + (1 - target)) + loss = -torch.log(pt.clamp(min=eps)) * asymmetric_weight + if weight is not None: + assert weight.dim() == 1 + weight = weight.float() + if pred.dim() > 1: + weight = weight.reshape(-1, 1) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module() +class AsymmetricLoss(nn.Module): + """asymmetric loss + + Args: + gamma_pos (float, optional): positive focusing parameter. + Defaults to 0.0. + gamma_neg (float, optional): Negative focusing parameter. We + usually set gamma_neg > gamma_pos. Defaults to 4.0. + clip (float, optional): Probability margin. Defaults to 0.05. + reduction (str, optional): The method used to reduce the loss into + a scalar. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + """ + + def __init__(self, + gamma_pos=0.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + loss_weight=1.0): + super(AsymmetricLoss, self).__init__() + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.clip = clip + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """asymmetric loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss_cls = self.loss_weight * asymmetric_loss( + pred, + target, + weight, + gamma_pos=self.gamma_pos, + gamma_neg=self.gamma_neg, + clip=self.clip, + reduction=reduction, + avg_factor=avg_factor) + return loss_cls diff --git a/tests/test_losses.py b/tests/test_losses.py index dcd4f629a..6195327b8 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -3,6 +3,38 @@ import torch from mmcls.models import build_loss +def test_asymmetric_loss(): + # test asymmetric_loss + cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]]) + label = torch.Tensor([[1, 0, 1], [0, 1, 0]]) + weight = torch.tensor([0.5, 0.5]) + + loss_cfg = dict( + type='AsymmetricLoss', + gamma_pos=1.0, + gamma_neg=4.0, + clip=0.05, + reduction='mean', + loss_weight=1.0) + loss = build_loss(loss_cfg) + assert torch.allclose(loss(cls_score, label), torch.tensor(3.80845 / 3)) + + # test asymmetric_loss with weight + assert torch.allclose( + loss(cls_score, label, weight=weight), torch.tensor(3.80845 / 6)) + + # test asymmetric_loss without clip + loss_cfg = dict( + type='AsymmetricLoss', + gamma_pos=1.0, + gamma_neg=4.0, + clip=None, + reduction='mean', + loss_weight=1.0) + loss = build_loss(loss_cfg) + assert torch.allclose(loss(cls_score, label), torch.tensor(5.1186 / 3)) + + def test_cross_entropy_loss(): # test ce_loss