[Feature] Add bce loss for multilabel task (#130)
* add bce loss for multilabel task * minor change * apply class wise sum * fix docstring * do not apply sum over classes and fix docstring * fix docstring * fix weight shape * fix weight shape * fix docstring * fix linting issue Co-authored-by: Y. Xiong <xiongyuxy@gmail.com>pull/132/head
parent
9578bfa0f1
commit
194ab7efda
|
@ -20,7 +20,6 @@ for f in files:
|
|||
|
||||
title = content.split('\n')[0].replace('# ', '')
|
||||
|
||||
|
||||
ckpts = set(x.lower().strip()
|
||||
for x in re.findall(r'\[model\]\((https?.*)\)', content))
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
from .accuracy import Accuracy, accuracy
|
||||
from .cross_entropy_loss import CrossEntropyLoss, cross_entropy
|
||||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy)
|
||||
from .eval_metrics import f1_score, precision, recall
|
||||
from .focal_loss import FocalLoss, sigmoid_focal_loss
|
||||
from .label_smooth_loss import LabelSmoothLoss, label_smooth
|
||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
|
||||
'precision', 'recall', 'f1_score', 'FocalLoss', 'sigmoid_focal_loss'
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
||||
'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'label_smooth',
|
||||
'LabelSmoothLoss', 'weighted_loss', 'precision', 'recall', 'f1_score',
|
||||
'FocalLoss', 'sigmoid_focal_loss'
|
||||
]
|
||||
|
|
|
@ -6,6 +6,20 @@ from .utils import weight_reduce_loss
|
|||
|
||||
|
||||
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
|
||||
"""Calculate the CrossEntropy loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||
of classes.
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
# element-wise losses
|
||||
loss = F.cross_entropy(pred, label, reduction='none')
|
||||
|
||||
|
@ -18,15 +32,65 @@ def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
|
|||
return loss
|
||||
|
||||
|
||||
def binary_cross_entropy(pred,
|
||||
label,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
"""Calculate the binary CrossEntropy loss with logits.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, *).
|
||||
label (torch.Tensor): The learning label with shape (N, *).
|
||||
weight (torch.Tensor, optional): Element-wise weight of loss with shape
|
||||
(N, ). Defaults to None.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum". If reduction is 'none' , loss
|
||||
is same shape as pred and label. Defaults to 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert pred.dim() == label.dim()
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
|
||||
|
||||
# apply weights and do the reduction
|
||||
if weight is not None:
|
||||
assert weight.dim() == 1
|
||||
weight = weight.float()
|
||||
if pred.dim() > 1:
|
||||
weight = weight.reshape(-1, 1)
|
||||
loss = weight_reduce_loss(
|
||||
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
"""Cross entropy loss
|
||||
|
||||
def __init__(self, reduction='mean', loss_weight=1.0):
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to False.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum". Defaults to 'mean'.
|
||||
loss_weight (float, optional): Weight of the loss.
|
||||
Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, use_sigmoid=False, reduction='mean', loss_weight=1.0):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
self.cls_criterion = cross_entropy
|
||||
if self.use_sigmoid:
|
||||
self.cls_criterion = binary_cross_entropy
|
||||
else:
|
||||
self.cls_criterion = cross_entropy
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
|
|
|
@ -3,6 +3,37 @@ import torch
|
|||
from mmcls.models import build_loss
|
||||
|
||||
|
||||
def test_cross_entropy_loss():
|
||||
|
||||
# test ce_loss
|
||||
cls_score = torch.Tensor([[100, -100]])
|
||||
label = torch.Tensor([1]).long()
|
||||
weight = torch.tensor(0.5)
|
||||
|
||||
loss_cfg = dict(type='CrossEntropyLoss', reduction='mean', loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(200.))
|
||||
# test ce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(100.))
|
||||
|
||||
# test bce_loss
|
||||
cls_score = torch.Tensor([[100, -100], [100, -100]])
|
||||
label = torch.Tensor([[1, 0], [0, 1]])
|
||||
weight = torch.Tensor([0.5, 0.5])
|
||||
|
||||
loss_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(50.))
|
||||
# test ce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(25.))
|
||||
|
||||
|
||||
def test_focal_loss():
|
||||
# test focal_loss
|
||||
cls_score = torch.Tensor([[5, -5, 0], [5, -5, 0]])
|
||||
|
|
Loading…
Reference in New Issue