[Fix] Support single cahnnel `pred` for Binary Cross Entropy Loss (#1454)
* [Fix] Fix the bug that binary cross entropy loss doesn't support single channel input * imcrease coverage * modify implementation * increase coverage * add assert * modify implementation * enshollow condition judge * fixpull/1477/head
parent
23ae1ebab6
commit
cd18b6d479
|
@ -81,7 +81,7 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
|||
bin_label_weights = valid_mask
|
||||
else:
|
||||
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||||
bin_label_weights *= valid_mask
|
||||
bin_label_weights = bin_label_weights * valid_mask
|
||||
|
||||
return bin_labels, bin_label_weights, valid_mask
|
||||
|
||||
|
@ -115,6 +115,13 @@ def binary_cross_entropy(pred,
|
|||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
if pred.size(1) == 1:
|
||||
# For binary class segmentation, the shape of pred is
|
||||
# [N, 1, H, W] and that of label is [N, H, W].
|
||||
assert label.max() <= 1, \
|
||||
'For pred with shape [N, 1, H, W], its label must have at ' \
|
||||
'most 2 classes'
|
||||
pred = pred.squeeze()
|
||||
if pred.dim() != label.dim():
|
||||
assert (pred.dim() == 2 and label.dim() == 1) or (
|
||||
pred.dim() == 4 and label.dim() == 3), \
|
||||
|
@ -128,7 +135,7 @@ def binary_cross_entropy(pred,
|
|||
# should mask out the ignored elements
|
||||
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
||||
if weight is not None:
|
||||
weight *= valid_mask
|
||||
weight = weight * valid_mask
|
||||
else:
|
||||
weight = valid_mask
|
||||
# average loss over non-ignored and valid elements
|
||||
|
|
|
@ -85,6 +85,35 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
|
|||
ignore_index=255) / fake_label.numel()
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
if use_sigmoid:
|
||||
# test loss with complicated case for ce/bce
|
||||
# when avg_non_ignore is False, `avg_factor` would not be calculated
|
||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||
fake_label = torch.ones(2, 8, 8).long()
|
||||
fake_label[:, 0, 0] = 255
|
||||
fake_weight = torch.rand(2, 8, 8)
|
||||
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
loss = loss_cls(
|
||||
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
|
||||
if use_sigmoid:
|
||||
fake_label, weight, valid_mask = _expand_onehot_labels(
|
||||
labels=fake_label,
|
||||
label_weights=None,
|
||||
target_shape=fake_pred.shape,
|
||||
ignore_index=255)
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred,
|
||||
fake_label.float(),
|
||||
reduction='none',
|
||||
weight=fake_weight.unsqueeze(1).expand(fake_pred.shape))
|
||||
if avg_non_ignore:
|
||||
avg_factor = valid_mask.sum().item()
|
||||
torch_loss = (torch_loss * weight).sum() / avg_factor
|
||||
else:
|
||||
torch_loss = (torch_loss * weight).mean()
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
# test loss with class weights from file
|
||||
fake_pred = torch.Tensor([[100, -100]])
|
||||
fake_label = torch.Tensor([1]).long()
|
||||
|
@ -223,3 +252,43 @@ def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
|
|||
reduction='sum',
|
||||
weight=class_weight) / fake_label.numel()
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('avg_non_ignore', [True, False])
|
||||
@pytest.mark.parametrize('with_weight', [True, False])
|
||||
def test_binary_class_ce_loss(avg_non_ignore, with_weight):
|
||||
from mmseg.models import build_loss
|
||||
|
||||
fake_pred = torch.rand(3, 1, 10, 10)
|
||||
fake_label = torch.randint(0, 2, (3, 10, 10))
|
||||
fake_weight = torch.rand(3, 10, 10)
|
||||
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
|
||||
weight = valid_mask
|
||||
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred,
|
||||
fake_label.unsqueeze(1).float(),
|
||||
reduction='none',
|
||||
weight=fake_weight.unsqueeze(1).float() if with_weight else None)
|
||||
if avg_non_ignore:
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
avg_factor = valid_mask.sum().item()
|
||||
torch_loss = (torch_loss * weight.unsqueeze(1)).sum() / (
|
||||
avg_factor + eps)
|
||||
else:
|
||||
torch_loss = (torch_loss * weight.unsqueeze(1)).mean()
|
||||
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=avg_non_ignore,
|
||||
reduction='mean',
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
loss = loss_cls(
|
||||
fake_pred,
|
||||
fake_label,
|
||||
weight=fake_weight if with_weight else None,
|
||||
ignore_index=255)
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
|
Loading…
Reference in New Issue