180 lines
6.3 KiB
Python
180 lines
6.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..builder import LOSSES
|
|
from .utils import weight_reduce_loss
|
|
|
|
|
|
def cross_entropy(pred,
|
|
label,
|
|
weight=None,
|
|
class_weight=None,
|
|
reduction='mean',
|
|
avg_factor=None,
|
|
ignore_index=-100):
|
|
"""The wrapper function for :func:`F.cross_entropy`"""
|
|
# class_weight is a manual rescaling weight given to each class.
|
|
# If given, has to be a Tensor of size C element-wise losses
|
|
loss = F.cross_entropy(
|
|
pred,
|
|
label,
|
|
weight=class_weight,
|
|
reduction='none',
|
|
ignore_index=ignore_index)
|
|
|
|
# apply weights and do the reduction
|
|
if weight is not None:
|
|
weight = weight.float()
|
|
loss = weight_reduce_loss(
|
|
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
|
|
|
return loss
|
|
|
|
|
|
def _expand_onehot_labels(labels, label_weights, label_channels):
|
|
"""Expand onehot labels to match the size of prediction."""
|
|
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
|
|
inds = torch.nonzero(labels >= 1, as_tuple=False).squeeze()
|
|
if inds.numel() > 0:
|
|
bin_labels[inds, labels[inds] - 1] = 1
|
|
if label_weights is None:
|
|
bin_label_weights = None
|
|
else:
|
|
bin_label_weights = label_weights.view(-1, 1).expand(
|
|
label_weights.size(0), label_channels)
|
|
return bin_labels, bin_label_weights
|
|
|
|
|
|
def binary_cross_entropy(pred,
|
|
label,
|
|
weight=None,
|
|
reduction='mean',
|
|
avg_factor=None,
|
|
class_weight=None):
|
|
"""Calculate the binary CrossEntropy loss.
|
|
|
|
Args:
|
|
pred (torch.Tensor): The prediction with shape (N, 1).
|
|
label (torch.Tensor): The learning label of the prediction.
|
|
weight (torch.Tensor, optional): Sample-wise loss weight.
|
|
reduction (str, optional): The method used to reduce the loss.
|
|
Options are "none", "mean" and "sum".
|
|
avg_factor (int, optional): Average factor that is used to average
|
|
the loss. Defaults to None.
|
|
class_weight (list[float], optional): The weight for each class.
|
|
|
|
Returns:
|
|
torch.Tensor: The calculated loss
|
|
"""
|
|
if pred.dim() != label.dim():
|
|
label, weight = _expand_onehot_labels(label, weight, pred.size(-1))
|
|
|
|
# weighted element-wise losses
|
|
if weight is not None:
|
|
weight = weight.float()
|
|
loss = F.binary_cross_entropy_with_logits(
|
|
pred, label.float(), weight=class_weight, reduction='none')
|
|
# do the reduction for the weighted loss
|
|
loss = weight_reduce_loss(
|
|
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
|
|
|
return loss
|
|
|
|
|
|
def mask_cross_entropy(pred,
|
|
target,
|
|
label,
|
|
reduction='mean',
|
|
avg_factor=None,
|
|
class_weight=None):
|
|
"""Calculate the CrossEntropy loss for masks.
|
|
|
|
Args:
|
|
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
|
of classes.
|
|
target (torch.Tensor): The learning label of the prediction.
|
|
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
|
corresponding object. This will be used to select the mask in the
|
|
of the class which the object belongs to when the mask prediction
|
|
if not class-agnostic.
|
|
reduction (str, optional): The method used to reduce the loss.
|
|
Options are "none", "mean" and "sum".
|
|
avg_factor (int, optional): Average factor that is used to average
|
|
the loss. Defaults to None.
|
|
class_weight (list[float], optional): The weight for each class.
|
|
|
|
Returns:
|
|
torch.Tensor: The calculated loss
|
|
"""
|
|
# TODO: handle these two reserved arguments
|
|
assert reduction == 'mean' and avg_factor is None
|
|
num_rois = pred.size()[0]
|
|
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
|
pred_slice = pred[inds, label].squeeze(1)
|
|
return F.binary_cross_entropy_with_logits(
|
|
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class CrossEntropyLoss(nn.Module):
|
|
"""CrossEntropyLoss.
|
|
|
|
Args:
|
|
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
|
of softmax. Defaults to False.
|
|
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
|
Defaults to False.
|
|
reduction (str, optional): . Defaults to 'mean'.
|
|
Options are "none", "mean" and "sum".
|
|
class_weight (list[float], optional): Weight of each class.
|
|
Defaults to None.
|
|
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
|
"""
|
|
|
|
def __init__(self,
|
|
use_sigmoid=False,
|
|
use_mask=False,
|
|
reduction='mean',
|
|
class_weight=None,
|
|
loss_weight=1.0):
|
|
super(CrossEntropyLoss, self).__init__()
|
|
assert (use_sigmoid is False) or (use_mask is False)
|
|
self.use_sigmoid = use_sigmoid
|
|
self.use_mask = use_mask
|
|
self.reduction = reduction
|
|
self.loss_weight = loss_weight
|
|
self.class_weight = class_weight
|
|
|
|
if self.use_sigmoid:
|
|
self.cls_criterion = binary_cross_entropy
|
|
elif self.use_mask:
|
|
self.cls_criterion = mask_cross_entropy
|
|
else:
|
|
self.cls_criterion = cross_entropy
|
|
|
|
def forward(self,
|
|
cls_score,
|
|
label,
|
|
weight=None,
|
|
avg_factor=None,
|
|
reduction_override=None,
|
|
**kwargs):
|
|
"""Forward function."""
|
|
assert reduction_override in (None, 'none', 'mean', 'sum')
|
|
reduction = (
|
|
reduction_override if reduction_override else self.reduction)
|
|
if self.class_weight is not None:
|
|
class_weight = cls_score.new_tensor(self.class_weight)
|
|
else:
|
|
class_weight = None
|
|
loss_cls = self.loss_weight * self.cls_criterion(
|
|
cls_score,
|
|
label,
|
|
weight,
|
|
class_weight=class_weight,
|
|
reduction=reduction,
|
|
avg_factor=avg_factor,
|
|
**kwargs)
|
|
return loss_cls
|