[Fix] Add avg_non_ignore in cross entropy loss (#1409)

* [Fix] Add avg_non_ignore in cross entropy loss

* [Fix] Add avg_non_ignore in cross entropy loss

* add docstring

* fix ut

* fix docstring and comments

* fix

* fix bce

* fix avg_factor in BCE and add more ut

* add avg_non_ignore

* add more ut

* fix part of ut

* fix part of ut

* test avg_non_ignore would not affect ce/bce when reduction none/sum

* test avg_non_ignore would not affect ce/bce when reduction none/sum/mean

* re-organize ut

* re-organize ut

* re-organize ut

* re-organize hardcode case

* fix parts of comments

* fix another parts of comments

* fix
This commit is contained in:
MengzhangLI 2022-03-30 18:32:47 +08:00 committed by GitHub
parent 24f1563571
commit a82ebad0f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 292 additions and 39 deletions

View File

@ -68,3 +68,23 @@ model = dict(
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
## Ignore specified label index in loss calculation
In default setting, `avg_non_ignore=False` which means each pixel counts for loss calculation although some of them belong to ignore-index labels.
For loss calculation, we support ignore index of certain label by `avg_non_ignore` and `ignore_index`. In this way, the average loss would only be calculated in non-ignored labels which may achieve better performance, and here is the [reference](https://github.com/open-mmlab/mmsegmentation/pull/1409). Here is an example config of training `unet` on `Cityscapes` dataset: in loss calculation it would ignore label 0 which is background and loss average is only calculated on non-ignore labels:
```python
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
```

View File

@ -68,3 +68,28 @@ model = dict(
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
## 在损失函数中忽略特定的 label 类别
默认设置 `avg_non_ignore=False` 即每个像素都用来计算损失函数。尽管其中的一些像素属于需要被忽略的类别。
对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore``ignore_index` 来忽略 label 特定的类别。 这样损失函数将只在非忽略类别像素中求平均值,会获得更好的表现。这里是[相关 PR](https://github.com/open-mmlab/mmsegmentation/pull/1409)。以 `unet` 使用 `Cityscapes` 数据集训练为例,
在计算损失函数时,忽略 label 为0的背景并且仅在不被忽略的像素上计算均值。配置文件写为:
```python
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
```
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -13,8 +15,31 @@ def cross_entropy(pred,
class_weight=None,
reduction='mean',
avg_factor=None,
ignore_index=-100):
"""The wrapper function for :func:`F.cross_entropy`"""
ignore_index=-100,
avg_non_ignore=False):
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss.
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
# class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy(
@ -25,6 +50,11 @@ def cross_entropy(pred,
ignore_index=ignore_index)
# apply weights and do the reduction
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
@ -46,13 +76,14 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights *= valid_mask
return bin_labels, bin_label_weights
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(pred,
@ -61,19 +92,25 @@ def binary_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=255):
ignore_index=-100,
avg_non_ignore=False,
**kwargs):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
Note: In bce loss, label < 0 is invalid.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored. Default: 255
ignore_index (int): The label index to be ignored. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
Returns:
torch.Tensor: The calculated loss
@ -83,12 +120,21 @@ def binary_cross_entropy(pred,
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
label, weight = _expand_onehot_labels(label, weight, pred.shape,
ignore_index)
# `weight` returned from `_expand_onehot_labels`
# has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.shape, ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight *= valid_mask
else:
weight = valid_mask
# average loss over non-ignored and valid elements
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()
# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
@ -104,7 +150,8 @@ def mask_cross_entropy(pred,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None):
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
@ -153,6 +200,9 @@ class CrossEntropyLoss(nn.Module):
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
def __init__(self,
@ -161,7 +211,8 @@ class CrossEntropyLoss(nn.Module):
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_ce'):
loss_name='loss_ce',
avg_non_ignore=False):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
@ -169,6 +220,13 @@ class CrossEntropyLoss(nn.Module):
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self.avg_non_ignore = avg_non_ignore
if not self.avg_non_ignore and self.reduction == 'mean':
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
@ -178,12 +236,18 @@ class CrossEntropyLoss(nn.Module):
self.cls_criterion = cross_entropy
self._loss_name = loss_name
def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=-100,
**kwargs):
"""Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum')
@ -193,6 +257,7 @@ class CrossEntropyLoss(nn.Module):
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# Note: for BCE loss, label < 0 is invalid.
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
@ -200,6 +265,8 @@ class CrossEntropyLoss(nn.Module):
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs)
return loss_cls
@ -212,6 +279,7 @@ class CrossEntropyLoss(nn.Module):
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""

View File

@ -3,6 +3,7 @@ import functools
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
@ -69,7 +70,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')

View File

@ -2,8 +2,14 @@
import pytest
import torch
from mmseg.models.losses.cross_entropy_loss import _expand_onehot_labels
def test_ce_loss():
@pytest.mark.parametrize('use_sigmoid', [True, False])
@pytest.mark.parametrize('reduction', ('mean', 'sum', 'none'))
@pytest.mark.parametrize('avg_non_ignore', [True, False])
@pytest.mark.parametrize('bce_input_same_dim', [True, False])
def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
from mmseg.models import build_loss
# use_mask and use_sigmoid cannot be true at the same time
@ -15,19 +21,73 @@ def test_ce_loss():
loss_weight=1.0)
build_loss(loss_cfg)
# test loss with class weights
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=[0.8, 0.2],
loss_weight=1.0,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
# test loss with simple case for ce/bce
fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long()
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
loss_weight=1.0,
avg_non_ignore=avg_non_ignore,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
if use_sigmoid:
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(100.))
else:
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(200.))
# test loss with complicated case for ce/bce
# when avg_non_ignore is False, `avg_factor` would not be calculated
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
fake_label[:, 0, 0] = 255
fake_weight = None
# extra test bce loss when pred.shape == label.shape
if use_sigmoid and bce_input_same_dim:
fake_pred = torch.randn(2, 10).float()
fake_label = torch.rand(2, 10).float()
fake_weight = torch.rand(2, 10) # set weight in forward function
fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index
fake_label[1, [0, 5, 8, 9]] = 255
loss_cls = build_loss(loss_cls_cfg)
loss = loss_cls(
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
if use_sigmoid:
if fake_pred.dim() != fake_label.dim():
fake_label, weight, valid_mask = _expand_onehot_labels(
labels=fake_label,
label_weights=None,
target_shape=fake_pred.shape,
ignore_index=255)
else:
# should mask out the ignored elements
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
weight = valid_mask
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.float(),
reduction='none',
weight=fake_weight)
if avg_non_ignore:
avg_factor = valid_mask.sum().item()
torch_loss = (torch_loss * weight).sum() / avg_factor
else:
torch_loss = (torch_loss * weight).mean()
else:
if avg_non_ignore:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='mean', ignore_index=255)
else:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='sum',
ignore_index=255) / fake_label.numel()
assert torch.allclose(loss, torch_loss)
# test loss with class weights from file
fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long()
import os
import tempfile
@ -63,27 +123,103 @@ def test_ce_loss():
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
loss_cls_cfg = dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.))
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
# test `avg_non_ignore` without ignore index would not affect ce/bce loss
# when reduction='sum'/'none'/'mean'
loss_cls_cfg1 = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
reduction=reduction,
loss_weight=1.0,
avg_non_ignore=True)
loss_cls1 = build_loss(loss_cls_cfg1)
loss_cls_cfg2 = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
reduction=reduction,
loss_weight=1.0,
avg_non_ignore=False)
loss_cls2 = build_loss(loss_cls_cfg2)
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4)
fake_label[:, 0, 0] = 255
assert torch.allclose(
loss_cls(fake_pred, fake_label, ignore_index=255),
torch.tensor(0.9354),
loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
atol=1e-4)
# test cross entropy loss has name `loss_ce`
# test ce/bce loss with ignore index and class weight
# in 5-way classification
if use_sigmoid:
# test bce loss when pred.shape == or != label.shape
if bce_input_same_dim:
fake_pred = torch.randn(2, 10).float()
fake_label = torch.rand(2, 10).float()
class_weight = torch.rand(2, 10)
else:
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
class_weight = torch.randn(2, 21, 8, 8)
fake_label, weight, valid_mask = _expand_onehot_labels(
labels=fake_label,
label_weights=None,
target_shape=fake_pred.shape,
ignore_index=-100)
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.float(),
reduction='mean',
pos_weight=class_weight)
else:
fake_pred = torch.randn(2, 5, 10).float() # 5-way classification
fake_label = torch.randint(0, 5, (2, 10)).long()
class_weight = torch.rand(5)
class_weight /= class_weight.sum()
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='sum',
weight=class_weight) / fake_label.numel()
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
use_sigmoid=use_sigmoid,
reduction='mean',
class_weight=class_weight,
loss_weight=1.0,
loss_name='loss_ce')
avg_non_ignore=avg_non_ignore)
loss_cls = build_loss(loss_cls_cfg)
# test cross entropy loss has name `loss_ce`
assert loss_cls.loss_name == 'loss_ce'
# TODO test use_mask
# test avg_non_ignore is in extra_repr
assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}'
loss = loss_cls(fake_pred, fake_label)
assert torch.allclose(loss, torch_loss)
fake_label[0, [1, 2, 5, 7]] = 10 # set ignore_index
fake_label[1, [0, 5, 8, 9]] = 10
loss = loss_cls(fake_pred, fake_label, ignore_index=10)
if use_sigmoid:
if avg_non_ignore:
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred[fake_label != 10],
fake_label[fake_label != 10].float(),
pos_weight=class_weight[fake_label != 10],
reduction='mean')
else:
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred[fake_label != 10],
fake_label[fake_label != 10].float(),
pos_weight=class_weight[fake_label != 10],
reduction='sum') / fake_label.numel()
else:
if avg_non_ignore:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred,
fake_label,
ignore_index=10,
reduction='sum',
weight=class_weight) / fake_label[fake_label != 10].numel()
else:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred,
fake_label,
ignore_index=10,
reduction='sum',
weight=class_weight) / fake_label.numel()
assert torch.allclose(loss, torch_loss)