`bbox_iou()` optimizations (#10296)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/10298/head
parent
bfa1f23045
commit
31c1f11186
|
@ -234,12 +234,12 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|||
else: # x1, y1, x2, y2 = box1
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
|
||||
w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
|
||||
|
||||
# Intersection area
|
||||
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
||||
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
||||
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
|
||||
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
|
||||
|
||||
# Union Area
|
||||
union = w1 * h1 + w2 * h2 - inter + eps
|
||||
|
@ -247,13 +247,13 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|||
# IoU
|
||||
iou = inter / union
|
||||
if CIoU or DIoU or GIoU:
|
||||
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
||||
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
||||
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
|
||||
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
|
||||
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
||||
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
||||
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
||||
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - (rho2 / c2 + v * alpha) # CIoU
|
||||
|
|
Loading…
Reference in New Issue