From bcc6695235a9661ebd02c11262d23bc5ea45e79a Mon Sep 17 00:00:00 2001 From: Zhaoyan Fang <52028100+satuoqaq@users.noreply.github.com> Date: Sun, 18 Sep 2022 11:13:41 +0800 Subject: [PATCH] add losses (#9) --- mmyolo/losses/__init__.py | 4 + mmyolo/losses/iou_loss.py | 175 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 mmyolo/losses/__init__.py create mode 100644 mmyolo/losses/iou_loss.py diff --git a/mmyolo/losses/__init__.py b/mmyolo/losses/__init__.py new file mode 100644 index 00000000..311e73be --- /dev/null +++ b/mmyolo/losses/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .iou_loss import IoULoss, bbox_overlaps + +__all__ = ['IoULoss', 'bbox_overlaps'] \ No newline at end of file diff --git a/mmyolo/losses/iou_loss.py b/mmyolo/losses/iou_loss.py new file mode 100644 index 00000000..286dc827 --- /dev/null +++ b/mmyolo/losses/iou_loss.py @@ -0,0 +1,175 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from mmdet.models.losses.utils import weight_reduce_loss +from mmdet.structures.bbox import HorizontalBoxes +from mmyolo.registry import MODELS + + +# TODO: unify all code +def bbox_overlaps(pred: torch.Tensor, + target: torch.Tensor, + iou_mode: str = 'ciou', + bbox_format: str = 'xywh', + is_aligned: bool = False, + eps: float = 1e-7) -> torch.Tensor: + r"""Calculate overlap between two set of bboxes. + `Implementation of paper `Enhancing Geometric Factors into + Model Learning and Inference for Object Detection and Instance + Segmentation `_. + In the CIoU implementation of YOLOv5 and mmdetection, there is a slight + difference in the way the alpha parameter is computed. + mmdet version: + alpha = (ious > 0.5).float() * v / (1 - ious + v) + YOLOv5 version: + alpha = v / (v - ious + (1 + eps) + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2) + or (x, y, w, h),shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + iou_mode (str): Options are "ciou". + Defaults to "ciou". + bbox_format (str): Options are "xywh" and "xyxy". + Defaults to "xywh". + is_aligned (bool): + eps (float): Eps to avoid log(0). + Returns: + Tensor: shape (n,). + """ + assert iou_mode in ('ciou', ) + assert bbox_format in ('xyxy', 'xywh') + if bbox_format == 'xywh': + pred = HorizontalBoxes.cxcywh_to_xyxy(pred) + target = HorizontalBoxes.cxcywh_to_xyxy(target) + + # overlap + lt = torch.max(pred[:, :2], target[:, :2]) + rb = torch.min(pred[:, 2:], target[:, 2:]) + wh = (rb - lt).clamp(min=0) + overlap = wh[:, 0] * wh[:, 1] + + # union + ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1]) + ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) + union = ap + ag - overlap + eps + + # IoU + ious = overlap / union + + # enclose area + enclose_x1y1 = torch.min(pred[:, :2], target[:, :2]) + enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:]) + enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0) + + cw = enclose_wh[:, 0] + ch = enclose_wh[:, 1] + + c2 = cw**2 + ch**2 + eps + + b1_x1, b1_y1 = pred[:, 0], pred[:, 1] + b1_x2, b1_y2 = pred[:, 2], pred[:, 3] + b2_x1, b2_y1 = target[:, 0], target[:, 1] + b2_x2, b2_y2 = target[:, 2], target[:, 3] + + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4 + right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4 + rho2 = left + right + + factor = 4 / math.pi**2 + v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + + with torch.no_grad(): + alpha = v / (v - ious + (1 + eps)) + + # CIoU + cious = ious - (rho2 / c2 + alpha * v) + return cious.clamp(min=-1.0, max=1.0) + + +@MODELS.register_module() +class IoULoss(nn.Module): + """IoULoss. + Computing the IoU loss between a set of predicted bboxes and target bboxes. + Args: + iou_mode (str): Options are "ciou". + Defaults to "ciou". + bbox_format (str): Options are "xywh" and "xyxy". + Defaults to "xywh". + eps (float): Eps to avoid log(0). + reduction (str): Options are "none", "mean" and "sum". + loss_weight (float): Weight of loss. + return_iou (bool): If True, return loss and iou. + """ + + def __init__(self, + iou_mode: str = 'ciou', + bbox_format: str = 'xywh', + eps: float = 1e-7, + reduction: str = 'mean', + loss_weight: float = 1.0, + return_iou: bool = True): + super().__init__() + assert bbox_format in ('xywh', 'xyxy') + assert iou_mode in ('ciou', ) + self.iou_mode = iou_mode + self.bbox_format = bbox_format + self.eps = eps + self.reduction = reduction + self.loss_weight = loss_weight + self.return_iou = return_iou + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + avg_factor: Optional[str] = None, + reduction_override: Optional[Union[str, bool]] = None + ) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]: + """Forward function. + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2) + or (x, y, w, h),shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + weight (Tensor, optional): Element-wise weights. + avg_factor (float, optional): Average factor when computing the + mean of losses. + reduction_override (str, bool, optional): Same as built-in losses + of PyTorch. Defaults to None. + Returns: + loss or tuple(loss, iou): + """ + if weight is not None and not torch.any(weight > 0): + if pred.dim() == weight.dim() + 1: + weight = weight.unsqueeze(1) + return (pred * weight).sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if weight is not None and weight.dim() > 1: + # TODO: remove this in the future + # reduce the weight of shape (n, 4) to (n,) to match the + # giou_loss of shape (n,) + assert weight.shape == pred.shape + weight = weight.mean(-1) + + iou = bbox_overlaps( + pred, + target, + iou_mode=self.iou_mode, + bbox_format=self.bbox_format, + eps=self.eps) + loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight, + reduction, avg_factor) + + if self.return_iou: + return loss, iou + else: + return loss \ No newline at end of file