mirror of https://github.com/alibaba/EasyCV.git
348 lines
14 KiB
Python
348 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
import warnings
|
||
|
||
import mmcv
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
from easycv.framework.errors import ValueError
|
||
from easycv.models.builder import LOSSES
|
||
from easycv.models.loss.utils import weight_reduce_loss
|
||
|
||
|
||
def get_class_weight(class_weight):
|
||
"""Get class weight for loss function.
|
||
|
||
Args:
|
||
class_weight (list[float] | str | None): If class_weight is a str,
|
||
take it as a file name and read from it.
|
||
"""
|
||
if isinstance(class_weight, str):
|
||
# take it as a file path
|
||
if class_weight.endswith('.npy'):
|
||
class_weight = np.load(class_weight)
|
||
else:
|
||
# pkl, json or yaml
|
||
class_weight = mmcv.load(class_weight)
|
||
|
||
return class_weight
|
||
|
||
|
||
def cross_entropy(pred,
|
||
label,
|
||
weight=None,
|
||
class_weight=None,
|
||
reduction='mean',
|
||
avg_factor=None,
|
||
ignore_index=-100,
|
||
avg_non_ignore=False):
|
||
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
|
||
support sample-wise loss weight and the reduction average loss over non-ignored elements.
|
||
|
||
Args:
|
||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||
label (torch.Tensor): The learning label of the prediction.
|
||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||
Default: None.
|
||
class_weight (list[float], optional): The weight for each class.
|
||
Default: None.
|
||
reduction (str, optional): The method used to reduce the loss.
|
||
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
|
||
avg_factor (int, optional): Average factor that is used to average
|
||
the loss. Default: None.
|
||
ignore_index (int): Specifies a target value that is ignored and
|
||
does not contribute to the input gradients. When
|
||
``avg_non_ignore `` is ``True``, and the ``reduction`` is
|
||
``''mean''``, the loss is averaged over non-ignored targets.
|
||
Defaults: -100.
|
||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||
only averaged over non-ignored targets. Default: False.
|
||
`New in version 0.23.0.`
|
||
"""
|
||
|
||
# class_weight is a manual rescaling weight given to each class.
|
||
# If given, has to be a Tensor of size C element-wise losses
|
||
loss = F.cross_entropy(
|
||
pred,
|
||
label,
|
||
weight=class_weight,
|
||
reduction='none',
|
||
ignore_index=ignore_index)
|
||
|
||
# apply weights and do the reduction
|
||
# average loss over non-ignored elements
|
||
# pytorch's official cross_entropy average loss over non-ignored elements
|
||
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
||
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
|
||
avg_factor = label.numel() - (label == ignore_index).sum().item()
|
||
if weight is not None:
|
||
weight = weight.float()
|
||
loss = weight_reduce_loss(
|
||
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
||
|
||
return loss
|
||
|
||
|
||
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
||
"""Expand onehot labels to match the size of prediction."""
|
||
bin_labels = labels.new_zeros(target_shape)
|
||
valid_mask = (labels >= 0) & (labels != ignore_index)
|
||
inds = torch.nonzero(valid_mask, as_tuple=True)
|
||
|
||
if inds[0].numel() > 0:
|
||
if labels.dim() == 3:
|
||
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
||
else:
|
||
bin_labels[inds[0], labels[valid_mask]] = 1
|
||
|
||
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
||
|
||
if label_weights is None:
|
||
bin_label_weights = valid_mask
|
||
else:
|
||
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||
bin_label_weights = bin_label_weights * valid_mask
|
||
|
||
return bin_labels, bin_label_weights, valid_mask
|
||
|
||
|
||
def binary_cross_entropy(pred,
|
||
label,
|
||
weight=None,
|
||
reduction='mean',
|
||
avg_factor=None,
|
||
class_weight=None,
|
||
ignore_index=-100,
|
||
avg_non_ignore=False,
|
||
label_ceil=False,
|
||
**kwargs):
|
||
"""Calculate the binary CrossEntropy loss.
|
||
|
||
Args:
|
||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||
label (torch.Tensor): The learning label of the prediction.
|
||
Note: In bce loss, label < 0 is invalid.
|
||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||
reduction (str, optional): The method used to reduce the loss.
|
||
Options are "none", "mean" and "sum".
|
||
avg_factor (int, optional): Average factor that is used to average
|
||
the loss. Defaults to None.
|
||
class_weight (list[float], optional): The weight for each class.
|
||
ignore_index (int): The label index to be ignored. Default: -100.
|
||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||
only averaged over non-ignored targets. Default: False.
|
||
`New in version 0.23.0.`
|
||
label_ceil (bool): When use bce and set label_ceil=True,
|
||
it will make elements belong to (0, 1] in label change to 1.
|
||
Default: False.
|
||
|
||
Returns:
|
||
torch.Tensor: The calculated loss
|
||
"""
|
||
if len(pred.shape) > 1 and pred.shape[1] == 1:
|
||
# For binary class segmentation, the shape of pred is
|
||
# [N, 1, H, W] and that of label is [N, H, W].
|
||
# As the ignore_index often set as 255, so the
|
||
# binary class label check should mask out
|
||
# ignore_index
|
||
assert label[label != ignore_index].max() <= 1, \
|
||
'For pred with shape [N, 1, H, W], its label must have at ' \
|
||
'most 2 classes'
|
||
pred = pred.squeeze()
|
||
if pred.dim() != label.dim():
|
||
assert (pred.dim() == 2 and label.dim() == 1) or (
|
||
pred.dim() == 4 and label.dim() == 3), \
|
||
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
||
'H, W], label shape [N, H, W] are supported'
|
||
# `weight` returned from `_expand_onehot_labels`
|
||
# has been treated for valid (non-ignore) pixels
|
||
label, weight, valid_mask = _expand_onehot_labels(
|
||
label, weight, pred.shape, ignore_index)
|
||
else:
|
||
# should mask out the ignored elements
|
||
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
||
if weight is not None:
|
||
weight = weight * valid_mask
|
||
else:
|
||
weight = valid_mask
|
||
if label_ceil:
|
||
label = label.gt(0.0).type(label.dtype)
|
||
# average loss over non-ignored and valid elements
|
||
if reduction == 'mean' and avg_factor is None and avg_non_ignore:
|
||
avg_factor = valid_mask.sum().item()
|
||
|
||
loss = F.binary_cross_entropy_with_logits(
|
||
pred, label.float(), pos_weight=class_weight, reduction='none')
|
||
# do the reduction for the weighted loss
|
||
loss = weight_reduce_loss(
|
||
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
||
|
||
return loss
|
||
|
||
|
||
def mask_cross_entropy(pred,
|
||
target,
|
||
label,
|
||
reduction='mean',
|
||
avg_factor=None,
|
||
class_weight=None,
|
||
ignore_index=None,
|
||
**kwargs):
|
||
"""Calculate the CrossEntropy loss for masks.
|
||
|
||
Args:
|
||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||
of classes.
|
||
target (torch.Tensor): The learning label of the prediction.
|
||
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
||
corresponding object. This will be used to select the mask in the
|
||
of the class which the object belongs to when the mask prediction
|
||
if not class-agnostic.
|
||
reduction (str, optional): The method used to reduce the loss.
|
||
Options are "none", "mean" and "sum".
|
||
avg_factor (int, optional): Average factor that is used to average
|
||
the loss. Defaults to None.
|
||
class_weight (list[float], optional): The weight for each class.
|
||
ignore_index (None): Placeholder, to be consistent with other loss.
|
||
Default: None.
|
||
|
||
Returns:
|
||
torch.Tensor: The calculated loss
|
||
"""
|
||
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
||
# TODO: handle these two reserved arguments
|
||
assert reduction == 'mean' and avg_factor is None
|
||
num_rois = pred.size()[0]
|
||
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
||
pred_slice = pred[inds, label].squeeze(1)
|
||
return F.binary_cross_entropy_with_logits(
|
||
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
||
|
||
|
||
@LOSSES.register_module()
|
||
class CrossEntropyLoss(nn.Module):
|
||
"""CrossEntropyLoss.
|
||
|
||
Args:
|
||
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
||
of softmax. Defaults to False.
|
||
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
||
Defaults to False.
|
||
reduction (str, optional): . Defaults to 'mean'.
|
||
Options are "none", "mean" and "sum".
|
||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||
str format, read them from a file. Defaults to None.
|
||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||
item to be included into the backward graph, `loss_` must be the
|
||
prefix of the name. Defaults to 'loss_ce'.
|
||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||
only averaged over non-ignored targets. Default: False.
|
||
`New in version 0.23.0.`
|
||
label_ceil (bool): When use bce and set label_ceil=True,
|
||
it will make elements belong to (0, 1] in label change to 1.
|
||
Default: False.
|
||
"""
|
||
|
||
def __init__(self,
|
||
use_sigmoid=False,
|
||
use_mask=False,
|
||
reduction='mean',
|
||
class_weight=None,
|
||
loss_weight=1.0,
|
||
loss_name='loss_ce',
|
||
avg_non_ignore=False,
|
||
label_ceil=False):
|
||
super(CrossEntropyLoss, self).__init__()
|
||
assert (use_sigmoid is False) or (use_mask is False)
|
||
self.use_sigmoid = use_sigmoid
|
||
if label_ceil:
|
||
if not use_sigmoid:
|
||
raise ValueError(
|
||
'‘label_ceil’ is supported only when ‘use_sigmoid’ is true. If not use bce, please set ‘label_ceil’=False'
|
||
)
|
||
self.use_mask = use_mask
|
||
self.reduction = reduction
|
||
self.loss_weight = loss_weight
|
||
self.class_weight = get_class_weight(class_weight)
|
||
self.avg_non_ignore = avg_non_ignore
|
||
if not self.avg_non_ignore and self.reduction == 'mean':
|
||
warnings.warn(
|
||
'Default ``avg_non_ignore`` is False, if you would like to '
|
||
'ignore the certain label and average loss over non-ignore '
|
||
'labels, which is the same with PyTorch official '
|
||
'cross_entropy, set ``avg_non_ignore=True``.')
|
||
|
||
if self.use_sigmoid:
|
||
self.cls_criterion = binary_cross_entropy
|
||
elif self.use_mask:
|
||
self.cls_criterion = mask_cross_entropy
|
||
else:
|
||
self.cls_criterion = cross_entropy
|
||
self._loss_name = loss_name
|
||
self.label_ceil = label_ceil
|
||
|
||
def extra_repr(self):
|
||
"""Extra repr."""
|
||
s = f'avg_non_ignore={self.avg_non_ignore}'
|
||
return s
|
||
|
||
def forward(self,
|
||
cls_score,
|
||
label,
|
||
weight=None,
|
||
avg_factor=None,
|
||
reduction_override=None,
|
||
ignore_index=-100,
|
||
**kwargs):
|
||
"""Forward function."""
|
||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||
reduction = (
|
||
reduction_override if reduction_override else self.reduction)
|
||
if self.class_weight is not None:
|
||
class_weight = cls_score.new_tensor(self.class_weight)
|
||
else:
|
||
class_weight = None
|
||
# Note: for BCE loss, label < 0 is invalid.
|
||
if self.use_sigmoid:
|
||
loss_cls = self.loss_weight * self.cls_criterion(
|
||
cls_score,
|
||
label,
|
||
weight,
|
||
class_weight=class_weight,
|
||
reduction=reduction,
|
||
avg_factor=avg_factor,
|
||
avg_non_ignore=self.avg_non_ignore,
|
||
ignore_index=ignore_index,
|
||
label_ceil=self.label_ceil,
|
||
**kwargs)
|
||
else:
|
||
loss_cls = self.loss_weight * self.cls_criterion(
|
||
cls_score,
|
||
label,
|
||
weight,
|
||
class_weight=class_weight,
|
||
reduction=reduction,
|
||
avg_factor=avg_factor,
|
||
avg_non_ignore=self.avg_non_ignore,
|
||
ignore_index=ignore_index,
|
||
**kwargs)
|
||
return loss_cls
|
||
|
||
@property
|
||
def loss_name(self):
|
||
"""Loss Name.
|
||
|
||
This function must be implemented and will return the name of this
|
||
loss function. This name will be used to combine different loss items
|
||
by simple sum operation. In addition, if you want this loss item to be
|
||
included into the backward graph, `loss_` must be the prefix of the
|
||
name.
|
||
|
||
Returns:
|
||
str: The name of this loss item.
|
||
"""
|
||
return self._loss_name
|