[Feature] Support classwise weight in losses (#388)
* Add classwise weight in losses:CE,BCE,softBCE * Update unit test * rm some extra code * rm some extra code * fix broadcast * fix broadcast * update unit tests * use new_tensor * fix lintpull/426/head
parent
6a0a76af0c
commit
192b79eea0
|
@ -6,7 +6,12 @@ from ..builder import LOSSES
|
|||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
|
||||
def cross_entropy(pred,
|
||||
label,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None):
|
||||
"""Calculate the CrossEntropy loss.
|
||||
|
||||
Args:
|
||||
|
@ -17,12 +22,14 @@ def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
|
|||
reduction (str): The method used to reduce the loss.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (torch.Tensor, optional): The weight for each class with
|
||||
shape (C), C is the number of classes. Default None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
# element-wise losses
|
||||
loss = F.cross_entropy(pred, label, reduction='none')
|
||||
loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none')
|
||||
|
||||
# apply weights and do the reduction
|
||||
if weight is not None:
|
||||
|
@ -37,6 +44,7 @@ def soft_cross_entropy(pred,
|
|||
label,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
avg_factor=None):
|
||||
"""Calculate the Soft CrossEntropy loss. The label can be float.
|
||||
|
||||
|
@ -49,12 +57,16 @@ def soft_cross_entropy(pred,
|
|||
reduction (str): The method used to reduce the loss.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (torch.Tensor, optional): The weight for each class with
|
||||
shape (C), C is the number of classes. Default None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
# element-wise losses
|
||||
loss = -label * F.log_softmax(pred, dim=-1)
|
||||
if class_weight is not None:
|
||||
loss *= class_weight
|
||||
loss = loss.sum(dim=-1)
|
||||
|
||||
# apply weights and do the reduction
|
||||
|
@ -70,7 +82,8 @@ def binary_cross_entropy(pred,
|
|||
label,
|
||||
weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None):
|
||||
avg_factor=None,
|
||||
class_weight=None):
|
||||
r"""Calculate the binary CrossEntropy loss with logits.
|
||||
|
||||
Args:
|
||||
|
@ -83,13 +96,20 @@ def binary_cross_entropy(pred,
|
|||
is same shape as pred and label. Defaults to 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (torch.Tensor, optional): The weight for each class with
|
||||
shape (C), C is the number of classes. Default None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert pred.dim() == label.dim()
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
|
||||
# Ensure that the size of class_weight is consistent with pred and label to
|
||||
# avoid automatic boracast,
|
||||
if class_weight is not None:
|
||||
N = pred.size()[0]
|
||||
class_weight = class_weight.repeat(N, 1)
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, label, weight=class_weight, reduction='none')
|
||||
|
||||
# apply weights and do the reduction
|
||||
if weight is not None:
|
||||
|
@ -114,13 +134,16 @@ class CrossEntropyLoss(nn.Module):
|
|||
reduction (str): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum". Defaults to 'mean'.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
class_weight (List[float], optional): The weight for each class with
|
||||
shape (C), C is the number of classes. Default None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid=False,
|
||||
use_soft=False,
|
||||
reduction='mean',
|
||||
loss_weight=1.0):
|
||||
loss_weight=1.0,
|
||||
class_weight=None):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.use_soft = use_soft
|
||||
|
@ -130,6 +153,7 @@ class CrossEntropyLoss(nn.Module):
|
|||
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = class_weight
|
||||
|
||||
if self.use_sigmoid:
|
||||
self.cls_criterion = binary_cross_entropy
|
||||
|
@ -148,10 +172,17 @@ class CrossEntropyLoss(nn.Module):
|
|||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
|
||||
if self.class_weight is not None:
|
||||
class_weight = cls_score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
cls_score,
|
||||
label,
|
||||
weight,
|
||||
class_weight=class_weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
**kwargs)
|
||||
|
|
|
@ -45,48 +45,92 @@ def test_cross_entropy_loss():
|
|||
loss = build_loss(loss_cfg)
|
||||
|
||||
# test ce_loss
|
||||
cls_score = torch.Tensor([[100, -100]])
|
||||
label = torch.Tensor([1]).long()
|
||||
weight = torch.tensor(0.5)
|
||||
cls_score = torch.Tensor([[-1000, 1000], [100, -100]])
|
||||
label = torch.Tensor([0, 1]).long()
|
||||
class_weight = [0.3, 0.7] # class 0 : 0.3, class 1 : 0.7
|
||||
weight = torch.tensor([0.6, 0.4])
|
||||
|
||||
# test ce_loss without class weight
|
||||
loss_cfg = dict(type='CrossEntropyLoss', reduction='mean', loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(200.))
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(1100.))
|
||||
# test ce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(100.))
|
||||
loss(cls_score, label, weight=weight), torch.tensor(640.))
|
||||
|
||||
# test ce_loss with class weight
|
||||
loss_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
class_weight=class_weight)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(370.))
|
||||
# test ce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(208.))
|
||||
|
||||
# test bce_loss
|
||||
cls_score = torch.Tensor([[100, -100], [100, -100]])
|
||||
label = torch.Tensor([[1, 0], [0, 1]])
|
||||
weight = torch.Tensor([0.5, 0.5])
|
||||
cls_score = torch.Tensor([[-200, 100], [500, -1000], [300, -300]])
|
||||
label = torch.Tensor([[1, 0], [0, 1], [1, 0]])
|
||||
weight = torch.Tensor([0.6, 0.4, 0.5])
|
||||
class_weight = [0.1, 0.9] # class 0: 0.1, class 1: 0.9
|
||||
|
||||
# test bce_loss without class weight
|
||||
loss_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(50.))
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(300.))
|
||||
# test ce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(25.))
|
||||
loss(cls_score, label, weight=weight), torch.tensor(130.))
|
||||
|
||||
# test bce_loss with class weight
|
||||
loss_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
class_weight=class_weight)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(176.667))
|
||||
# test bce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(74.333))
|
||||
|
||||
# test soft_ce_loss
|
||||
cls_score = torch.Tensor([[100, -100]])
|
||||
label = torch.Tensor([[1, 0], [0, 1]])
|
||||
weight = torch.tensor(0.5)
|
||||
cls_score = torch.Tensor([[-1000, 1000], [100, -100]])
|
||||
label = torch.Tensor([[1.0, 0.0], [0.0, 1.0]])
|
||||
class_weight = [0.3, 0.7] # class 0 : 0.3, class 1 : 0.7
|
||||
weight = torch.tensor([0.6, 0.4])
|
||||
|
||||
# test soft_ce_loss without class weight
|
||||
loss_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_soft=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(100.))
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(1100.))
|
||||
# test soft_ce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(50.))
|
||||
loss(cls_score, label, weight=weight), torch.tensor(640.))
|
||||
|
||||
# test soft_ce_loss with class weight
|
||||
loss_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_soft=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
class_weight=class_weight)
|
||||
loss = build_loss(loss_cfg)
|
||||
assert torch.allclose(loss(cls_score, label), torch.tensor(370.))
|
||||
# test soft_ce_loss with weight
|
||||
assert torch.allclose(
|
||||
loss(cls_score, label, weight=weight), torch.tensor(208.))
|
||||
|
||||
|
||||
def test_focal_loss():
|
||||
|
|
Loading…
Reference in New Issue