[Feature] Add bce loss for multilabel task (#130)

* add bce loss for multilabel task

* minor change

* apply class wise sum

* fix docstring

* do not apply sum over classes and fix docstring

* fix docstring

* fix weight shape

* fix weight shape

* fix docstring

* fix linting issue

Co-authored-by: Y. Xiong <xiongyuxy@gmail.com>
pull/132/head
LXXXXR 2021-01-11 11:05:24 +08:00 committed by GitHub
parent 9578bfa0f1
commit 194ab7efda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 103 additions and 7 deletions

View File

@ -20,7 +20,6 @@ for f in files:
title = content.split('\n')[0].replace('# ', '')
ckpts = set(x.lower().strip()
for x in re.findall(r'\[model\]\((https?.*)\)', content))

View File

@ -1,12 +1,14 @@
from .accuracy import Accuracy, accuracy
from .cross_entropy_loss import CrossEntropyLoss, cross_entropy
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy)
from .eval_metrics import f1_score, precision, recall
from .focal_loss import FocalLoss, sigmoid_focal_loss
from .label_smooth_loss import LabelSmoothLoss, label_smooth
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
'precision', 'recall', 'f1_score', 'FocalLoss', 'sigmoid_focal_loss'
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'label_smooth',
'LabelSmoothLoss', 'weighted_loss', 'precision', 'recall', 'f1_score',
'FocalLoss', 'sigmoid_focal_loss'
]

View File

@ -6,6 +6,20 @@ from .utils import weight_reduce_loss
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
Returns:
torch.Tensor: The calculated loss
"""
# element-wise losses
loss = F.cross_entropy(pred, label, reduction='none')
@ -18,15 +32,65 @@ def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
return loss
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None):
"""Calculate the binary CrossEntropy loss with logits.
Args:
pred (torch.Tensor): The prediction with shape (N, *).
label (torch.Tensor): The learning label with shape (N, *).
weight (torch.Tensor, optional): Element-wise weight of loss with shape
(N, ). Defaults to None.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum". If reduction is 'none' , loss
is same shape as pred and label. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert pred.dim() == label.dim()
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
# apply weights and do the reduction
if weight is not None:
assert weight.dim() == 1
weight = weight.float()
if pred.dim() > 1:
weight = weight.reshape(-1, 1)
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
"""Cross entropy loss
def __init__(self, reduction='mean', loss_weight=1.0):
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum". Defaults to 'mean'.
loss_weight (float, optional): Weight of the loss.
Defaults to 1.0.
"""
def __init__(self, use_sigmoid=False, reduction='mean', loss_weight=1.0):
super(CrossEntropyLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.loss_weight = loss_weight
self.cls_criterion = cross_entropy
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
else:
self.cls_criterion = cross_entropy
def forward(self,
cls_score,

View File

@ -3,6 +3,37 @@ import torch
from mmcls.models import build_loss
def test_cross_entropy_loss():
# test ce_loss
cls_score = torch.Tensor([[100, -100]])
label = torch.Tensor([1]).long()
weight = torch.tensor(0.5)
loss_cfg = dict(type='CrossEntropyLoss', reduction='mean', loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(loss(cls_score, label), torch.tensor(200.))
# test ce_loss with weight
assert torch.allclose(
loss(cls_score, label, weight=weight), torch.tensor(100.))
# test bce_loss
cls_score = torch.Tensor([[100, -100], [100, -100]])
label = torch.Tensor([[1, 0], [0, 1]])
weight = torch.Tensor([0.5, 0.5])
loss_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0)
loss = build_loss(loss_cfg)
assert torch.allclose(loss(cls_score, label), torch.tensor(50.))
# test ce_loss with weight
assert torch.allclose(
loss(cls_score, label, weight=weight), torch.tensor(25.))
def test_focal_loss():
# test focal_loss
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])