fix bug of dice loss: loss always 1 if empty target (#408)

pull/395/head^2
Jiaqi Duan 2021-08-04 14:05:31 +08:00 committed by GitHub
parent 3707d67106
commit f24be6c614
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 1 deletions

View File

@ -261,7 +261,7 @@ class PANLoss(nn.Module):
pred = pred * mask
target = target * mask
a = torch.sum(pred * target, 1)
a = torch.sum(pred * target, 1) + smooth
b = torch.sum(pred * pred, 1) + smooth
c = torch.sum(target * target, 1) + smooth
d = (2 * a) / (b + c)

View File

@ -143,3 +143,16 @@ def test_drrgloss():
target_maps, target_maps, target_maps, target_maps)
assert isinstance(loss_dict, dict)
def test_dice_loss():
pred = torch.Tensor([[[-1000, -1000, -1000], [-1000, -1000, -1000],
[-1000, -1000, -1000]]])
target = torch.Tensor([[[0, 0, 0], [0, 0, 0], [0, 0, 0]]])
mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
pan_loss = losses.PANLoss()
dice_loss = pan_loss.dice_loss_with_logits(pred, target, mask)
assert np.allclose(dice_loss.item(), 0)