dice loss (#396)
* dice loss * format code, add docstring and calculate denominator without valid_mask * minor change * restorepull/425/head
parent
d0a71c1509
commit
7e1b24dd32
|
@ -1,11 +1,12 @@
|
|||
from .accuracy import Accuracy, accuracy
|
||||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy, mask_cross_entropy)
|
||||
from .dice_loss import DiceLoss
|
||||
from .lovasz_loss import LovaszLoss
|
||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
||||
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss'
|
||||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
|
||||
segmentron/solver/loss.py (Apache-2.0 License)"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weighted_loss
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def dice_loss(pred,
|
||||
target,
|
||||
valid_mask,
|
||||
smooth=1,
|
||||
exponent=2,
|
||||
class_weight=None,
|
||||
ignore_index=-1):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
total_loss = 0
|
||||
num_classes = pred.shape[1]
|
||||
for i in range(num_classes):
|
||||
if i != ignore_index:
|
||||
dice_loss = binary_dice_loss(
|
||||
pred[:, i],
|
||||
target[..., i],
|
||||
valid_mask=valid_mask,
|
||||
smooth=smooth,
|
||||
exponent=exponent)
|
||||
if class_weight is not None:
|
||||
dice_loss *= class_weight[i]
|
||||
total_loss += dice_loss
|
||||
return total_loss / num_classes
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
pred = pred.contiguous().view(pred.shape[0], -1)
|
||||
target = target.contiguous().view(target.shape[0], -1)
|
||||
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1)
|
||||
|
||||
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
|
||||
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
|
||||
|
||||
return 1 - num / den
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class DiceLoss(nn.Module):
|
||||
"""DiceLoss.
|
||||
|
||||
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
|
||||
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
||||
|
||||
Args:
|
||||
loss_type (str, optional): Binary or multi-class loss.
|
||||
Default: 'multi_class'. Options are "binary" and "multi_class".
|
||||
smooth (float): A float number to smooth loss, and avoid NaN error.
|
||||
Default: 1
|
||||
exponent (float): An float number to calculate denominator
|
||||
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss_type='multi_class',
|
||||
smooth=1,
|
||||
exponent=2,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255):
|
||||
super(DiceLoss, self).__init__()
|
||||
assert loss_type in ['multi_class', 'binary']
|
||||
if loss_type == 'multi_class':
|
||||
self.cls_criterion = dice_loss
|
||||
else:
|
||||
self.cls_criterion = binary_dice_loss
|
||||
self.smooth = smooth
|
||||
self.exponent = exponent
|
||||
self.reduction = reduction
|
||||
self.class_weight = class_weight
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, pred, target, avg_factor=None, reduction_override=None):
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = pred.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pred = F.softmax(pred, dim=1)
|
||||
one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0))
|
||||
valid_mask = (target != self.ignore_index).long()
|
||||
|
||||
loss = self.loss_weight * self.cls_criterion(
|
||||
pred,
|
||||
one_hot_target,
|
||||
valid_mask=valid_mask,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
smooth=self.smooth,
|
||||
exponent=self.exponent,
|
||||
class_weight=class_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss
|
|
@ -202,3 +202,43 @@ def test_lovasz_loss():
|
|||
logits = torch.rand(2, 4, 4)
|
||||
labels = (torch.rand(2, 4, 4)).long()
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
|
||||
def test_dice_lose():
|
||||
from mmseg.models import build_loss
|
||||
|
||||
# loss_type should be 'binary' or 'multi_class'
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
loss_type='Binary',
|
||||
reduction='none',
|
||||
loss_weight=1.0)
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# test dice loss with loss_type = 'multi_class'
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
loss_type='multi_class',
|
||||
reduction='none',
|
||||
class_weight=[1.0, 2.0, 3.0],
|
||||
loss_weight=1.0,
|
||||
ignore_index=1)
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(8, 3, 4, 4)
|
||||
labels = (torch.rand(8, 4, 4) * 3).long()
|
||||
dice_loss(logits, labels)
|
||||
|
||||
# test dice loss with loss_type = 'binary'
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
loss_type='binary',
|
||||
smooth=2,
|
||||
exponent=3,
|
||||
reduction='sum',
|
||||
loss_weight=1.0,
|
||||
ignore_index=0)
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(16, 4, 4)
|
||||
labels = (torch.rand(16, 4, 4)).long()
|
||||
dice_loss(logits, labels)
|
||||
|
|
Loading…
Reference in New Issue