[Enhance] Support single-label, softmax, custom eps by asymmetric loss (#609)
Co-authored-by: Minyus <Minyus@users.noreply.github.com>pull/668/head
parent
d29037e8d1
commit
e694269c59
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weight_reduce_loss
|
||||
from .utils import convert_to_one_hot, weight_reduce_loss
|
||||
|
||||
|
||||
def asymmetric_loss(pred,
|
||||
|
@ -13,7 +13,9 @@ def asymmetric_loss(pred,
|
|||
gamma_neg=4.0,
|
||||
clip=0.05,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
avg_factor=None,
|
||||
use_sigmoid=True,
|
||||
eps=1e-8):
|
||||
r"""asymmetric loss.
|
||||
|
||||
Please refer to the `paper <https://arxiv.org/abs/2009.14119>`__ for
|
||||
|
@ -34,6 +36,10 @@ def asymmetric_loss(pred,
|
|||
is same shape as pred and label. Defaults to 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
use_sigmoid (bool): Whether the prediction uses sigmoid instead
|
||||
of softmax. Defaults to True.
|
||||
eps (float): The minimum value of the argument of logarithm. Defaults
|
||||
to 1e-8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss.
|
||||
|
@ -41,8 +47,11 @@ def asymmetric_loss(pred,
|
|||
assert pred.shape == \
|
||||
target.shape, 'pred and target should be in the same shape.'
|
||||
|
||||
eps = 1e-8
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
if use_sigmoid:
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
else:
|
||||
pred_sigmoid = nn.functional.softmax(pred, dim=-1)
|
||||
|
||||
target = target.type_as(pred)
|
||||
|
||||
if clip and clip > 0:
|
||||
|
@ -75,6 +84,10 @@ class AsymmetricLoss(nn.Module):
|
|||
reduction (str): The method used to reduce the loss into
|
||||
a scalar.
|
||||
loss_weight (float): Weight of loss. Defaults to 1.0.
|
||||
use_sigmoid (bool): Whether the prediction uses sigmoid instead
|
||||
of softmax. Defaults to True.
|
||||
eps (float): The minimum value of the argument of logarithm. Defaults
|
||||
to 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -82,13 +95,17 @@ class AsymmetricLoss(nn.Module):
|
|||
gamma_neg=4.0,
|
||||
clip=0.05,
|
||||
reduction='mean',
|
||||
loss_weight=1.0):
|
||||
loss_weight=1.0,
|
||||
use_sigmoid=True,
|
||||
eps=1e-8):
|
||||
super(AsymmetricLoss, self).__init__()
|
||||
self.gamma_pos = gamma_pos
|
||||
self.gamma_neg = gamma_neg
|
||||
self.clip = clip
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.eps = eps
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
|
@ -96,10 +113,28 @@ class AsymmetricLoss(nn.Module):
|
|||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None):
|
||||
"""asymmetric loss."""
|
||||
r"""asymmetric loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, \*).
|
||||
target (torch.Tensor): The ground truth label of the prediction
|
||||
with shape (N, \*), N or (N,1).
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight with shape
|
||||
(N, \*). Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
reduction_override (str, optional): The method used to reduce the
|
||||
loss into a scalar. Options are "none", "mean" and "sum".
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss.
|
||||
"""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if target.dim() == 1 or (target.dim() == 2 and target.shape[1] == 1):
|
||||
target = convert_to_one_hot(target.view(-1, 1), pred.shape[-1])
|
||||
loss_cls = self.loss_weight * asymmetric_loss(
|
||||
pred,
|
||||
target,
|
||||
|
@ -108,5 +143,7 @@ class AsymmetricLoss(nn.Module):
|
|||
gamma_neg=self.gamma_neg,
|
||||
clip=self.clip,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor)
|
||||
avg_factor=avg_factor,
|
||||
use_sigmoid=self.use_sigmoid,
|
||||
eps=self.eps)
|
||||
return loss_cls
|
||||
|
|
|
@ -36,6 +36,46 @@ def test_asymmetric_loss():
|
|||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(5.1186 / 3))
|
||||
|
||||
# test asymmetric_loss with softmax for single label task
|
||||
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
||||
label = torch.Tensor([0, 1])
|
||||
weight = torch.tensor([0.5, 0.5])
|
||||
loss_cfg = dict(
|
||||
type='AsymmetricLoss',
|
||||
gamma_pos=0.0,
|
||||
gamma_neg=0.0,
|
||||
clip=None,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
use_sigmoid=False,
|
||||
eps=1e-8)
|
||||
loss = build_loss(loss_cfg)
|
||||
# test asymmetric_loss for single label task without weight
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(2.5045))
|
||||
# test asymmetric_loss for single label task with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(2.5045 * 0.5))
|
||||
|
||||
# test soft asymmetric_loss with softmax
|
||||
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
||||
label = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
||||
weight = torch.tensor([0.5, 0.5])
|
||||
loss_cfg = dict(
|
||||
type='AsymmetricLoss',
|
||||
gamma_pos=0.0,
|
||||
gamma_neg=0.0,
|
||||
clip=None,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
use_sigmoid=False,
|
||||
eps=1e-8)
|
||||
loss = build_loss(loss_cfg)
|
||||
# test soft asymmetric_loss with softmax without weight
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(2.5045))
|
||||
# test soft asymmetric_loss with softmax with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(2.5045 * 0.5))
|
||||
|
||||
|
||||
def test_cross_entropy_loss():
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
Loading…
Reference in New Issue