mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Added ignore_index param to forward(), also implemented one hot encoding to ensure the dims of target matches pred. Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Please describe the motivation of this PR and the goal you want to achieve through this PR. Attempted to solve the problems mentioned by #3172 ## Modification Please briefly describe what modification is made in this PR. Added ignore_index into forward function (although the dice loss itself does not actually take account for it for some reason). Added _expand_onehot_labels_dice, which takes the target with shape [N, H, W] into [N, num_classes, H, W]. ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. This is my first time contributing to open-source code, so I might have made some stupid mistakes. Please don't hesitate to point it out.
97 lines
3.7 KiB
Python
97 lines
3.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.losses import DiceLoss
|
|
|
|
|
|
@pytest.mark.parametrize('naive_dice', [True, False])
|
|
def test_dice_loss(naive_dice):
|
|
loss_class = DiceLoss
|
|
pred = torch.rand((1, 10, 4, 4))
|
|
target = torch.randint(0, 10, (1, 4, 4))
|
|
weight = torch.rand(1)
|
|
# Test loss forward
|
|
loss = loss_class(naive_dice=naive_dice)(pred, target)
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
# Test loss forward with weight
|
|
loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
# Test loss forward with reduction_override
|
|
loss = loss_class(naive_dice=naive_dice)(
|
|
pred, target, reduction_override='mean')
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
# Test loss forward with avg_factor
|
|
loss = loss_class(naive_dice=naive_dice)(pred, target, avg_factor=10)
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
with pytest.raises(ValueError):
|
|
# loss can evaluate with avg_factor only if
|
|
# reduction is None, 'none' or 'mean'.
|
|
reduction_override = 'sum'
|
|
loss_class(naive_dice=naive_dice)(
|
|
pred, target, avg_factor=10, reduction_override=reduction_override)
|
|
|
|
# Test loss forward with avg_factor and reduction
|
|
for reduction_override in [None, 'none', 'mean']:
|
|
loss_class(naive_dice=naive_dice)(
|
|
pred, target, avg_factor=10, reduction_override=reduction_override)
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
# Test loss forward with has_acted=False and use_sigmoid=False
|
|
for use_sigmoid in [True, False]:
|
|
loss_class(
|
|
use_sigmoid=use_sigmoid, activate=True,
|
|
naive_dice=naive_dice)(pred, target)
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
# Test loss forward with weight.ndim != loss.ndim
|
|
with pytest.raises(AssertionError):
|
|
weight = torch.rand((2, 8))
|
|
loss_class(naive_dice=naive_dice)(pred, target, weight)
|
|
|
|
# Test loss forward with len(weight) != len(pred)
|
|
with pytest.raises(AssertionError):
|
|
weight = torch.rand(8)
|
|
loss_class(naive_dice=naive_dice)(pred, target, weight)
|
|
|
|
# Test _expand_onehot_labels_dice
|
|
pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float()
|
|
target = torch.tensor([[[0, 0], [0, 1]]])
|
|
target_onehot = torch.tensor([[[[1, 1], [1, 0]], [[0, 0], [0, 1]]]])
|
|
weight = torch.rand(1)
|
|
loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
|
|
loss_onehot = loss_class(naive_dice=naive_dice)(pred, target_onehot,
|
|
weight)
|
|
assert torch.equal(loss, loss_onehot)
|
|
|
|
# Test Whether Loss is 0 when pred == target, eps == 0 and naive_dice=False
|
|
target = torch.randint(0, 2, (1, 10, 4, 4))
|
|
pred = target.float()
|
|
target = target.sigmoid()
|
|
weight = torch.rand(1)
|
|
loss = loss_class(
|
|
naive_dice=False, use_sigmoid=True, eps=0)(pred, target, weight)
|
|
assert loss.item() == 0
|
|
|
|
# Test ignore_index when ignore_index is the only class
|
|
with pytest.raises(AssertionError):
|
|
pred = torch.ones((1, 1, 4, 4))
|
|
target = torch.randint(0, 1, (1, 4, 4))
|
|
weight = torch.rand(1)
|
|
loss = loss_class(
|
|
naive_dice=naive_dice, use_sigmoid=False, ignore_index=0,
|
|
eps=0)(pred, target, weight)
|
|
|
|
# Test ignore_index with naive_dice = False
|
|
pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float()
|
|
target = torch.tensor([[[[1, 1], [1, 0]], [[1, 0], [0, 1]]]]).sigmoid()
|
|
weight = torch.rand(1)
|
|
loss = loss_class(
|
|
naive_dice=False, use_sigmoid=True, ignore_index=1,
|
|
eps=0)(pred, target, weight)
|
|
assert loss.item() == 0
|