diff --git a/mmcls/models/losses/cross_entropy_loss.py b/mmcls/models/losses/cross_entropy_loss.py index e8de4b1f..be9b8f09 100644 --- a/mmcls/models/losses/cross_entropy_loss.py +++ b/mmcls/models/losses/cross_entropy_loss.py @@ -6,7 +6,12 @@ from ..builder import LOSSES from .utils import weight_reduce_loss -def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None): +def cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None): """Calculate the CrossEntropy loss. Args: @@ -17,12 +22,14 @@ def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None): reduction (str): The method used to reduce the loss. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. Returns: torch.Tensor: The calculated loss """ # element-wise losses - loss = F.cross_entropy(pred, label, reduction='none') + loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none') # apply weights and do the reduction if weight is not None: @@ -37,6 +44,7 @@ def soft_cross_entropy(pred, label, weight=None, reduction='mean', + class_weight=None, avg_factor=None): """Calculate the Soft CrossEntropy loss. The label can be float. @@ -49,12 +57,16 @@ def soft_cross_entropy(pred, reduction (str): The method used to reduce the loss. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. Returns: torch.Tensor: The calculated loss """ # element-wise losses loss = -label * F.log_softmax(pred, dim=-1) + if class_weight is not None: + loss *= class_weight loss = loss.sum(dim=-1) # apply weights and do the reduction @@ -70,7 +82,8 @@ def binary_cross_entropy(pred, label, weight=None, reduction='mean', - avg_factor=None): + avg_factor=None, + class_weight=None): r"""Calculate the binary CrossEntropy loss with logits. Args: @@ -83,13 +96,20 @@ def binary_cross_entropy(pred, is same shape as pred and label. Defaults to 'mean'. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. + class_weight (torch.Tensor, optional): The weight for each class with + shape (C), C is the number of classes. Default None. Returns: torch.Tensor: The calculated loss """ assert pred.dim() == label.dim() - - loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none') + # Ensure that the size of class_weight is consistent with pred and label to + # avoid automatic boracast, + if class_weight is not None: + N = pred.size()[0] + class_weight = class_weight.repeat(N, 1) + loss = F.binary_cross_entropy_with_logits( + pred, label, weight=class_weight, reduction='none') # apply weights and do the reduction if weight is not None: @@ -114,13 +134,16 @@ class CrossEntropyLoss(nn.Module): reduction (str): The method used to reduce the loss. Options are "none", "mean" and "sum". Defaults to 'mean'. loss_weight (float): Weight of the loss. Defaults to 1.0. + class_weight (List[float], optional): The weight for each class with + shape (C), C is the number of classes. Default None. """ def __init__(self, use_sigmoid=False, use_soft=False, reduction='mean', - loss_weight=1.0): + loss_weight=1.0, + class_weight=None): super(CrossEntropyLoss, self).__init__() self.use_sigmoid = use_sigmoid self.use_soft = use_soft @@ -130,6 +153,7 @@ class CrossEntropyLoss(nn.Module): self.reduction = reduction self.loss_weight = loss_weight + self.class_weight = class_weight if self.use_sigmoid: self.cls_criterion = binary_cross_entropy @@ -148,10 +172,17 @@ class CrossEntropyLoss(nn.Module): assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) + + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + loss_cls = self.loss_weight * self.cls_criterion( cls_score, label, weight, + class_weight=class_weight, reduction=reduction, avg_factor=avg_factor, **kwargs) diff --git a/tests/test_metrics/test_losses.py b/tests/test_metrics/test_losses.py index bf474407..c156e687 100644 --- a/tests/test_metrics/test_losses.py +++ b/tests/test_metrics/test_losses.py @@ -45,48 +45,92 @@ def test_cross_entropy_loss(): loss = build_loss(loss_cfg) # test ce_loss - cls_score = torch.Tensor([[100, -100]]) - label = torch.Tensor([1]).long() - weight = torch.tensor(0.5) + cls_score = torch.Tensor([[-1000, 1000], [100, -100]]) + label = torch.Tensor([0, 1]).long() + class_weight = [0.3, 0.7] # class 0 : 0.3, class 1 : 0.7 + weight = torch.tensor([0.6, 0.4]) + # test ce_loss without class weight loss_cfg = dict(type='CrossEntropyLoss', reduction='mean', loss_weight=1.0) loss = build_loss(loss_cfg) - assert torch.allclose(loss(cls_score, label), torch.tensor(200.)) + assert torch.allclose(loss(cls_score, label), torch.tensor(1100.)) # test ce_loss with weight assert torch.allclose( - loss(cls_score, label, weight=weight), torch.tensor(100.)) + loss(cls_score, label, weight=weight), torch.tensor(640.)) + + # test ce_loss with class weight + loss_cfg = dict( + type='CrossEntropyLoss', + reduction='mean', + loss_weight=1.0, + class_weight=class_weight) + loss = build_loss(loss_cfg) + assert torch.allclose(loss(cls_score, label), torch.tensor(370.)) + # test ce_loss with weight + assert torch.allclose( + loss(cls_score, label, weight=weight), torch.tensor(208.)) # test bce_loss - cls_score = torch.Tensor([[100, -100], [100, -100]]) - label = torch.Tensor([[1, 0], [0, 1]]) - weight = torch.Tensor([0.5, 0.5]) + cls_score = torch.Tensor([[-200, 100], [500, -1000], [300, -300]]) + label = torch.Tensor([[1, 0], [0, 1], [1, 0]]) + weight = torch.Tensor([0.6, 0.4, 0.5]) + class_weight = [0.1, 0.9] # class 0: 0.1, class 1: 0.9 + # test bce_loss without class weight loss_cfg = dict( type='CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=1.0) loss = build_loss(loss_cfg) - assert torch.allclose(loss(cls_score, label), torch.tensor(50.)) + assert torch.allclose(loss(cls_score, label), torch.tensor(300.)) # test ce_loss with weight assert torch.allclose( - loss(cls_score, label, weight=weight), torch.tensor(25.)) + loss(cls_score, label, weight=weight), torch.tensor(130.)) + + # test bce_loss with class weight + loss_cfg = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=1.0, + class_weight=class_weight) + loss = build_loss(loss_cfg) + assert torch.allclose(loss(cls_score, label), torch.tensor(176.667)) + # test bce_loss with weight + assert torch.allclose( + loss(cls_score, label, weight=weight), torch.tensor(74.333)) # test soft_ce_loss - cls_score = torch.Tensor([[100, -100]]) - label = torch.Tensor([[1, 0], [0, 1]]) - weight = torch.tensor(0.5) + cls_score = torch.Tensor([[-1000, 1000], [100, -100]]) + label = torch.Tensor([[1.0, 0.0], [0.0, 1.0]]) + class_weight = [0.3, 0.7] # class 0 : 0.3, class 1 : 0.7 + weight = torch.tensor([0.6, 0.4]) + # test soft_ce_loss without class weight loss_cfg = dict( type='CrossEntropyLoss', use_soft=True, reduction='mean', loss_weight=1.0) loss = build_loss(loss_cfg) - assert torch.allclose(loss(cls_score, label), torch.tensor(100.)) + assert torch.allclose(loss(cls_score, label), torch.tensor(1100.)) # test soft_ce_loss with weight assert torch.allclose( - loss(cls_score, label, weight=weight), torch.tensor(50.)) + loss(cls_score, label, weight=weight), torch.tensor(640.)) + + # test soft_ce_loss with class weight + loss_cfg = dict( + type='CrossEntropyLoss', + use_soft=True, + reduction='mean', + loss_weight=1.0, + class_weight=class_weight) + loss = build_loss(loss_cfg) + assert torch.allclose(loss(cls_score, label), torch.tensor(370.)) + # test soft_ce_loss with weight + assert torch.allclose( + loss(cls_score, label, weight=weight), torch.tensor(208.)) def test_focal_loss():