mirror of https://github.com/alibaba/EasyCV.git
149 lines
3.9 KiB
Python
149 lines
3.9 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/nms.py
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
def oks_iou(g, d, a_g, a_d, sigmas=None, vis_thr=None):
|
||
|
"""Calculate oks ious.
|
||
|
|
||
|
Args:
|
||
|
g: Ground truth keypoints.
|
||
|
d: Detected keypoints.
|
||
|
a_g: Area of the ground truth object.
|
||
|
a_d: Area of the detected object.
|
||
|
sigmas: standard deviation of keypoint labelling.
|
||
|
vis_thr: threshold of the keypoint visibility.
|
||
|
|
||
|
Returns:
|
||
|
list: The oks ious.
|
||
|
"""
|
||
|
if sigmas is None:
|
||
|
sigmas = np.array([
|
||
|
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07,
|
||
|
.87, .87, .89, .89
|
||
|
]) / 10.0
|
||
|
vars = (sigmas * 2)**2
|
||
|
xg = g[0::3]
|
||
|
yg = g[1::3]
|
||
|
vg = g[2::3]
|
||
|
ious = np.zeros(len(d), dtype=np.float32)
|
||
|
for n_d in range(0, len(d)):
|
||
|
xd = d[n_d, 0::3]
|
||
|
yd = d[n_d, 1::3]
|
||
|
vd = d[n_d, 2::3]
|
||
|
dx = xd - xg
|
||
|
dy = yd - yg
|
||
|
e = (dx**2 + dy**2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
|
||
|
if vis_thr is not None:
|
||
|
ind = list(vg > vis_thr) and list(vd > vis_thr)
|
||
|
e = e[ind]
|
||
|
ious[n_d] = np.sum(np.exp(-e)) / len(e) if len(e) != 0 else 0.0
|
||
|
return ious
|
||
|
|
||
|
|
||
|
def oks_nms(kpts_db, thr, sigmas=None, vis_thr=None):
|
||
|
"""OKS NMS implementations.
|
||
|
|
||
|
Args:
|
||
|
kpts_db: keypoints.
|
||
|
thr: Retain overlap < thr.
|
||
|
sigmas: standard deviation of keypoint labelling.
|
||
|
vis_thr: threshold of the keypoint visibility.
|
||
|
|
||
|
Returns:
|
||
|
np.ndarray: indexes to keep.
|
||
|
"""
|
||
|
if len(kpts_db) == 0:
|
||
|
return []
|
||
|
|
||
|
scores = np.array([k['score'] for k in kpts_db])
|
||
|
kpts = np.array([k['keypoints'].flatten() for k in kpts_db])
|
||
|
areas = np.array([k['area'] for k in kpts_db])
|
||
|
|
||
|
order = scores.argsort()[::-1]
|
||
|
|
||
|
keep = []
|
||
|
while len(order) > 0:
|
||
|
i = order[0]
|
||
|
keep.append(i)
|
||
|
|
||
|
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]],
|
||
|
sigmas, vis_thr)
|
||
|
|
||
|
inds = np.where(oks_ovr <= thr)[0]
|
||
|
order = order[inds + 1]
|
||
|
|
||
|
keep = np.array(keep)
|
||
|
|
||
|
return keep
|
||
|
|
||
|
|
||
|
def _rescore(overlap, scores, thr, type='gaussian'):
|
||
|
"""Rescoring mechanism gaussian or linear.
|
||
|
|
||
|
Args:
|
||
|
overlap: calculated ious
|
||
|
scores: target scores.
|
||
|
thr: retain oks overlap < thr.
|
||
|
type: 'gaussian' or 'linear'
|
||
|
|
||
|
Returns:
|
||
|
np.ndarray: indexes to keep
|
||
|
"""
|
||
|
assert len(overlap) == len(scores)
|
||
|
assert type in ['gaussian', 'linear']
|
||
|
|
||
|
if type == 'linear':
|
||
|
inds = np.where(overlap >= thr)[0]
|
||
|
scores[inds] = scores[inds] * (1 - overlap[inds])
|
||
|
else:
|
||
|
scores = scores * np.exp(-overlap**2 / thr)
|
||
|
|
||
|
return scores
|
||
|
|
||
|
|
||
|
def soft_oks_nms(kpts_db, thr, max_dets=20, sigmas=None, vis_thr=None):
|
||
|
"""Soft OKS NMS implementations.
|
||
|
|
||
|
Args:
|
||
|
kpts_db
|
||
|
thr: retain oks overlap < thr.
|
||
|
max_dets: max number of detections to keep.
|
||
|
sigmas: Keypoint labelling uncertainty.
|
||
|
|
||
|
Returns:
|
||
|
np.ndarray: indexes to keep.
|
||
|
"""
|
||
|
if len(kpts_db) == 0:
|
||
|
return []
|
||
|
|
||
|
scores = np.array([k['score'] for k in kpts_db])
|
||
|
kpts = np.array([k['keypoints'].flatten() for k in kpts_db])
|
||
|
areas = np.array([k['area'] for k in kpts_db])
|
||
|
|
||
|
order = scores.argsort()[::-1]
|
||
|
scores = scores[order]
|
||
|
|
||
|
keep = np.zeros(max_dets, dtype=np.intp)
|
||
|
keep_cnt = 0
|
||
|
while len(order) > 0 and keep_cnt < max_dets:
|
||
|
i = order[0]
|
||
|
|
||
|
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]],
|
||
|
sigmas, vis_thr)
|
||
|
|
||
|
order = order[1:]
|
||
|
scores = _rescore(oks_ovr, scores[1:], thr)
|
||
|
|
||
|
tmp = scores.argsort()[::-1]
|
||
|
order = order[tmp]
|
||
|
scores = scores[tmp]
|
||
|
|
||
|
keep[keep_cnt] = i
|
||
|
keep_cnt += 1
|
||
|
|
||
|
keep = keep[:keep_cnt]
|
||
|
|
||
|
return keep
|