132 lines
5.1 KiB
Python
132 lines
5.1 KiB
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..builder import LOSSES
|
|
from .utils import weight_reduce_loss
|
|
|
|
|
|
def varifocal_loss(pred,
|
|
target,
|
|
weight=None,
|
|
alpha=0.75,
|
|
gamma=2.0,
|
|
iou_weighted=True,
|
|
reduction='mean',
|
|
avg_factor=None):
|
|
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
|
|
|
|
Args:
|
|
pred (torch.Tensor): The prediction with shape (N, C), C is the
|
|
number of classes
|
|
target (torch.Tensor): The learning target of the iou-aware
|
|
classification score with shape (N, C), C is the number of classes.
|
|
weight (torch.Tensor, optional): The weight of loss for each
|
|
prediction. Defaults to None.
|
|
alpha (float, optional): A balance factor for the negative part of
|
|
Varifocal Loss, which is different from the alpha of Focal Loss.
|
|
Defaults to 0.75.
|
|
gamma (float, optional): The gamma for calculating the modulating
|
|
factor. Defaults to 2.0.
|
|
iou_weighted (bool, optional): Whether to weight the loss of the
|
|
positive example with the iou target. Defaults to True.
|
|
reduction (str, optional): The method used to reduce the loss into
|
|
a scalar. Defaults to 'mean'. Options are "none", "mean" and
|
|
"sum".
|
|
avg_factor (int, optional): Average factor that is used to average
|
|
the loss. Defaults to None.
|
|
"""
|
|
# pred and target should be of the same size
|
|
assert pred.size() == target.size()
|
|
pred_sigmoid = pred.sigmoid()
|
|
target = target.type_as(pred)
|
|
if iou_weighted:
|
|
focal_weight = target * (target > 0.0).float() + \
|
|
alpha * (pred_sigmoid - target).abs().pow(gamma) * \
|
|
(target <= 0.0).float()
|
|
else:
|
|
focal_weight = (target > 0.0).float() + \
|
|
alpha * (pred_sigmoid - target).abs().pow(gamma) * \
|
|
(target <= 0.0).float()
|
|
loss = F.binary_cross_entropy_with_logits(
|
|
pred, target, reduction='none') * focal_weight
|
|
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
|
return loss
|
|
|
|
|
|
@LOSSES.register_module()
|
|
class VarifocalLoss(nn.Module):
|
|
|
|
def __init__(self,
|
|
use_sigmoid=True,
|
|
alpha=0.75,
|
|
gamma=2.0,
|
|
iou_weighted=True,
|
|
reduction='mean',
|
|
loss_weight=1.0):
|
|
"""`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
|
|
|
|
Args:
|
|
use_sigmoid (bool, optional): Whether the prediction is
|
|
used for sigmoid or softmax. Defaults to True.
|
|
alpha (float, optional): A balance factor for the negative part of
|
|
Varifocal Loss, which is different from the alpha of Focal
|
|
Loss. Defaults to 0.75.
|
|
gamma (float, optional): The gamma for calculating the modulating
|
|
factor. Defaults to 2.0.
|
|
iou_weighted (bool, optional): Whether to weight the loss of the
|
|
positive examples with the iou target. Defaults to True.
|
|
reduction (str, optional): The method used to reduce the loss into
|
|
a scalar. Defaults to 'mean'. Options are "none", "mean" and
|
|
"sum".
|
|
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
|
"""
|
|
super(VarifocalLoss, self).__init__()
|
|
assert use_sigmoid is True, \
|
|
'Only sigmoid varifocal loss supported now.'
|
|
assert alpha >= 0.0
|
|
self.use_sigmoid = use_sigmoid
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
self.iou_weighted = iou_weighted
|
|
self.reduction = reduction
|
|
self.loss_weight = loss_weight
|
|
|
|
def forward(self,
|
|
pred,
|
|
target,
|
|
weight=None,
|
|
avg_factor=None,
|
|
reduction_override=None):
|
|
"""Forward function.
|
|
|
|
Args:
|
|
pred (torch.Tensor): The prediction.
|
|
target (torch.Tensor): The learning target of the prediction.
|
|
weight (torch.Tensor, optional): The weight of loss for each
|
|
prediction. Defaults to None.
|
|
avg_factor (int, optional): Average factor that is used to average
|
|
the loss. Defaults to None.
|
|
reduction_override (str, optional): The reduction method used to
|
|
override the original reduction method of the loss.
|
|
Options are "none", "mean" and "sum".
|
|
|
|
Returns:
|
|
torch.Tensor: The calculated loss
|
|
"""
|
|
assert reduction_override in (None, 'none', 'mean', 'sum')
|
|
reduction = (
|
|
reduction_override if reduction_override else self.reduction)
|
|
if self.use_sigmoid:
|
|
loss_cls = self.loss_weight * varifocal_loss(
|
|
pred,
|
|
target,
|
|
weight,
|
|
alpha=self.alpha,
|
|
gamma=self.gamma,
|
|
iou_weighted=self.iou_weighted,
|
|
reduction=reduction,
|
|
avg_factor=avg_factor)
|
|
else:
|
|
raise NotImplementedError
|
|
return loss_cls
|