mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Add avg_non_ignore in cross entropy loss (#1409)
* [Fix] Add avg_non_ignore in cross entropy loss * [Fix] Add avg_non_ignore in cross entropy loss * add docstring * fix ut * fix docstring and comments * fix * fix bce * fix avg_factor in BCE and add more ut * add avg_non_ignore * add more ut * fix part of ut * fix part of ut * test avg_non_ignore would not affect ce/bce when reduction none/sum * test avg_non_ignore would not affect ce/bce when reduction none/sum/mean * re-organize ut * re-organize ut * re-organize ut * re-organize hardcode case * fix parts of comments * fix another parts of comments * fix
This commit is contained in:
parent
24f1563571
commit
a82ebad0f6
@ -68,3 +68,23 @@ model = dict(
|
|||||||
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
|
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
|
||||||
|
|
||||||
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
|
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
|
||||||
|
|
||||||
|
## Ignore specified label index in loss calculation
|
||||||
|
|
||||||
|
In default setting, `avg_non_ignore=False` which means each pixel counts for loss calculation although some of them belong to ignore-index labels.
|
||||||
|
|
||||||
|
For loss calculation, we support ignore index of certain label by `avg_non_ignore` and `ignore_index`. In this way, the average loss would only be calculated in non-ignored labels which may achieve better performance, and here is the [reference](https://github.com/open-mmlab/mmsegmentation/pull/1409). Here is an example config of training `unet` on `Cityscapes` dataset: in loss calculation it would ignore label 0 which is background and loss average is only calculated on non-ignore labels:
|
||||||
|
|
||||||
|
```python
|
||||||
|
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
|
||||||
|
model = dict(
|
||||||
|
decode_head=dict(
|
||||||
|
ignore_index=0,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
|
||||||
|
auxiliary_head=dict(
|
||||||
|
ignore_index=0,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
|
||||||
|
))
|
||||||
|
```
|
||||||
|
@ -68,3 +68,28 @@ model = dict(
|
|||||||
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。
|
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。
|
||||||
|
|
||||||
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
|
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
|
||||||
|
|
||||||
|
## 在损失函数中忽略特定的 label 类别
|
||||||
|
|
||||||
|
默认设置 `avg_non_ignore=False`, 即每个像素都用来计算损失函数。尽管其中的一些像素属于需要被忽略的类别。
|
||||||
|
|
||||||
|
对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore` 和 `ignore_index` 来忽略 label 特定的类别。 这样损失函数将只在非忽略类别像素中求平均值,会获得更好的表现。这里是[相关 PR](https://github.com/open-mmlab/mmsegmentation/pull/1409)。以 `unet` 使用 `Cityscapes` 数据集训练为例,
|
||||||
|
在计算损失函数时,忽略 label 为0的背景,并且仅在不被忽略的像素上计算均值。配置文件写为:
|
||||||
|
|
||||||
|
```python
|
||||||
|
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
|
||||||
|
model = dict(
|
||||||
|
decode_head=dict(
|
||||||
|
ignore_index=0,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
|
||||||
|
auxiliary_head=dict(
|
||||||
|
ignore_index=0,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。
|
||||||
|
|
||||||
|
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -13,8 +15,31 @@ def cross_entropy(pred,
|
|||||||
class_weight=None,
|
class_weight=None,
|
||||||
reduction='mean',
|
reduction='mean',
|
||||||
avg_factor=None,
|
avg_factor=None,
|
||||||
ignore_index=-100):
|
ignore_index=-100,
|
||||||
"""The wrapper function for :func:`F.cross_entropy`"""
|
avg_non_ignore=False):
|
||||||
|
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||||
|
label (torch.Tensor): The learning label of the prediction.
|
||||||
|
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||||
|
Default: None.
|
||||||
|
class_weight (list[float], optional): The weight for each class.
|
||||||
|
Default: None.
|
||||||
|
reduction (str, optional): The method used to reduce the loss.
|
||||||
|
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
|
||||||
|
avg_factor (int, optional): Average factor that is used to average
|
||||||
|
the loss. Default: None.
|
||||||
|
ignore_index (int): Specifies a target value that is ignored and
|
||||||
|
does not contribute to the input gradients. When
|
||||||
|
``avg_non_ignore `` is ``True``, and the ``reduction`` is
|
||||||
|
``''mean''``, the loss is averaged over non-ignored targets.
|
||||||
|
Defaults: -100.
|
||||||
|
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||||
|
only averaged over non-ignored targets. Default: False.
|
||||||
|
`New in version 0.23.0.`
|
||||||
|
"""
|
||||||
|
|
||||||
# class_weight is a manual rescaling weight given to each class.
|
# class_weight is a manual rescaling weight given to each class.
|
||||||
# If given, has to be a Tensor of size C element-wise losses
|
# If given, has to be a Tensor of size C element-wise losses
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
@ -25,6 +50,11 @@ def cross_entropy(pred,
|
|||||||
ignore_index=ignore_index)
|
ignore_index=ignore_index)
|
||||||
|
|
||||||
# apply weights and do the reduction
|
# apply weights and do the reduction
|
||||||
|
# average loss over non-ignored elements
|
||||||
|
# pytorch's official cross_entropy average loss over non-ignored elements
|
||||||
|
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
||||||
|
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
|
||||||
|
avg_factor = label.numel() - (label == ignore_index).sum().item()
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
weight = weight.float()
|
weight = weight.float()
|
||||||
loss = weight_reduce_loss(
|
loss = weight_reduce_loss(
|
||||||
@ -46,13 +76,14 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
|||||||
bin_labels[inds[0], labels[valid_mask]] = 1
|
bin_labels[inds[0], labels[valid_mask]] = 1
|
||||||
|
|
||||||
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
||||||
|
|
||||||
if label_weights is None:
|
if label_weights is None:
|
||||||
bin_label_weights = valid_mask
|
bin_label_weights = valid_mask
|
||||||
else:
|
else:
|
||||||
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||||||
bin_label_weights *= valid_mask
|
bin_label_weights *= valid_mask
|
||||||
|
|
||||||
return bin_labels, bin_label_weights
|
return bin_labels, bin_label_weights, valid_mask
|
||||||
|
|
||||||
|
|
||||||
def binary_cross_entropy(pred,
|
def binary_cross_entropy(pred,
|
||||||
@ -61,19 +92,25 @@ def binary_cross_entropy(pred,
|
|||||||
reduction='mean',
|
reduction='mean',
|
||||||
avg_factor=None,
|
avg_factor=None,
|
||||||
class_weight=None,
|
class_weight=None,
|
||||||
ignore_index=255):
|
ignore_index=-100,
|
||||||
|
avg_non_ignore=False,
|
||||||
|
**kwargs):
|
||||||
"""Calculate the binary CrossEntropy loss.
|
"""Calculate the binary CrossEntropy loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||||
label (torch.Tensor): The learning label of the prediction.
|
label (torch.Tensor): The learning label of the prediction.
|
||||||
|
Note: In bce loss, label < 0 is invalid.
|
||||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||||
reduction (str, optional): The method used to reduce the loss.
|
reduction (str, optional): The method used to reduce the loss.
|
||||||
Options are "none", "mean" and "sum".
|
Options are "none", "mean" and "sum".
|
||||||
avg_factor (int, optional): Average factor that is used to average
|
avg_factor (int, optional): Average factor that is used to average
|
||||||
the loss. Defaults to None.
|
the loss. Defaults to None.
|
||||||
class_weight (list[float], optional): The weight for each class.
|
class_weight (list[float], optional): The weight for each class.
|
||||||
ignore_index (int | None): The label index to be ignored. Default: 255
|
ignore_index (int): The label index to be ignored. Default: -100.
|
||||||
|
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||||
|
only averaged over non-ignored targets. Default: False.
|
||||||
|
`New in version 0.23.0.`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The calculated loss
|
torch.Tensor: The calculated loss
|
||||||
@ -83,12 +120,21 @@ def binary_cross_entropy(pred,
|
|||||||
pred.dim() == 4 and label.dim() == 3), \
|
pred.dim() == 4 and label.dim() == 3), \
|
||||||
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
||||||
'H, W], label shape [N, H, W] are supported'
|
'H, W], label shape [N, H, W] are supported'
|
||||||
label, weight = _expand_onehot_labels(label, weight, pred.shape,
|
# `weight` returned from `_expand_onehot_labels`
|
||||||
ignore_index)
|
# has been treated for valid (non-ignore) pixels
|
||||||
|
label, weight, valid_mask = _expand_onehot_labels(
|
||||||
|
label, weight, pred.shape, ignore_index)
|
||||||
|
else:
|
||||||
|
# should mask out the ignored elements
|
||||||
|
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
||||||
|
if weight is not None:
|
||||||
|
weight *= valid_mask
|
||||||
|
else:
|
||||||
|
weight = valid_mask
|
||||||
|
# average loss over non-ignored and valid elements
|
||||||
|
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
|
||||||
|
avg_factor = valid_mask.sum().item()
|
||||||
|
|
||||||
# weighted element-wise losses
|
|
||||||
if weight is not None:
|
|
||||||
weight = weight.float()
|
|
||||||
loss = F.binary_cross_entropy_with_logits(
|
loss = F.binary_cross_entropy_with_logits(
|
||||||
pred, label.float(), pos_weight=class_weight, reduction='none')
|
pred, label.float(), pos_weight=class_weight, reduction='none')
|
||||||
# do the reduction for the weighted loss
|
# do the reduction for the weighted loss
|
||||||
@ -104,7 +150,8 @@ def mask_cross_entropy(pred,
|
|||||||
reduction='mean',
|
reduction='mean',
|
||||||
avg_factor=None,
|
avg_factor=None,
|
||||||
class_weight=None,
|
class_weight=None,
|
||||||
ignore_index=None):
|
ignore_index=None,
|
||||||
|
**kwargs):
|
||||||
"""Calculate the CrossEntropy loss for masks.
|
"""Calculate the CrossEntropy loss for masks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -153,6 +200,9 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||||
item to be included into the backward graph, `loss_` must be the
|
item to be included into the backward graph, `loss_` must be the
|
||||||
prefix of the name. Defaults to 'loss_ce'.
|
prefix of the name. Defaults to 'loss_ce'.
|
||||||
|
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||||
|
only averaged over non-ignored targets. Default: False.
|
||||||
|
`New in version 0.23.0.`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -161,7 +211,8 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
reduction='mean',
|
reduction='mean',
|
||||||
class_weight=None,
|
class_weight=None,
|
||||||
loss_weight=1.0,
|
loss_weight=1.0,
|
||||||
loss_name='loss_ce'):
|
loss_name='loss_ce',
|
||||||
|
avg_non_ignore=False):
|
||||||
super(CrossEntropyLoss, self).__init__()
|
super(CrossEntropyLoss, self).__init__()
|
||||||
assert (use_sigmoid is False) or (use_mask is False)
|
assert (use_sigmoid is False) or (use_mask is False)
|
||||||
self.use_sigmoid = use_sigmoid
|
self.use_sigmoid = use_sigmoid
|
||||||
@ -169,6 +220,13 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.loss_weight = loss_weight
|
self.loss_weight = loss_weight
|
||||||
self.class_weight = get_class_weight(class_weight)
|
self.class_weight = get_class_weight(class_weight)
|
||||||
|
self.avg_non_ignore = avg_non_ignore
|
||||||
|
if not self.avg_non_ignore and self.reduction == 'mean':
|
||||||
|
warnings.warn(
|
||||||
|
'Default ``avg_non_ignore`` is False, if you would like to '
|
||||||
|
'ignore the certain label and average loss over non-ignore '
|
||||||
|
'labels, which is the same with PyTorch official '
|
||||||
|
'cross_entropy, set ``avg_non_ignore=True``.')
|
||||||
|
|
||||||
if self.use_sigmoid:
|
if self.use_sigmoid:
|
||||||
self.cls_criterion = binary_cross_entropy
|
self.cls_criterion = binary_cross_entropy
|
||||||
@ -178,12 +236,18 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
self.cls_criterion = cross_entropy
|
self.cls_criterion = cross_entropy
|
||||||
self._loss_name = loss_name
|
self._loss_name = loss_name
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
"""Extra repr."""
|
||||||
|
s = f'avg_non_ignore={self.avg_non_ignore}'
|
||||||
|
return s
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
cls_score,
|
cls_score,
|
||||||
label,
|
label,
|
||||||
weight=None,
|
weight=None,
|
||||||
avg_factor=None,
|
avg_factor=None,
|
||||||
reduction_override=None,
|
reduction_override=None,
|
||||||
|
ignore_index=-100,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Forward function."""
|
"""Forward function."""
|
||||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||||
@ -193,6 +257,7 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
class_weight = cls_score.new_tensor(self.class_weight)
|
class_weight = cls_score.new_tensor(self.class_weight)
|
||||||
else:
|
else:
|
||||||
class_weight = None
|
class_weight = None
|
||||||
|
# Note: for BCE loss, label < 0 is invalid.
|
||||||
loss_cls = self.loss_weight * self.cls_criterion(
|
loss_cls = self.loss_weight * self.cls_criterion(
|
||||||
cls_score,
|
cls_score,
|
||||||
label,
|
label,
|
||||||
@ -200,6 +265,8 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
class_weight=class_weight,
|
class_weight=class_weight,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
avg_factor=avg_factor,
|
avg_factor=avg_factor,
|
||||||
|
avg_non_ignore=self.avg_non_ignore,
|
||||||
|
ignore_index=ignore_index,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return loss_cls
|
return loss_cls
|
||||||
|
|
||||||
@ -212,6 +279,7 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
by simple sum operation. In addition, if you want this loss item to be
|
by simple sum operation. In addition, if you want this loss item to be
|
||||||
included into the backward graph, `loss_` must be the prefix of the
|
included into the backward graph, `loss_` must be the prefix of the
|
||||||
name.
|
name.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The name of this loss item.
|
str: The name of this loss item.
|
||||||
"""
|
"""
|
||||||
|
@ -3,6 +3,7 @@ import functools
|
|||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +70,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
|
|||||||
else:
|
else:
|
||||||
# if reduction is mean, then average the loss by avg_factor
|
# if reduction is mean, then average the loss by avg_factor
|
||||||
if reduction == 'mean':
|
if reduction == 'mean':
|
||||||
loss = loss.sum() / avg_factor
|
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
|
||||||
|
# i.e., all labels of an image belong to ignore index.
|
||||||
|
eps = torch.finfo(torch.float32).eps
|
||||||
|
loss = loss.sum() / (avg_factor + eps)
|
||||||
# if reduction is 'none', then do nothing, otherwise raise an error
|
# if reduction is 'none', then do nothing, otherwise raise an error
|
||||||
elif reduction != 'none':
|
elif reduction != 'none':
|
||||||
raise ValueError('avg_factor can not be used with reduction="sum"')
|
raise ValueError('avg_factor can not be used with reduction="sum"')
|
||||||
|
@ -2,8 +2,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from mmseg.models.losses.cross_entropy_loss import _expand_onehot_labels
|
||||||
|
|
||||||
def test_ce_loss():
|
|
||||||
|
@pytest.mark.parametrize('use_sigmoid', [True, False])
|
||||||
|
@pytest.mark.parametrize('reduction', ('mean', 'sum', 'none'))
|
||||||
|
@pytest.mark.parametrize('avg_non_ignore', [True, False])
|
||||||
|
@pytest.mark.parametrize('bce_input_same_dim', [True, False])
|
||||||
|
def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
|
||||||
from mmseg.models import build_loss
|
from mmseg.models import build_loss
|
||||||
|
|
||||||
# use_mask and use_sigmoid cannot be true at the same time
|
# use_mask and use_sigmoid cannot be true at the same time
|
||||||
@ -15,19 +21,73 @@ def test_ce_loss():
|
|||||||
loss_weight=1.0)
|
loss_weight=1.0)
|
||||||
build_loss(loss_cfg)
|
build_loss(loss_cfg)
|
||||||
|
|
||||||
# test loss with class weights
|
# test loss with simple case for ce/bce
|
||||||
loss_cls_cfg = dict(
|
|
||||||
type='CrossEntropyLoss',
|
|
||||||
use_sigmoid=False,
|
|
||||||
class_weight=[0.8, 0.2],
|
|
||||||
loss_weight=1.0,
|
|
||||||
loss_name='loss_ce')
|
|
||||||
loss_cls = build_loss(loss_cls_cfg)
|
|
||||||
fake_pred = torch.Tensor([[100, -100]])
|
fake_pred = torch.Tensor([[100, -100]])
|
||||||
fake_label = torch.Tensor([1]).long()
|
fake_label = torch.Tensor([1]).long()
|
||||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
loss_cls_cfg = dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=use_sigmoid,
|
||||||
|
loss_weight=1.0,
|
||||||
|
avg_non_ignore=avg_non_ignore,
|
||||||
|
loss_name='loss_ce')
|
||||||
|
loss_cls = build_loss(loss_cls_cfg)
|
||||||
|
if use_sigmoid:
|
||||||
|
assert torch.allclose(
|
||||||
|
loss_cls(fake_pred, fake_label), torch.tensor(100.))
|
||||||
|
else:
|
||||||
|
assert torch.allclose(
|
||||||
|
loss_cls(fake_pred, fake_label), torch.tensor(200.))
|
||||||
|
|
||||||
|
# test loss with complicated case for ce/bce
|
||||||
|
# when avg_non_ignore is False, `avg_factor` would not be calculated
|
||||||
|
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||||
|
fake_label = torch.ones(2, 8, 8).long()
|
||||||
|
fake_label[:, 0, 0] = 255
|
||||||
|
fake_weight = None
|
||||||
|
# extra test bce loss when pred.shape == label.shape
|
||||||
|
if use_sigmoid and bce_input_same_dim:
|
||||||
|
fake_pred = torch.randn(2, 10).float()
|
||||||
|
fake_label = torch.rand(2, 10).float()
|
||||||
|
fake_weight = torch.rand(2, 10) # set weight in forward function
|
||||||
|
fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index
|
||||||
|
fake_label[1, [0, 5, 8, 9]] = 255
|
||||||
|
loss_cls = build_loss(loss_cls_cfg)
|
||||||
|
loss = loss_cls(
|
||||||
|
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
|
||||||
|
if use_sigmoid:
|
||||||
|
if fake_pred.dim() != fake_label.dim():
|
||||||
|
fake_label, weight, valid_mask = _expand_onehot_labels(
|
||||||
|
labels=fake_label,
|
||||||
|
label_weights=None,
|
||||||
|
target_shape=fake_pred.shape,
|
||||||
|
ignore_index=255)
|
||||||
|
else:
|
||||||
|
# should mask out the ignored elements
|
||||||
|
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
|
||||||
|
weight = valid_mask
|
||||||
|
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||||
|
fake_pred,
|
||||||
|
fake_label.float(),
|
||||||
|
reduction='none',
|
||||||
|
weight=fake_weight)
|
||||||
|
if avg_non_ignore:
|
||||||
|
avg_factor = valid_mask.sum().item()
|
||||||
|
torch_loss = (torch_loss * weight).sum() / avg_factor
|
||||||
|
else:
|
||||||
|
torch_loss = (torch_loss * weight).mean()
|
||||||
|
else:
|
||||||
|
if avg_non_ignore:
|
||||||
|
torch_loss = torch.nn.functional.cross_entropy(
|
||||||
|
fake_pred, fake_label, reduction='mean', ignore_index=255)
|
||||||
|
else:
|
||||||
|
torch_loss = torch.nn.functional.cross_entropy(
|
||||||
|
fake_pred, fake_label, reduction='sum',
|
||||||
|
ignore_index=255) / fake_label.numel()
|
||||||
|
assert torch.allclose(loss, torch_loss)
|
||||||
|
|
||||||
# test loss with class weights from file
|
# test loss with class weights from file
|
||||||
|
fake_pred = torch.Tensor([[100, -100]])
|
||||||
|
fake_label = torch.Tensor([1]).long()
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
@ -63,27 +123,103 @@ def test_ce_loss():
|
|||||||
loss_cls = build_loss(loss_cls_cfg)
|
loss_cls = build_loss(loss_cls_cfg)
|
||||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
|
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
|
||||||
|
|
||||||
loss_cls_cfg = dict(
|
# test `avg_non_ignore` without ignore index would not affect ce/bce loss
|
||||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
|
# when reduction='sum'/'none'/'mean'
|
||||||
loss_cls = build_loss(loss_cls_cfg)
|
loss_cls_cfg1 = dict(
|
||||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.))
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=use_sigmoid,
|
||||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
reduction=reduction,
|
||||||
fake_label = torch.ones(2, 8, 8).long()
|
loss_weight=1.0,
|
||||||
|
avg_non_ignore=True)
|
||||||
|
loss_cls1 = build_loss(loss_cls_cfg1)
|
||||||
|
loss_cls_cfg2 = dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=use_sigmoid,
|
||||||
|
reduction=reduction,
|
||||||
|
loss_weight=1.0,
|
||||||
|
avg_non_ignore=False)
|
||||||
|
loss_cls2 = build_loss(loss_cls_cfg2)
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4)
|
loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
|
||||||
fake_label[:, 0, 0] = 255
|
loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
|
||||||
assert torch.allclose(
|
|
||||||
loss_cls(fake_pred, fake_label, ignore_index=255),
|
|
||||||
torch.tensor(0.9354),
|
|
||||||
atol=1e-4)
|
atol=1e-4)
|
||||||
|
|
||||||
# test cross entropy loss has name `loss_ce`
|
# test ce/bce loss with ignore index and class weight
|
||||||
|
# in 5-way classification
|
||||||
|
if use_sigmoid:
|
||||||
|
# test bce loss when pred.shape == or != label.shape
|
||||||
|
if bce_input_same_dim:
|
||||||
|
fake_pred = torch.randn(2, 10).float()
|
||||||
|
fake_label = torch.rand(2, 10).float()
|
||||||
|
class_weight = torch.rand(2, 10)
|
||||||
|
else:
|
||||||
|
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||||
|
fake_label = torch.ones(2, 8, 8).long()
|
||||||
|
class_weight = torch.randn(2, 21, 8, 8)
|
||||||
|
fake_label, weight, valid_mask = _expand_onehot_labels(
|
||||||
|
labels=fake_label,
|
||||||
|
label_weights=None,
|
||||||
|
target_shape=fake_pred.shape,
|
||||||
|
ignore_index=-100)
|
||||||
|
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||||
|
fake_pred,
|
||||||
|
fake_label.float(),
|
||||||
|
reduction='mean',
|
||||||
|
pos_weight=class_weight)
|
||||||
|
else:
|
||||||
|
fake_pred = torch.randn(2, 5, 10).float() # 5-way classification
|
||||||
|
fake_label = torch.randint(0, 5, (2, 10)).long()
|
||||||
|
class_weight = torch.rand(5)
|
||||||
|
class_weight /= class_weight.sum()
|
||||||
|
torch_loss = torch.nn.functional.cross_entropy(
|
||||||
|
fake_pred, fake_label, reduction='sum',
|
||||||
|
weight=class_weight) / fake_label.numel()
|
||||||
loss_cls_cfg = dict(
|
loss_cls_cfg = dict(
|
||||||
type='CrossEntropyLoss',
|
type='CrossEntropyLoss',
|
||||||
use_sigmoid=False,
|
use_sigmoid=use_sigmoid,
|
||||||
|
reduction='mean',
|
||||||
|
class_weight=class_weight,
|
||||||
loss_weight=1.0,
|
loss_weight=1.0,
|
||||||
loss_name='loss_ce')
|
avg_non_ignore=avg_non_ignore)
|
||||||
loss_cls = build_loss(loss_cls_cfg)
|
loss_cls = build_loss(loss_cls_cfg)
|
||||||
|
|
||||||
|
# test cross entropy loss has name `loss_ce`
|
||||||
assert loss_cls.loss_name == 'loss_ce'
|
assert loss_cls.loss_name == 'loss_ce'
|
||||||
# TODO test use_mask
|
# test avg_non_ignore is in extra_repr
|
||||||
|
assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}'
|
||||||
|
|
||||||
|
loss = loss_cls(fake_pred, fake_label)
|
||||||
|
assert torch.allclose(loss, torch_loss)
|
||||||
|
|
||||||
|
fake_label[0, [1, 2, 5, 7]] = 10 # set ignore_index
|
||||||
|
fake_label[1, [0, 5, 8, 9]] = 10
|
||||||
|
loss = loss_cls(fake_pred, fake_label, ignore_index=10)
|
||||||
|
if use_sigmoid:
|
||||||
|
if avg_non_ignore:
|
||||||
|
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||||
|
fake_pred[fake_label != 10],
|
||||||
|
fake_label[fake_label != 10].float(),
|
||||||
|
pos_weight=class_weight[fake_label != 10],
|
||||||
|
reduction='mean')
|
||||||
|
else:
|
||||||
|
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||||
|
fake_pred[fake_label != 10],
|
||||||
|
fake_label[fake_label != 10].float(),
|
||||||
|
pos_weight=class_weight[fake_label != 10],
|
||||||
|
reduction='sum') / fake_label.numel()
|
||||||
|
else:
|
||||||
|
if avg_non_ignore:
|
||||||
|
torch_loss = torch.nn.functional.cross_entropy(
|
||||||
|
fake_pred,
|
||||||
|
fake_label,
|
||||||
|
ignore_index=10,
|
||||||
|
reduction='sum',
|
||||||
|
weight=class_weight) / fake_label[fake_label != 10].numel()
|
||||||
|
else:
|
||||||
|
torch_loss = torch.nn.functional.cross_entropy(
|
||||||
|
fake_pred,
|
||||||
|
fake_label,
|
||||||
|
ignore_index=10,
|
||||||
|
reduction='sum',
|
||||||
|
weight=class_weight) / fake_label.numel()
|
||||||
|
assert torch.allclose(loss, torch_loss)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user