mirror of https://github.com/open-mmlab/mmyolo.git
add losses (#9)
parent
b3d405aa4c
commit
bcc6695235
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .iou_loss import IoULoss, bbox_overlaps
|
||||
|
||||
__all__ = ['IoULoss', 'bbox_overlaps']
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmdet.models.losses.utils import weight_reduce_loss
|
||||
from mmdet.structures.bbox import HorizontalBoxes
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
# TODO: unify all code
|
||||
def bbox_overlaps(pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
iou_mode: str = 'ciou',
|
||||
bbox_format: str = 'xywh',
|
||||
is_aligned: bool = False,
|
||||
eps: float = 1e-7) -> torch.Tensor:
|
||||
r"""Calculate overlap between two set of bboxes.
|
||||
`Implementation of paper `Enhancing Geometric Factors into
|
||||
Model Learning and Inference for Object Detection and Instance
|
||||
Segmentation <https://arxiv.org/abs/2005.03572>`_.
|
||||
In the CIoU implementation of YOLOv5 and mmdetection, there is a slight
|
||||
difference in the way the alpha parameter is computed.
|
||||
mmdet version:
|
||||
alpha = (ious > 0.5).float() * v / (1 - ious + v)
|
||||
YOLOv5 version:
|
||||
alpha = v / (v - ious + (1 + eps)
|
||||
Args:
|
||||
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
|
||||
or (x, y, w, h),shape (n, 4).
|
||||
target (Tensor): Corresponding gt bboxes, shape (n, 4).
|
||||
iou_mode (str): Options are "ciou".
|
||||
Defaults to "ciou".
|
||||
bbox_format (str): Options are "xywh" and "xyxy".
|
||||
Defaults to "xywh".
|
||||
is_aligned (bool):
|
||||
eps (float): Eps to avoid log(0).
|
||||
Returns:
|
||||
Tensor: shape (n,).
|
||||
"""
|
||||
assert iou_mode in ('ciou', )
|
||||
assert bbox_format in ('xyxy', 'xywh')
|
||||
if bbox_format == 'xywh':
|
||||
pred = HorizontalBoxes.cxcywh_to_xyxy(pred)
|
||||
target = HorizontalBoxes.cxcywh_to_xyxy(target)
|
||||
|
||||
# overlap
|
||||
lt = torch.max(pred[:, :2], target[:, :2])
|
||||
rb = torch.min(pred[:, 2:], target[:, 2:])
|
||||
wh = (rb - lt).clamp(min=0)
|
||||
overlap = wh[:, 0] * wh[:, 1]
|
||||
|
||||
# union
|
||||
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
||||
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
||||
union = ap + ag - overlap + eps
|
||||
|
||||
# IoU
|
||||
ious = overlap / union
|
||||
|
||||
# enclose area
|
||||
enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
|
||||
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
||||
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
||||
|
||||
cw = enclose_wh[:, 0]
|
||||
ch = enclose_wh[:, 1]
|
||||
|
||||
c2 = cw**2 + ch**2 + eps
|
||||
|
||||
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
||||
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
||||
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
||||
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
||||
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
|
||||
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
|
||||
rho2 = left + right
|
||||
|
||||
factor = 4 / math.pi**2
|
||||
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - ious + (1 + eps))
|
||||
|
||||
# CIoU
|
||||
cious = ious - (rho2 / c2 + alpha * v)
|
||||
return cious.clamp(min=-1.0, max=1.0)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class IoULoss(nn.Module):
|
||||
"""IoULoss.
|
||||
Computing the IoU loss between a set of predicted bboxes and target bboxes.
|
||||
Args:
|
||||
iou_mode (str): Options are "ciou".
|
||||
Defaults to "ciou".
|
||||
bbox_format (str): Options are "xywh" and "xyxy".
|
||||
Defaults to "xywh".
|
||||
eps (float): Eps to avoid log(0).
|
||||
reduction (str): Options are "none", "mean" and "sum".
|
||||
loss_weight (float): Weight of loss.
|
||||
return_iou (bool): If True, return loss and iou.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
iou_mode: str = 'ciou',
|
||||
bbox_format: str = 'xywh',
|
||||
eps: float = 1e-7,
|
||||
reduction: str = 'mean',
|
||||
loss_weight: float = 1.0,
|
||||
return_iou: bool = True):
|
||||
super().__init__()
|
||||
assert bbox_format in ('xywh', 'xyxy')
|
||||
assert iou_mode in ('ciou', )
|
||||
self.iou_mode = iou_mode
|
||||
self.bbox_format = bbox_format
|
||||
self.eps = eps
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.return_iou = return_iou
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
avg_factor: Optional[str] = None,
|
||||
reduction_override: Optional[Union[str, bool]] = None
|
||||
) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""Forward function.
|
||||
Args:
|
||||
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
|
||||
or (x, y, w, h),shape (n, 4).
|
||||
target (Tensor): Corresponding gt bboxes, shape (n, 4).
|
||||
weight (Tensor, optional): Element-wise weights.
|
||||
avg_factor (float, optional): Average factor when computing the
|
||||
mean of losses.
|
||||
reduction_override (str, bool, optional): Same as built-in losses
|
||||
of PyTorch. Defaults to None.
|
||||
Returns:
|
||||
loss or tuple(loss, iou):
|
||||
"""
|
||||
if weight is not None and not torch.any(weight > 0):
|
||||
if pred.dim() == weight.dim() + 1:
|
||||
weight = weight.unsqueeze(1)
|
||||
return (pred * weight).sum() # 0
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if weight is not None and weight.dim() > 1:
|
||||
# TODO: remove this in the future
|
||||
# reduce the weight of shape (n, 4) to (n,) to match the
|
||||
# giou_loss of shape (n,)
|
||||
assert weight.shape == pred.shape
|
||||
weight = weight.mean(-1)
|
||||
|
||||
iou = bbox_overlaps(
|
||||
pred,
|
||||
target,
|
||||
iou_mode=self.iou_mode,
|
||||
bbox_format=self.bbox_format,
|
||||
eps=self.eps)
|
||||
loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight,
|
||||
reduction, avg_factor)
|
||||
|
||||
if self.return_iou:
|
||||
return loss, iou
|
||||
else:
|
||||
return loss
|
Loading…
Reference in New Issue