diff --git a/mmocr/models/textdet/losses/pan_loss.py b/mmocr/models/textdet/losses/pan_loss.py index 28ca6f2c..c0751f48 100644 --- a/mmocr/models/textdet/losses/pan_loss.py +++ b/mmocr/models/textdet/losses/pan_loss.py @@ -261,7 +261,7 @@ class PANLoss(nn.Module): pred = pred * mask target = target * mask - a = torch.sum(pred * target, 1) + a = torch.sum(pred * target, 1) + smooth b = torch.sum(pred * pred, 1) + smooth c = torch.sum(target * target, 1) + smooth d = (2 * a) / (b + c) diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index da99c00e..5235b673 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -143,3 +143,16 @@ def test_drrgloss(): target_maps, target_maps, target_maps, target_maps) assert isinstance(loss_dict, dict) + + +def test_dice_loss(): + pred = torch.Tensor([[[-1000, -1000, -1000], [-1000, -1000, -1000], + [-1000, -1000, -1000]]]) + target = torch.Tensor([[[0, 0, 0], [0, 0, 0], [0, 0, 0]]]) + mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) + + pan_loss = losses.PANLoss() + + dice_loss = pan_loss.dice_loss_with_logits(pred, target, mask) + + assert np.allclose(dice_loss.item(), 0)