diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py index 27da861f9..b94ece3a2 100644 --- a/mmseg/models/losses/dice_loss.py +++ b/mmseg/models/losses/dice_loss.py @@ -15,7 +15,7 @@ def dice_loss(pred, smooth=1, exponent=2, class_weight=None, - ignore_index=-1): + ignore_index=255): assert pred.shape[0] == target.shape[0] total_loss = 0 num_classes = pred.shape[1] @@ -36,9 +36,9 @@ def dice_loss(pred, @weighted_loss def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): assert pred.shape[0] == target.shape[0] - pred = pred.contiguous().view(pred.shape[0], -1) - target = target.contiguous().view(target.shape[0], -1) - valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1) + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth @@ -70,19 +70,14 @@ class DiceLoss(nn.Module): """ def __init__(self, - loss_type='multi_class', smooth=1, exponent=2, reduction='mean', class_weight=None, loss_weight=1.0, - ignore_index=255): + ignore_index=255, + **kwards): super(DiceLoss, self).__init__() - assert loss_type in ['multi_class', 'binary'] - if loss_type == 'multi_class': - self.cls_criterion = dice_loss - else: - self.cls_criterion = binary_dice_loss self.smooth = smooth self.exponent = exponent self.reduction = reduction @@ -90,7 +85,12 @@ class DiceLoss(nn.Module): self.loss_weight = loss_weight self.ignore_index = ignore_index - def forward(self, pred, target, avg_factor=None, reduction_override=None): + def forward(self, + pred, + target, + avg_factor=None, + reduction_override=None, + **kwards): assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) @@ -100,10 +100,13 @@ class DiceLoss(nn.Module): class_weight = None pred = F.softmax(pred, dim=1) - one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0)) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) valid_mask = (target != self.ignore_index).long() - loss = self.loss_weight * self.cls_criterion( + loss = self.loss_weight * dice_loss( pred, one_hot_target, valid_mask=valid_mask, diff --git a/tests/test_models/test_losses.py b/tests/test_models/test_losses.py index 481a8e92c..c58e6a505 100644 --- a/tests/test_models/test_losses.py +++ b/tests/test_models/test_losses.py @@ -207,19 +207,9 @@ def test_lovasz_loss(): def test_dice_lose(): from mmseg.models import build_loss - # loss_type should be 'binary' or 'multi_class' - with pytest.raises(AssertionError): - loss_cfg = dict( - type='DiceLoss', - loss_type='Binary', - reduction='none', - loss_weight=1.0) - build_loss(loss_cfg) - # test dice loss with loss_type = 'multi_class' loss_cfg = dict( type='DiceLoss', - loss_type='multi_class', reduction='none', class_weight=[1.0, 2.0, 3.0], loss_weight=1.0, @@ -232,13 +222,12 @@ def test_dice_lose(): # test dice loss with loss_type = 'binary' loss_cfg = dict( type='DiceLoss', - loss_type='binary', smooth=2, exponent=3, reduction='sum', loss_weight=1.0, ignore_index=0) dice_loss = build_loss(loss_cfg) - logits = torch.rand(16, 4, 4) - labels = (torch.rand(16, 4, 4)).long() + logits = torch.rand(8, 2, 4, 4) + labels = (torch.rand(8, 4, 4) * 2).long() dice_loss(logits, labels)