mirror of https://github.com/alibaba/EasyCV.git
52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ..registry import LOSSES
|
|
|
|
|
|
@LOSSES.register_module
|
|
class IOUloss(nn.Module):
|
|
|
|
def __init__(self, reduction='none', loss_type='iou'):
|
|
super(IOUloss, self).__init__()
|
|
self.reduction = reduction
|
|
self.loss_type = loss_type
|
|
|
|
def forward(self, pred, target):
|
|
assert pred.shape[0] == target.shape[0]
|
|
if target.dtype != pred.dtype:
|
|
target = target.to(pred.dtype)
|
|
pred = pred.view(-1, 4)
|
|
target = target.view(-1, 4)
|
|
tl = torch.max((pred[:, :2] - pred[:, 2:] / 2),
|
|
(target[:, :2] - target[:, 2:] / 2))
|
|
br = torch.min((pred[:, :2] + pred[:, 2:] / 2),
|
|
(target[:, :2] + target[:, 2:] / 2))
|
|
|
|
area_p = torch.prod(pred[:, 2:], 1)
|
|
area_g = torch.prod(target[:, 2:], 1)
|
|
|
|
en = (tl < br).type(tl.type()).prod(dim=1)
|
|
area_i = torch.prod(br - tl, 1) * en
|
|
iou = (area_i) / (area_p + area_g - area_i + 1e-16)
|
|
|
|
if self.loss_type == 'iou':
|
|
loss = 1 - iou**2
|
|
elif self.loss_type == 'giou':
|
|
c_tl = torch.min((pred[:, :2] - pred[:, 2:] / 2),
|
|
(target[:, :2] - target[:, 2:] / 2))
|
|
c_br = torch.max((pred[:, :2] + pred[:, 2:] / 2),
|
|
(target[:, :2] + target[:, 2:] / 2))
|
|
area_c = torch.prod(c_br - c_tl, 1)
|
|
giou = iou - (area_c - area_i) / area_c.clamp(1e-16)
|
|
loss = 1 - giou.clamp(min=-1.0, max=1.0)
|
|
|
|
if self.reduction == 'mean':
|
|
loss = loss.mean()
|
|
elif self.reduction == 'sum':
|
|
loss = loss.sum()
|
|
|
|
return loss
|