# Copyright (c) Alibaba, Inc. and its affiliates. import warnings import mmcv import torch import torch.nn as nn from easycv.models.detection.utils import bbox_overlaps from easycv.models.loss.utils import weighted_loss from ..registry import LOSSES @mmcv.jit(derivate=True, coderize=True) @weighted_loss def iou_loss(pred, target, linear=False, mode='log', eps=1e-6): """IoU loss. Computing the IoU loss between a set of predicted bboxes and target bboxes. The loss is calculated as negative log of IoU. Args: pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2), shape (n, 4). target (torch.Tensor): Corresponding gt bboxes, shape (n, 4). linear (bool, optional): If True, use linear scale of loss instead of log scale. Default: False. mode (str): Loss scaling mode, including "linear", "square", and "log". Default: 'log' eps (float): Eps to avoid log(0). Return: torch.Tensor: Loss tensor. """ assert mode in ['linear', 'square', 'log'] if linear: mode = 'linear' warnings.warn('DeprecationWarning: Setting "linear=True" in ' 'iou_loss is deprecated, please use "mode=`linear`" ' 'instead.') ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps) if mode == 'linear': loss = 1 - ious elif mode == 'square': loss = 1 - ious**2 elif mode == 'log': loss = -ious.log() else: raise NotImplementedError return loss @mmcv.jit(derivate=True, coderize=True) @weighted_loss def giou_loss(pred, target, eps=1e-7): r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression `_. Args: pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2), shape (n, 4). target (torch.Tensor): Corresponding gt bboxes, shape (n, 4). eps (float): Eps to avoid log(0). Return: Tensor: Loss tensor. """ gious = bbox_overlaps(pred, target, mode='giou', is_aligned=True, eps=eps) loss = 1 - gious return loss @LOSSES.register_module() class IoULoss(nn.Module): """IoULoss. Computing the IoU loss between a set of predicted bboxes and target bboxes. Args: linear (bool): If True, use linear scale of loss else determined by mode. Default: False. eps (float): Eps to avoid log(0). reduction (str): Options are "none", "mean" and "sum". loss_weight (float): Weight of loss. mode (str): Loss scaling mode, including "linear", "square", and "log". Default: 'log' """ def __init__(self, linear=False, eps=1e-6, reduction='mean', loss_weight=1.0, mode='log'): super(IoULoss, self).__init__() assert mode in ['linear', 'square', 'log'] if linear: mode = 'linear' warnings.warn('DeprecationWarning: Setting "linear=True" in ' 'IOULoss is deprecated, please use "mode=`linear`" ' 'instead.') self.mode = mode self.linear = linear self.eps = eps self.reduction = reduction self.loss_weight = loss_weight def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None, **kwargs): """Forward function. Args: pred (torch.Tensor): The prediction. target (torch.Tensor): The learning target of the prediction. weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Options are "none", "mean" and "sum". """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if (weight is not None) and (not torch.any(weight > 0)) and ( reduction != 'none'): if pred.dim() == weight.dim() + 1: weight = weight.unsqueeze(1) return (pred * weight).sum() # 0 if weight is not None and weight.dim() > 1: # TODO: remove this in the future # reduce the weight of shape (n, 4) to (n,) to match the # iou_loss of shape (n,) assert weight.shape == pred.shape weight = weight.mean(-1) loss = self.loss_weight * iou_loss( pred, target, weight, mode=self.mode, eps=self.eps, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss @LOSSES.register_module() class GIoULoss(nn.Module): def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0): super(GIoULoss, self).__init__() self.eps = eps self.reduction = reduction self.loss_weight = loss_weight def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None, **kwargs): if weight is not None and not torch.any(weight > 0): if pred.dim() == weight.dim() + 1: weight = weight.unsqueeze(1) return (pred * weight).sum() # 0 assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if weight is not None and weight.dim() > 1: # TODO: remove this in the future # reduce the weight of shape (n, 4) to (n,) to match the # giou_loss of shape (n,) assert weight.shape == pred.shape weight = weight.mean(-1) loss = self.loss_weight * giou_loss( pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor, **kwargs) return loss