mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
79 lines
2.9 KiB
Python
79 lines
2.9 KiB
Python
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
from distutils.version import LooseVersion
|
||
|
|
||
|
import torch
|
||
|
import torchvision
|
||
|
|
||
|
__all__ = ['bboxes_iou', 'postprocess']
|
||
|
|
||
|
|
||
|
def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
|
||
|
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
|
||
|
raise IndexError
|
||
|
|
||
|
if xyxy:
|
||
|
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
|
||
|
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
|
||
|
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
|
||
|
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
|
||
|
else:
|
||
|
tl = torch.max(
|
||
|
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
|
||
|
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
|
||
|
)
|
||
|
br = torch.min(
|
||
|
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
|
||
|
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
|
||
|
)
|
||
|
|
||
|
area_a = torch.prod(bboxes_a[:, 2:], 1)
|
||
|
area_b = torch.prod(bboxes_b[:, 2:], 1)
|
||
|
en = (tl < br).type(tl.type()).prod(dim=2)
|
||
|
area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
|
||
|
return area_i / (area_a[:, None] + area_b - area_i)
|
||
|
|
||
|
|
||
|
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
|
||
|
box_corner = prediction.new(prediction.shape)
|
||
|
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
||
|
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
||
|
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
||
|
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
||
|
prediction[:, :, :4] = box_corner[:, :, :4]
|
||
|
|
||
|
output = [None for _ in range(len(prediction))]
|
||
|
for i, image_pred in enumerate(prediction):
|
||
|
|
||
|
# If none are remaining => process next image
|
||
|
if not image_pred.size(0):
|
||
|
continue
|
||
|
# Get score and class with highest confidence
|
||
|
class_conf, class_pred = torch.max(
|
||
|
image_pred[:, 5:5 + num_classes], 1, keepdim=True)
|
||
|
|
||
|
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >=
|
||
|
conf_thre).squeeze()
|
||
|
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
|
||
|
detections = torch.cat(
|
||
|
(image_pred[:, :5], class_conf, class_pred.float()), 1)
|
||
|
detections = detections[conf_mask]
|
||
|
if not detections.size(0):
|
||
|
continue
|
||
|
|
||
|
if LooseVersion(torchvision.__version__) >= LooseVersion('0.8.0'):
|
||
|
nms_out_index = torchvision.ops.batched_nms(
|
||
|
detections[:, :4], detections[:, 4] * detections[:, 5],
|
||
|
detections[:, 6], nms_thre)
|
||
|
else:
|
||
|
nms_out_index = torchvision.ops.nms(
|
||
|
detections[:, :4], detections[:, 4] * detections[:, 5],
|
||
|
nms_thre)
|
||
|
|
||
|
detections = detections[nms_out_index]
|
||
|
if output[i] is None:
|
||
|
output[i] = detections
|
||
|
else:
|
||
|
output[i] = torch.cat((output[i], detections))
|
||
|
|
||
|
return output
|