mirror of https://github.com/open-mmlab/mmocr.git
fix bug of dice loss: loss always 1 if empty target (#408)
parent
3707d67106
commit
f24be6c614
|
@ -261,7 +261,7 @@ class PANLoss(nn.Module):
|
|||
pred = pred * mask
|
||||
target = target * mask
|
||||
|
||||
a = torch.sum(pred * target, 1)
|
||||
a = torch.sum(pred * target, 1) + smooth
|
||||
b = torch.sum(pred * pred, 1) + smooth
|
||||
c = torch.sum(target * target, 1) + smooth
|
||||
d = (2 * a) / (b + c)
|
||||
|
|
|
@ -143,3 +143,16 @@ def test_drrgloss():
|
|||
target_maps, target_maps, target_maps, target_maps)
|
||||
|
||||
assert isinstance(loss_dict, dict)
|
||||
|
||||
|
||||
def test_dice_loss():
|
||||
pred = torch.Tensor([[[-1000, -1000, -1000], [-1000, -1000, -1000],
|
||||
[-1000, -1000, -1000]]])
|
||||
target = torch.Tensor([[[0, 0, 0], [0, 0, 0], [0, 0, 0]]])
|
||||
mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
|
||||
|
||||
pan_loss = losses.PANLoss()
|
||||
|
||||
dice_loss = pan_loss.dice_loss_with_logits(pred, target, mask)
|
||||
|
||||
assert np.allclose(dice_loss.item(), 0)
|
||||
|
|
Loading…
Reference in New Issue