[Fix] Add avg_non_ignore in cross entropy loss (#1409)

* [Fix] Add avg_non_ignore in cross entropy loss

* [Fix] Add avg_non_ignore in cross entropy loss

* add docstring

* fix ut

* fix docstring and comments

* fix

* fix bce

* fix avg_factor in BCE and add more ut

* add avg_non_ignore

* add more ut

* fix part of ut

* fix part of ut

* test avg_non_ignore would not affect ce/bce when reduction none/sum

* test avg_non_ignore would not affect ce/bce when reduction none/sum/mean

* re-organize ut

* re-organize ut

* re-organize ut

* re-organize hardcode case

* fix parts of comments

* fix another parts of comments

* fix
This commit is contained in:
MengzhangLI 2022-03-30 18:32:47 +08:00 committed by GitHub
parent 24f1563571
commit a82ebad0f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 292 additions and 39 deletions

View File

@ -68,3 +68,23 @@ model = dict(
In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively. In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.
Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
## Ignore specified label index in loss calculation
In default setting, `avg_non_ignore=False` which means each pixel counts for loss calculation although some of them belong to ignore-index labels.
For loss calculation, we support ignore index of certain label by `avg_non_ignore` and `ignore_index`. In this way, the average loss would only be calculated in non-ignored labels which may achieve better performance, and here is the [reference](https://github.com/open-mmlab/mmsegmentation/pull/1409). Here is an example config of training `unet` on `Cityscapes` dataset: in loss calculation it would ignore label 0 which is background and loss average is only calculated on non-ignore labels:
```python
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
```

View File

@ -68,3 +68,28 @@ model = dict(
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name` 通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。 注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
## 在损失函数中忽略特定的 label 类别
默认设置 `avg_non_ignore=False` 即每个像素都用来计算损失函数。尽管其中的一些像素属于需要被忽略的类别。
对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore``ignore_index` 来忽略 label 特定的类别。 这样损失函数将只在非忽略类别像素中求平均值,会获得更好的表现。这里是[相关 PR](https://github.com/open-mmlab/mmsegmentation/pull/1409)。以 `unet` 使用 `Cityscapes` 数据集训练为例,
在计算损失函数时,忽略 label 为0的背景并且仅在不被忽略的像素上计算均值。配置文件写为:
```python
_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
model = dict(
decode_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
auxiliary_head=dict(
ignore_index=0,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
))
```
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`
注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -13,8 +15,31 @@ def cross_entropy(pred,
class_weight=None, class_weight=None,
reduction='mean', reduction='mean',
avg_factor=None, avg_factor=None,
ignore_index=-100): ignore_index=-100,
"""The wrapper function for :func:`F.cross_entropy`""" avg_non_ignore=False):
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss.
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
# class_weight is a manual rescaling weight given to each class. # class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses # If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy( loss = F.cross_entropy(
@ -25,6 +50,11 @@ def cross_entropy(pred,
ignore_index=ignore_index) ignore_index=ignore_index)
# apply weights and do the reduction # apply weights and do the reduction
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
if weight is not None: if weight is not None:
weight = weight.float() weight = weight.float()
loss = weight_reduce_loss( loss = weight_reduce_loss(
@ -46,13 +76,14 @@ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
bin_labels[inds[0], labels[valid_mask]] = 1 bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None: if label_weights is None:
bin_label_weights = valid_mask bin_label_weights = valid_mask
else: else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights *= valid_mask bin_label_weights *= valid_mask
return bin_labels, bin_label_weights return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(pred, def binary_cross_entropy(pred,
@ -61,19 +92,25 @@ def binary_cross_entropy(pred,
reduction='mean', reduction='mean',
avg_factor=None, avg_factor=None,
class_weight=None, class_weight=None,
ignore_index=255): ignore_index=-100,
avg_non_ignore=False,
**kwargs):
"""Calculate the binary CrossEntropy loss. """Calculate the binary CrossEntropy loss.
Args: Args:
pred (torch.Tensor): The prediction with shape (N, 1). pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction. label (torch.Tensor): The learning label of the prediction.
Note: In bce loss, label < 0 is invalid.
weight (torch.Tensor, optional): Sample-wise loss weight. weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss. reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum". Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None. the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class. class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored. Default: 255 ignore_index (int): The label index to be ignored. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
Returns: Returns:
torch.Tensor: The calculated loss torch.Tensor: The calculated loss
@ -83,12 +120,21 @@ def binary_cross_entropy(pred,
pred.dim() == 4 and label.dim() == 3), \ pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported' 'H, W], label shape [N, H, W] are supported'
label, weight = _expand_onehot_labels(label, weight, pred.shape, # `weight` returned from `_expand_onehot_labels`
ignore_index) # has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.shape, ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight *= valid_mask
else:
weight = valid_mask
# average loss over non-ignored and valid elements
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()
# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits( loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none') pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss # do the reduction for the weighted loss
@ -104,7 +150,8 @@ def mask_cross_entropy(pred,
reduction='mean', reduction='mean',
avg_factor=None, avg_factor=None,
class_weight=None, class_weight=None,
ignore_index=None): ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks. """Calculate the CrossEntropy loss for masks.
Args: Args:
@ -153,6 +200,9 @@ class CrossEntropyLoss(nn.Module):
loss_name (str, optional): Name of the loss item. If you want this loss loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'. prefix of the name. Defaults to 'loss_ce'.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
""" """
def __init__(self, def __init__(self,
@ -161,7 +211,8 @@ class CrossEntropyLoss(nn.Module):
reduction='mean', reduction='mean',
class_weight=None, class_weight=None,
loss_weight=1.0, loss_weight=1.0,
loss_name='loss_ce'): loss_name='loss_ce',
avg_non_ignore=False):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False) assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid self.use_sigmoid = use_sigmoid
@ -169,6 +220,13 @@ class CrossEntropyLoss(nn.Module):
self.reduction = reduction self.reduction = reduction
self.loss_weight = loss_weight self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight) self.class_weight = get_class_weight(class_weight)
self.avg_non_ignore = avg_non_ignore
if not self.avg_non_ignore and self.reduction == 'mean':
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid: if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy self.cls_criterion = binary_cross_entropy
@ -178,12 +236,18 @@ class CrossEntropyLoss(nn.Module):
self.cls_criterion = cross_entropy self.cls_criterion = cross_entropy
self._loss_name = loss_name self._loss_name = loss_name
def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
def forward(self, def forward(self,
cls_score, cls_score,
label, label,
weight=None, weight=None,
avg_factor=None, avg_factor=None,
reduction_override=None, reduction_override=None,
ignore_index=-100,
**kwargs): **kwargs):
"""Forward function.""" """Forward function."""
assert reduction_override in (None, 'none', 'mean', 'sum') assert reduction_override in (None, 'none', 'mean', 'sum')
@ -193,6 +257,7 @@ class CrossEntropyLoss(nn.Module):
class_weight = cls_score.new_tensor(self.class_weight) class_weight = cls_score.new_tensor(self.class_weight)
else: else:
class_weight = None class_weight = None
# Note: for BCE loss, label < 0 is invalid.
loss_cls = self.loss_weight * self.cls_criterion( loss_cls = self.loss_weight * self.cls_criterion(
cls_score, cls_score,
label, label,
@ -200,6 +265,8 @@ class CrossEntropyLoss(nn.Module):
class_weight=class_weight, class_weight=class_weight,
reduction=reduction, reduction=reduction,
avg_factor=avg_factor, avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs) **kwargs)
return loss_cls return loss_cls
@ -212,6 +279,7 @@ class CrossEntropyLoss(nn.Module):
by simple sum operation. In addition, if you want this loss item to be by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the included into the backward graph, `loss_` must be the prefix of the
name. name.
Returns: Returns:
str: The name of this loss item. str: The name of this loss item.
""" """

View File

@ -3,6 +3,7 @@ import functools
import mmcv import mmcv
import numpy as np import numpy as np
import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -69,7 +70,10 @@ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
else: else:
# if reduction is mean, then average the loss by avg_factor # if reduction is mean, then average the loss by avg_factor
if reduction == 'mean': if reduction == 'mean':
loss = loss.sum() / avg_factor # Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error # if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none': elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"') raise ValueError('avg_factor can not be used with reduction="sum"')

View File

@ -2,8 +2,14 @@
import pytest import pytest
import torch import torch
from mmseg.models.losses.cross_entropy_loss import _expand_onehot_labels
def test_ce_loss():
@pytest.mark.parametrize('use_sigmoid', [True, False])
@pytest.mark.parametrize('reduction', ('mean', 'sum', 'none'))
@pytest.mark.parametrize('avg_non_ignore', [True, False])
@pytest.mark.parametrize('bce_input_same_dim', [True, False])
def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
from mmseg.models import build_loss from mmseg.models import build_loss
# use_mask and use_sigmoid cannot be true at the same time # use_mask and use_sigmoid cannot be true at the same time
@ -15,19 +21,73 @@ def test_ce_loss():
loss_weight=1.0) loss_weight=1.0)
build_loss(loss_cfg) build_loss(loss_cfg)
# test loss with class weights # test loss with simple case for ce/bce
loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=[0.8, 0.2],
loss_weight=1.0,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
fake_pred = torch.Tensor([[100, -100]]) fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long() fake_label = torch.Tensor([1]).long()
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) loss_cls_cfg = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
loss_weight=1.0,
avg_non_ignore=avg_non_ignore,
loss_name='loss_ce')
loss_cls = build_loss(loss_cls_cfg)
if use_sigmoid:
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(100.))
else:
assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(200.))
# test loss with complicated case for ce/bce
# when avg_non_ignore is False, `avg_factor` would not be calculated
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
fake_label[:, 0, 0] = 255
fake_weight = None
# extra test bce loss when pred.shape == label.shape
if use_sigmoid and bce_input_same_dim:
fake_pred = torch.randn(2, 10).float()
fake_label = torch.rand(2, 10).float()
fake_weight = torch.rand(2, 10) # set weight in forward function
fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index
fake_label[1, [0, 5, 8, 9]] = 255
loss_cls = build_loss(loss_cls_cfg)
loss = loss_cls(
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
if use_sigmoid:
if fake_pred.dim() != fake_label.dim():
fake_label, weight, valid_mask = _expand_onehot_labels(
labels=fake_label,
label_weights=None,
target_shape=fake_pred.shape,
ignore_index=255)
else:
# should mask out the ignored elements
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
weight = valid_mask
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.float(),
reduction='none',
weight=fake_weight)
if avg_non_ignore:
avg_factor = valid_mask.sum().item()
torch_loss = (torch_loss * weight).sum() / avg_factor
else:
torch_loss = (torch_loss * weight).mean()
else:
if avg_non_ignore:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='mean', ignore_index=255)
else:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='sum',
ignore_index=255) / fake_label.numel()
assert torch.allclose(loss, torch_loss)
# test loss with class weights from file # test loss with class weights from file
fake_pred = torch.Tensor([[100, -100]])
fake_label = torch.Tensor([1]).long()
import os import os
import tempfile import tempfile
@ -63,27 +123,103 @@ def test_ce_loss():
loss_cls = build_loss(loss_cls_cfg) loss_cls = build_loss(loss_cls_cfg)
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
loss_cls_cfg = dict( # test `avg_non_ignore` without ignore index would not affect ce/bce loss
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0) # when reduction='sum'/'none'/'mean'
loss_cls = build_loss(loss_cls_cfg) loss_cls_cfg1 = dict(
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(100.)) type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5) reduction=reduction,
fake_label = torch.ones(2, 8, 8).long() loss_weight=1.0,
avg_non_ignore=True)
loss_cls1 = build_loss(loss_cls_cfg1)
loss_cls_cfg2 = dict(
type='CrossEntropyLoss',
use_sigmoid=use_sigmoid,
reduction=reduction,
loss_weight=1.0,
avg_non_ignore=False)
loss_cls2 = build_loss(loss_cls_cfg2)
assert torch.allclose( assert torch.allclose(
loss_cls(fake_pred, fake_label), torch.tensor(0.9503), atol=1e-4) loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
fake_label[:, 0, 0] = 255 loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
assert torch.allclose(
loss_cls(fake_pred, fake_label, ignore_index=255),
torch.tensor(0.9354),
atol=1e-4) atol=1e-4)
# test cross entropy loss has name `loss_ce` # test ce/bce loss with ignore index and class weight
# in 5-way classification
if use_sigmoid:
# test bce loss when pred.shape == or != label.shape
if bce_input_same_dim:
fake_pred = torch.randn(2, 10).float()
fake_label = torch.rand(2, 10).float()
class_weight = torch.rand(2, 10)
else:
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
fake_label = torch.ones(2, 8, 8).long()
class_weight = torch.randn(2, 21, 8, 8)
fake_label, weight, valid_mask = _expand_onehot_labels(
labels=fake_label,
label_weights=None,
target_shape=fake_pred.shape,
ignore_index=-100)
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred,
fake_label.float(),
reduction='mean',
pos_weight=class_weight)
else:
fake_pred = torch.randn(2, 5, 10).float() # 5-way classification
fake_label = torch.randint(0, 5, (2, 10)).long()
class_weight = torch.rand(5)
class_weight /= class_weight.sum()
torch_loss = torch.nn.functional.cross_entropy(
fake_pred, fake_label, reduction='sum',
weight=class_weight) / fake_label.numel()
loss_cls_cfg = dict( loss_cls_cfg = dict(
type='CrossEntropyLoss', type='CrossEntropyLoss',
use_sigmoid=False, use_sigmoid=use_sigmoid,
reduction='mean',
class_weight=class_weight,
loss_weight=1.0, loss_weight=1.0,
loss_name='loss_ce') avg_non_ignore=avg_non_ignore)
loss_cls = build_loss(loss_cls_cfg) loss_cls = build_loss(loss_cls_cfg)
# test cross entropy loss has name `loss_ce`
assert loss_cls.loss_name == 'loss_ce' assert loss_cls.loss_name == 'loss_ce'
# TODO test use_mask # test avg_non_ignore is in extra_repr
assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}'
loss = loss_cls(fake_pred, fake_label)
assert torch.allclose(loss, torch_loss)
fake_label[0, [1, 2, 5, 7]] = 10 # set ignore_index
fake_label[1, [0, 5, 8, 9]] = 10
loss = loss_cls(fake_pred, fake_label, ignore_index=10)
if use_sigmoid:
if avg_non_ignore:
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred[fake_label != 10],
fake_label[fake_label != 10].float(),
pos_weight=class_weight[fake_label != 10],
reduction='mean')
else:
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
fake_pred[fake_label != 10],
fake_label[fake_label != 10].float(),
pos_weight=class_weight[fake_label != 10],
reduction='sum') / fake_label.numel()
else:
if avg_non_ignore:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred,
fake_label,
ignore_index=10,
reduction='sum',
weight=class_weight) / fake_label[fake_label != 10].numel()
else:
torch_loss = torch.nn.functional.cross_entropy(
fake_pred,
fake_label,
ignore_index=10,
reduction='sum',
weight=class_weight) / fake_label.numel()
assert torch.allclose(loss, torch_loss)