mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Bug fixed]Fix dice_loss errors (#417)
* fix training bugs * fix unitest error * fix error in num_classes==2 case * delete comments
This commit is contained in:
parent
d474cfde4b
commit
71be1c2793
@ -15,7 +15,7 @@ def dice_loss(pred,
|
|||||||
smooth=1,
|
smooth=1,
|
||||||
exponent=2,
|
exponent=2,
|
||||||
class_weight=None,
|
class_weight=None,
|
||||||
ignore_index=-1):
|
ignore_index=255):
|
||||||
assert pred.shape[0] == target.shape[0]
|
assert pred.shape[0] == target.shape[0]
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
num_classes = pred.shape[1]
|
num_classes = pred.shape[1]
|
||||||
@ -36,9 +36,9 @@ def dice_loss(pred,
|
|||||||
@weighted_loss
|
@weighted_loss
|
||||||
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
|
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
|
||||||
assert pred.shape[0] == target.shape[0]
|
assert pred.shape[0] == target.shape[0]
|
||||||
pred = pred.contiguous().view(pred.shape[0], -1)
|
pred = pred.reshape(pred.shape[0], -1)
|
||||||
target = target.contiguous().view(target.shape[0], -1)
|
target = target.reshape(target.shape[0], -1)
|
||||||
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1)
|
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
||||||
|
|
||||||
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
|
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
|
||||||
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
|
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
|
||||||
@ -70,19 +70,14 @@ class DiceLoss(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
loss_type='multi_class',
|
|
||||||
smooth=1,
|
smooth=1,
|
||||||
exponent=2,
|
exponent=2,
|
||||||
reduction='mean',
|
reduction='mean',
|
||||||
class_weight=None,
|
class_weight=None,
|
||||||
loss_weight=1.0,
|
loss_weight=1.0,
|
||||||
ignore_index=255):
|
ignore_index=255,
|
||||||
|
**kwards):
|
||||||
super(DiceLoss, self).__init__()
|
super(DiceLoss, self).__init__()
|
||||||
assert loss_type in ['multi_class', 'binary']
|
|
||||||
if loss_type == 'multi_class':
|
|
||||||
self.cls_criterion = dice_loss
|
|
||||||
else:
|
|
||||||
self.cls_criterion = binary_dice_loss
|
|
||||||
self.smooth = smooth
|
self.smooth = smooth
|
||||||
self.exponent = exponent
|
self.exponent = exponent
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
@ -90,7 +85,12 @@ class DiceLoss(nn.Module):
|
|||||||
self.loss_weight = loss_weight
|
self.loss_weight = loss_weight
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
def forward(self, pred, target, avg_factor=None, reduction_override=None):
|
def forward(self,
|
||||||
|
pred,
|
||||||
|
target,
|
||||||
|
avg_factor=None,
|
||||||
|
reduction_override=None,
|
||||||
|
**kwards):
|
||||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||||
reduction = (
|
reduction = (
|
||||||
reduction_override if reduction_override else self.reduction)
|
reduction_override if reduction_override else self.reduction)
|
||||||
@ -100,10 +100,13 @@ class DiceLoss(nn.Module):
|
|||||||
class_weight = None
|
class_weight = None
|
||||||
|
|
||||||
pred = F.softmax(pred, dim=1)
|
pred = F.softmax(pred, dim=1)
|
||||||
one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0))
|
num_classes = pred.shape[1]
|
||||||
|
one_hot_target = F.one_hot(
|
||||||
|
torch.clamp(target.long(), 0, num_classes - 1),
|
||||||
|
num_classes=num_classes)
|
||||||
valid_mask = (target != self.ignore_index).long()
|
valid_mask = (target != self.ignore_index).long()
|
||||||
|
|
||||||
loss = self.loss_weight * self.cls_criterion(
|
loss = self.loss_weight * dice_loss(
|
||||||
pred,
|
pred,
|
||||||
one_hot_target,
|
one_hot_target,
|
||||||
valid_mask=valid_mask,
|
valid_mask=valid_mask,
|
||||||
|
@ -207,19 +207,9 @@ def test_lovasz_loss():
|
|||||||
def test_dice_lose():
|
def test_dice_lose():
|
||||||
from mmseg.models import build_loss
|
from mmseg.models import build_loss
|
||||||
|
|
||||||
# loss_type should be 'binary' or 'multi_class'
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
loss_cfg = dict(
|
|
||||||
type='DiceLoss',
|
|
||||||
loss_type='Binary',
|
|
||||||
reduction='none',
|
|
||||||
loss_weight=1.0)
|
|
||||||
build_loss(loss_cfg)
|
|
||||||
|
|
||||||
# test dice loss with loss_type = 'multi_class'
|
# test dice loss with loss_type = 'multi_class'
|
||||||
loss_cfg = dict(
|
loss_cfg = dict(
|
||||||
type='DiceLoss',
|
type='DiceLoss',
|
||||||
loss_type='multi_class',
|
|
||||||
reduction='none',
|
reduction='none',
|
||||||
class_weight=[1.0, 2.0, 3.0],
|
class_weight=[1.0, 2.0, 3.0],
|
||||||
loss_weight=1.0,
|
loss_weight=1.0,
|
||||||
@ -232,13 +222,12 @@ def test_dice_lose():
|
|||||||
# test dice loss with loss_type = 'binary'
|
# test dice loss with loss_type = 'binary'
|
||||||
loss_cfg = dict(
|
loss_cfg = dict(
|
||||||
type='DiceLoss',
|
type='DiceLoss',
|
||||||
loss_type='binary',
|
|
||||||
smooth=2,
|
smooth=2,
|
||||||
exponent=3,
|
exponent=3,
|
||||||
reduction='sum',
|
reduction='sum',
|
||||||
loss_weight=1.0,
|
loss_weight=1.0,
|
||||||
ignore_index=0)
|
ignore_index=0)
|
||||||
dice_loss = build_loss(loss_cfg)
|
dice_loss = build_loss(loss_cfg)
|
||||||
logits = torch.rand(16, 4, 4)
|
logits = torch.rand(8, 2, 4, 4)
|
||||||
labels = (torch.rand(16, 4, 4)).long()
|
labels = (torch.rand(8, 4, 4) * 2).long()
|
||||||
dice_loss(logits, labels)
|
dice_loss(logits, labels)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user