[Fix] Support single cahnnel `pred` for Binary Cross Entropy Loss (#1454)

* [Fix] Fix the bug that binary cross entropy loss doesn't support single channel input

* imcrease coverage

* modify implementation

* increase coverage

* add assert

* modify implementation

* enshollow condition judge

* fix
pull/1477/head
Rockey 2022-04-14 11:26:02 +08:00 committed by GitHub
parent 23ae1ebab6
commit cd18b6d479
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 2 deletions

View File

@ -81,7 +81,7 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights *= valid_mask
bin_label_weights = bin_label_weights * valid_mask
return bin_labels, bin_label_weights, valid_mask
@ -115,6 +115,13 @@ def binary_cross_entropy(pred,
Returns:
torch.Tensor: The calculated loss
"""
if pred.size(1) == 1:
# For binary class segmentation, the shape of pred is
# [N, 1, H, W] and that of label is [N, H, W].
assert label.max() <= 1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'
pred = pred.squeeze()
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \
@ -128,7 +135,7 @@ def binary_cross_entropy(pred,
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight *= valid_mask
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored and valid elements

View File

@ -85,6 +85,35 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
ignore_index=255) / fake_label.numel()
assert torch.allclose(loss, torch_loss)
if use_sigmoid:
# test loss with complicated case for ce/bce
# when avg_non_ignore is False, `avg_factor` would not be calculated
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
fake_label[:, 0, 0] = 255
fake_weight = torch.rand(2, 8, 8)
loss_cls = build_loss(loss_cls_cfg)
loss = loss_cls(
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
if use_sigmoid:
fake_label, weight, valid_mask = _expand_onehot_labels(
labels=fake_label,
label_weights=None,
target_shape=fake_pred.shape,
ignore_index=255)
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.float(),
reduction='none',
weight=fake_weight.unsqueeze(1).expand(fake_pred.shape))
if avg_non_ignore:
avg_factor = valid_mask.sum().item()
torch_loss = (torch_loss * weight).sum() / avg_factor
else:
torch_loss = (torch_loss * weight).mean()
assert torch.allclose(loss, torch_loss)
# test loss with class weights from file
fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long()
@ -223,3 +252,43 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
reduction='sum',
weight=class_weight) / fake_label.numel()
assert torch.allclose(loss, torch_loss)
@pytest.mark.parametrize('avg_non_ignore', [True, False])
@pytest.mark.parametrize('with_weight', [True, False])
def test_binary_class_ce_loss(avg_non_ignore, with_weight):
from mmseg.models import build_loss
fake_pred = torch.rand(3, 1, 10, 10)
fake_label = torch.randint(0, 2, (3, 10, 10))
fake_weight = torch.rand(3, 10, 10)
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
weight = valid_mask
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.unsqueeze(1).float(),
reduction='none',
weight=fake_weight.unsqueeze(1).float() if with_weight else None)
if avg_non_ignore:
eps = torch.finfo(torch.float32).eps
avg_factor = valid_mask.sum().item()
torch_loss = (torch_loss * weight.unsqueeze(1)).sum() / (
avg_factor + eps)
else:
torch_loss = (torch_loss * weight.unsqueeze(1)).mean()
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0,
avg_non_ignore=avg_non_ignore,
reduction='mean',
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
loss = loss_cls(
fake_pred,
fake_label,
weight=fake_weight if with_weight else None,
ignore_index=255)
assert torch.allclose(loss, torch_loss)