mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Add avg_non_ignore in cross entropy loss (#1409)
* [Fix] Add avg_non_ignore in cross entropy loss * [Fix] Add avg_non_ignore in cross entropy loss * add docstring * fix ut * fix docstring and comments * fix * fix bce * fix avg_factor in BCE and add more ut * add avg_non_ignore * add more ut * fix part of ut * fix part of ut * test avg_non_ignore would not affect ce/bce when reduction none/sum * test avg_non_ignore would not affect ce/bce when reduction none/sum/mean * re-organize ut * re-organize ut * re-organize ut * re-organize hardcode case * fix parts of comments * fix another parts of comments * fix
This commit is contained in:
parent
24f1563571
commit
a82ebad0f6
@ -68,3 +68,23 @@ model = dict(
|
||||
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
|
||||
|
||||
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
|
||||
|
||||
## Ignore specified label index in loss calculation
|
||||
|
||||
In default setting, `avg_non_ignore=False` which means each pixel counts for loss calculation although some of them belong to ignore-index labels.
|
||||
|
||||
For loss calculation, we support ignore index of certain label by `avg_non_ignore` and `ignore_index`. In this way, the average loss would only be calculated in non-ignored labels which may achieve better performance, and here is the [reference](https://github.com/open-mmlab/mmsegmentation/pull/1409). Here is an example config of training `unet` on `Cityscapes` dataset: in loss calculation it would ignore label 0 which is background and loss average is only calculated on non-ignore labels:
|
||||
|
||||
```python
|
||||
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
|
||||
model = dict(
|
||||
decode_head=dict(
|
||||
ignore_index=0,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
|
||||
auxiliary_head=dict(
|
||||
ignore_index=0,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
|
||||
))
|
||||
```
|
||||
|
@ -68,3 +68,28 @@ model = dict(
|
||||
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。
|
||||
|
||||
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
|
||||
|
||||
## 在损失函数中忽略特定的 label 类别
|
||||
|
||||
默认设置 `avg_non_ignore=False`, 即每个像素都用来计算损失函数。尽管其中的一些像素属于需要被忽略的类别。
|
||||
|
||||
对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore` 和 `ignore_index` 来忽略 label 特定的类别。 这样损失函数将只在非忽略类别像素中求平均值,会获得更好的表现。这里是[相关 PR](https://github.com/open-mmlab/mmsegmentation/pull/1409)。以 `unet` 使用 `Cityscapes` 数据集训练为例,
|
||||
在计算损失函数时,忽略 label 为0的背景,并且仅在不被忽略的像素上计算均值。配置文件写为:
|
||||
|
||||
```python
|
||||
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
|
||||
model = dict(
|
||||
decode_head=dict(
|
||||
ignore_index=0,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
|
||||
auxiliary_head=dict(
|
||||
ignore_index=0,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
|
||||
))
|
||||
```
|
||||
|
||||
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。
|
||||
|
||||
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -13,8 +15,31 @@ def cross_entropy(pred,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=-100):
|
||||
"""The wrapper function for :func:`F.cross_entropy`"""
|
||||
ignore_index=-100,
|
||||
avg_non_ignore=False):
|
||||
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
Default: None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Default: None.
|
||||
ignore_index (int): Specifies a target value that is ignored and
|
||||
does not contribute to the input gradients. When
|
||||
``avg_non_ignore `` is ``True``, and the ``reduction`` is
|
||||
``''mean''``, the loss is averaged over non-ignored targets.
|
||||
Defaults: -100.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
"""
|
||||
|
||||
# class_weight is a manual rescaling weight given to each class.
|
||||
# If given, has to be a Tensor of size C element-wise losses
|
||||
loss = F.cross_entropy(
|
||||
@ -25,6 +50,11 @@ def cross_entropy(pred,
|
||||
ignore_index=ignore_index)
|
||||
|
||||
# apply weights and do the reduction
|
||||
# average loss over non-ignored elements
|
||||
# pytorch's official cross_entropy average loss over non-ignored elements
|
||||
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
||||
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
|
||||
avg_factor = label.numel() - (label == ignore_index).sum().item()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
loss = weight_reduce_loss(
|
||||
@ -46,13 +76,14 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
||||
bin_labels[inds[0], labels[valid_mask]] = 1
|
||||
|
||||
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
||||
|
||||
if label_weights is None:
|
||||
bin_label_weights = valid_mask
|
||||
else:
|
||||
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||||
bin_label_weights *= valid_mask
|
||||
|
||||
return bin_labels, bin_label_weights
|
||||
return bin_labels, bin_label_weights, valid_mask
|
||||
|
||||
|
||||
def binary_cross_entropy(pred,
|
||||
@ -61,19 +92,25 @@ def binary_cross_entropy(pred,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None,
|
||||
ignore_index=255):
|
||||
ignore_index=-100,
|
||||
avg_non_ignore=False,
|
||||
**kwargs):
|
||||
"""Calculate the binary CrossEntropy loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
Note: In bce loss, label < 0 is invalid.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255
|
||||
ignore_index (int): The label index to be ignored. Default: -100.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
@ -83,12 +120,21 @@ def binary_cross_entropy(pred,
|
||||
pred.dim() == 4 and label.dim() == 3), \
|
||||
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
||||
'H, W], label shape [N, H, W] are supported'
|
||||
label, weight = _expand_onehot_labels(label, weight, pred.shape,
|
||||
ignore_index)
|
||||
# `weight` returned from `_expand_onehot_labels`
|
||||
# has been treated for valid (non-ignore) pixels
|
||||
label, weight, valid_mask = _expand_onehot_labels(
|
||||
label, weight, pred.shape, ignore_index)
|
||||
else:
|
||||
# should mask out the ignored elements
|
||||
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
||||
if weight is not None:
|
||||
weight *= valid_mask
|
||||
else:
|
||||
weight = valid_mask
|
||||
# average loss over non-ignored and valid elements
|
||||
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
|
||||
avg_factor = valid_mask.sum().item()
|
||||
|
||||
# weighted element-wise losses
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, label.float(), pos_weight=class_weight, reduction='none')
|
||||
# do the reduction for the weighted loss
|
||||
@ -104,7 +150,8 @@ def mask_cross_entropy(pred,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
class_weight=None,
|
||||
ignore_index=None):
|
||||
ignore_index=None,
|
||||
**kwargs):
|
||||
"""Calculate the CrossEntropy loss for masks.
|
||||
|
||||
Args:
|
||||
@ -153,6 +200,9 @@ class CrossEntropyLoss(nn.Module):
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -161,7 +211,8 @@ class CrossEntropyLoss(nn.Module):
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce'):
|
||||
loss_name='loss_ce',
|
||||
avg_non_ignore=False):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
assert (use_sigmoid is False) or (use_mask is False)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
@ -169,6 +220,13 @@ class CrossEntropyLoss(nn.Module):
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.avg_non_ignore = avg_non_ignore
|
||||
if not self.avg_non_ignore and self.reduction == 'mean':
|
||||
warnings.warn(
|
||||
'Default ``avg_non_ignore`` is False, if you would like to '
|
||||
'ignore the certain label and average loss over non-ignore '
|
||||
'labels, which is the same with PyTorch official '
|
||||
'cross_entropy, set ``avg_non_ignore=True``.')
|
||||
|
||||
if self.use_sigmoid:
|
||||
self.cls_criterion = binary_cross_entropy
|
||||
@ -178,12 +236,18 @@ class CrossEntropyLoss(nn.Module):
|
||||
self.cls_criterion = cross_entropy
|
||||
self._loss_name = loss_name
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'avg_non_ignore={self.avg_non_ignore}'
|
||||
return s
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
label,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
ignore_index=-100,
|
||||
**kwargs):
|
||||
"""Forward function."""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
@ -193,6 +257,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
class_weight = cls_score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
# Note: for BCE loss, label < 0 is invalid.
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
cls_score,
|
||||
label,
|
||||
@ -200,6 +265,8 @@ class CrossEntropyLoss(nn.Module):
|
||||
class_weight=class_weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
avg_non_ignore=self.avg_non_ignore,
|
||||
ignore_index=ignore_index,
|
||||
**kwargs)
|
||||
return loss_cls
|
||||
|
||||
@ -212,6 +279,7 @@ class CrossEntropyLoss(nn.Module):
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
|
@ -3,6 +3,7 @@ import functools
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@ -69,7 +70,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
|
||||
else:
|
||||
# if reduction is mean, then average the loss by avg_factor
|
||||
if reduction == 'mean':
|
||||
loss = loss.sum() / avg_factor
|
||||
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
|
||||
# i.e., all labels of an image belong to ignore index.
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
loss = loss.sum() / (avg_factor + eps)
|
||||
# if reduction is 'none', then do nothing, otherwise raise an error
|
||||
elif reduction != 'none':
|
||||
raise ValueError('avg_factor can not be used with reduction="sum"')
|
||||
|
@ -2,8 +2,14 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.losses.cross_entropy_loss import _expand_onehot_labels
|
||||
|
||||
def test_ce_loss():
|
||||
|
||||
@pytest.mark.parametrize('use_sigmoid', [True, False])
|
||||
@pytest.mark.parametrize('reduction', ('mean', 'sum', 'none'))
|
||||
@pytest.mark.parametrize('avg_non_ignore', [True, False])
|
||||
@pytest.mark.parametrize('bce_input_same_dim', [True, False])
|
||||
def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
|
||||
from mmseg.models import build_loss
|
||||
|
||||
# use_mask and use_sigmoid cannot be true at the same time
|
||||
@ -15,19 +21,73 @@ def test_ce_loss():
|
||||
loss_weight=1.0)
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# test loss with class weights
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=[0.8, 0.2],
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
# test loss with simple case for ce/bce
|
||||
fake_pred = torch.Tensor([[100, -100]])
|
||||
fake_label = torch.Tensor([1]).long()
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=use_sigmoid,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=avg_non_ignore,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
if use_sigmoid:
|
||||
assert torch.allclose(
|
||||
loss_cls(fake_pred, fake_label), torch.tensor(100.))
|
||||
else:
|
||||
assert torch.allclose(
|
||||
loss_cls(fake_pred, fake_label), torch.tensor(200.))
|
||||
|
||||
# test loss with complicated case for ce/bce
|
||||
# when avg_non_ignore is False, `avg_factor` would not be calculated
|
||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||
fake_label = torch.ones(2, 8, 8).long()
|
||||
fake_label[:, 0, 0] = 255
|
||||
fake_weight = None
|
||||
# extra test bce loss when pred.shape == label.shape
|
||||
if use_sigmoid and bce_input_same_dim:
|
||||
fake_pred = torch.randn(2, 10).float()
|
||||
fake_label = torch.rand(2, 10).float()
|
||||
fake_weight = torch.rand(2, 10) # set weight in forward function
|
||||
fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index
|
||||
fake_label[1, [0, 5, 8, 9]] = 255
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
loss = loss_cls(
|
||||
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
|
||||
if use_sigmoid:
|
||||
if fake_pred.dim() != fake_label.dim():
|
||||
fake_label, weight, valid_mask = _expand_onehot_labels(
|
||||
labels=fake_label,
|
||||
label_weights=None,
|
||||
target_shape=fake_pred.shape,
|
||||
ignore_index=255)
|
||||
else:
|
||||
# should mask out the ignored elements
|
||||
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
|
||||
weight = valid_mask
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred,
|
||||
fake_label.float(),
|
||||
reduction='none',
|
||||
weight=fake_weight)
|
||||
if avg_non_ignore:
|
||||
avg_factor = valid_mask.sum().item()
|
||||
torch_loss = (torch_loss * weight).sum() / avg_factor
|
||||
else:
|
||||
torch_loss = (torch_loss * weight).mean()
|
||||
else:
|
||||
if avg_non_ignore:
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred, fake_label, reduction='mean', ignore_index=255)
|
||||
else:
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred, fake_label, reduction='sum',
|
||||
ignore_index=255) / fake_label.numel()
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
# test loss with class weights from file
|
||||
fake_pred = torch.Tensor([[100, -100]])
|
||||
fake_label = torch.Tensor([1]).long()
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@ -63,27 +123,103 @@ def test_ce_loss():
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
|
||||
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.))
|
||||
|
||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||
fake_label = torch.ones(2, 8, 8).long()
|
||||
# test `avg_non_ignore` without ignore index would not affect ce/bce loss
|
||||
# when reduction='sum'/'none'/'mean'
|
||||
loss_cls_cfg1 = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=use_sigmoid,
|
||||
reduction=reduction,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=True)
|
||||
loss_cls1 = build_loss(loss_cls_cfg1)
|
||||
loss_cls_cfg2 = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=use_sigmoid,
|
||||
reduction=reduction,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=False)
|
||||
loss_cls2 = build_loss(loss_cls_cfg2)
|
||||
assert torch.allclose(
|
||||
loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4)
|
||||
fake_label[:, 0, 0] = 255
|
||||
assert torch.allclose(
|
||||
loss_cls(fake_pred, fake_label, ignore_index=255),
|
||||
torch.tensor(0.9354),
|
||||
loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
|
||||
loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
|
||||
atol=1e-4)
|
||||
|
||||
# test cross entropy loss has name `loss_ce`
|
||||
# test ce/bce loss with ignore index and class weight
|
||||
# in 5-way classification
|
||||
if use_sigmoid:
|
||||
# test bce loss when pred.shape == or != label.shape
|
||||
if bce_input_same_dim:
|
||||
fake_pred = torch.randn(2, 10).float()
|
||||
fake_label = torch.rand(2, 10).float()
|
||||
class_weight = torch.rand(2, 10)
|
||||
else:
|
||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||
fake_label = torch.ones(2, 8, 8).long()
|
||||
class_weight = torch.randn(2, 21, 8, 8)
|
||||
fake_label, weight, valid_mask = _expand_onehot_labels(
|
||||
labels=fake_label,
|
||||
label_weights=None,
|
||||
target_shape=fake_pred.shape,
|
||||
ignore_index=-100)
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred,
|
||||
fake_label.float(),
|
||||
reduction='mean',
|
||||
pos_weight=class_weight)
|
||||
else:
|
||||
fake_pred = torch.randn(2, 5, 10).float() # 5-way classification
|
||||
fake_label = torch.randint(0, 5, (2, 10)).long()
|
||||
class_weight = torch.rand(5)
|
||||
class_weight /= class_weight.sum()
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred, fake_label, reduction='sum',
|
||||
weight=class_weight) / fake_label.numel()
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
use_sigmoid=use_sigmoid,
|
||||
reduction='mean',
|
||||
class_weight=class_weight,
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce')
|
||||
avg_non_ignore=avg_non_ignore)
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
|
||||
# test cross entropy loss has name `loss_ce`
|
||||
assert loss_cls.loss_name == 'loss_ce'
|
||||
# TODO test use_mask
|
||||
# test avg_non_ignore is in extra_repr
|
||||
assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}'
|
||||
|
||||
loss = loss_cls(fake_pred, fake_label)
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
fake_label[0, [1, 2, 5, 7]] = 10 # set ignore_index
|
||||
fake_label[1, [0, 5, 8, 9]] = 10
|
||||
loss = loss_cls(fake_pred, fake_label, ignore_index=10)
|
||||
if use_sigmoid:
|
||||
if avg_non_ignore:
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred[fake_label != 10],
|
||||
fake_label[fake_label != 10].float(),
|
||||
pos_weight=class_weight[fake_label != 10],
|
||||
reduction='mean')
|
||||
else:
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred[fake_label != 10],
|
||||
fake_label[fake_label != 10].float(),
|
||||
pos_weight=class_weight[fake_label != 10],
|
||||
reduction='sum') / fake_label.numel()
|
||||
else:
|
||||
if avg_non_ignore:
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred,
|
||||
fake_label,
|
||||
ignore_index=10,
|
||||
reduction='sum',
|
||||
weight=class_weight) / fake_label[fake_label != 10].numel()
|
||||
else:
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred,
|
||||
fake_label,
|
||||
ignore_index=10,
|
||||
reduction='sum',
|
||||
weight=class_weight) / fake_label.numel()
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
Loading…
x
Reference in New Issue
Block a user