mirror of https://github.com/FoundationVision/GLEE
547 lines
24 KiB
Python
547 lines
24 KiB
Python
__author__ = 'tsungyi'
|
|
|
|
import numpy as np
|
|
import datetime
|
|
import time
|
|
from collections import defaultdict
|
|
from pycocotools import mask as maskUtils
|
|
from pycocotools.mask import decode
|
|
import copy
|
|
|
|
import torch
|
|
from torch._C import InterfaceType
|
|
from torchvision.ops.boxes import box_area
|
|
|
|
def compute_bbox_iou(boxes1: torch.Tensor, boxes2: torch.Tensor):
|
|
# both boxes: xyxy
|
|
area1 = box_area(boxes1)
|
|
area2 = box_area(boxes2)
|
|
|
|
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
|
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
|
|
|
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
|
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
|
|
|
union = area1[:, None] + area2 - inter
|
|
|
|
iou = (inter+1e-6) / (union+1e-6)
|
|
return iou, inter, union
|
|
|
|
def compute_mask_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6):
|
|
outputs = outputs.int()
|
|
intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0
|
|
union = (outputs | labels).float().sum((1, 2)) # Will be zero if both are 0
|
|
iou = (intersection + EPS) / (union + EPS) # EPS is used to avoid division by zero
|
|
return iou, intersection, union
|
|
|
|
|
|
class RefCOCOeval:
|
|
# Interface for evaluating detection on the Microsoft COCO dataset.
|
|
#
|
|
# The usage for CocoEval is as follows:
|
|
# cocoGt=..., cocoDt=... # load dataset and results
|
|
# E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
|
|
# E.params.recThrs = ...; # set parameters as desired
|
|
# E.evaluate(); # run per image evaluation
|
|
# E.accumulate(); # accumulate per image results
|
|
# E.summarize(); # display summary metrics of results
|
|
# For example usage see evalDemo.m and http://mscoco.org/.
|
|
#
|
|
# The evaluation parameters are as follows (defaults in brackets):
|
|
# imgIds - [all] N img ids to use for evaluation
|
|
# catIds - [all] K cat ids to use for evaluation
|
|
# iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
|
|
# recThrs - [0:.01:1] R=101 recall thresholds for evaluation
|
|
# areaRng - [...] A=4 object area ranges for evaluation
|
|
# maxDets - [1 10 100] M=3 thresholds on max detections per image
|
|
# iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints'
|
|
# iouType replaced the now DEPRECATED useSegm parameter.
|
|
# useCats - [1] if true use category labels for evaluation
|
|
# Note: if useCats=0 category labels are ignored as in proposal scoring.
|
|
# Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
|
|
#
|
|
# evaluate(): evaluates detections on every image and every category and
|
|
# concats the results into the "evalImgs" with fields:
|
|
# dtIds - [1xD] id for each of the D detections (dt)
|
|
# gtIds - [1xG] id for each of the G ground truths (gt)
|
|
# dtMatches - [TxD] matching gt id at each IoU or 0
|
|
# gtMatches - [TxG] matching dt id at each IoU or 0
|
|
# dtScores - [1xD] confidence of each dt
|
|
# gtIgnore - [1xG] ignore flag for each gt
|
|
# dtIgnore - [TxD] ignore flag for each dt at each IoU
|
|
#
|
|
# accumulate(): accumulates the per-image, per-category evaluation
|
|
# results in "evalImgs" into the dictionary "eval" with fields:
|
|
# params - parameters used for evaluation
|
|
# date - date evaluation was performed
|
|
# counts - [T,R,K,A,M] parameter dimensions (see above)
|
|
# precision - [TxRxKxAxM] precision for every evaluation setting
|
|
# recall - [TxKxAxM] max recall for every evaluation setting
|
|
# Note: precision and recall==-1 for settings with no gt objects.
|
|
#
|
|
# See also coco, mask, pycocoDemo, pycocoEvalDemo
|
|
#
|
|
# Microsoft COCO Toolbox. version 2.0
|
|
# Data, paper, and tutorials available at: http://mscoco.org/
|
|
# Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
|
|
# Licensed under the Simplified BSD License [see coco/license.txt]
|
|
def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
|
|
'''
|
|
Initialize CocoEval using coco APIs for gt and dt
|
|
:param cocoGt: coco object with ground truth annotations
|
|
:param cocoDt: coco object with detection results
|
|
:return: None
|
|
'''
|
|
if not iouType:
|
|
print('iouType not specified. use default iouType segm')
|
|
self.cocoGt = cocoGt # ground truth COCO API
|
|
self.cocoDt = cocoDt # detections COCO API
|
|
self.evalImgs = defaultdict(list) # per-image per-category evaluation results [KxAxI] elements
|
|
self.eval = {} # accumulated evaluation results
|
|
self._gts = defaultdict(list) # gt for evaluation
|
|
self._dts = defaultdict(list) # dt for evaluation
|
|
self.params = Params(iouType=iouType) # parameters
|
|
self._paramsEval = {} # parameters for evaluation
|
|
self.stats = [] # result summarization
|
|
self.ious = {} # ious between all gts and dts
|
|
# for computing overall iou
|
|
self.total_intersection_area = 0
|
|
self.total_union_area = 0
|
|
self.iou_list = []
|
|
if not cocoGt is None:
|
|
self.params.imgIds = sorted(cocoGt.getImgIds())
|
|
self.params.catIds = sorted(cocoGt.getCatIds())
|
|
|
|
|
|
def _prepare(self):
|
|
'''
|
|
Prepare ._gts and ._dts for evaluation based on params
|
|
:return: None
|
|
'''
|
|
def _toMask(anns, coco):
|
|
# modify ann['segmentation'] by reference
|
|
for ann in anns:
|
|
rle = coco.annToRLE(ann)
|
|
ann['segmentation'] = rle
|
|
p = self.params
|
|
if p.useCats:
|
|
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
|
|
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
|
|
else:
|
|
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
|
|
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
|
|
|
|
# convert ground truth to mask if iouType == 'segm'
|
|
if p.iouType == 'segm':
|
|
_toMask(gts, self.cocoGt)
|
|
_toMask(dts, self.cocoDt)
|
|
# set ignore flag
|
|
for gt in gts:
|
|
gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0
|
|
gt['ignore'] = 'iscrowd' in gt and gt['iscrowd']
|
|
if p.iouType == 'keypoints':
|
|
gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
|
|
self._gts = defaultdict(list) # gt for evaluation
|
|
self._dts = defaultdict(list) # dt for evaluation
|
|
for gt in gts:
|
|
self._gts[gt['image_id'], gt['category_id']].append(gt)
|
|
for dt in dts:
|
|
self._dts[dt['image_id'], dt['category_id']].append(dt)
|
|
self.evalImgs = defaultdict(list) # per-image per-category evaluation results
|
|
self.eval = {} # accumulated evaluation results
|
|
|
|
def evaluate(self):
|
|
'''
|
|
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
|
:return: None
|
|
'''
|
|
tic = time.time()
|
|
print('Running per image evaluation...')
|
|
p = self.params
|
|
# add backward compatibility if useSegm is specified in params
|
|
if not p.useSegm is None:
|
|
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
|
|
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
|
|
print('Evaluate annotation type *{}*'.format(p.iouType))
|
|
p.imgIds = list(np.unique(p.imgIds))
|
|
if p.useCats:
|
|
p.catIds = list(np.unique(p.catIds))
|
|
p.maxDets = sorted(p.maxDets)
|
|
self.params=p
|
|
|
|
self._prepare()
|
|
# loop through images, area range, max detection number
|
|
catIds = p.catIds if p.useCats else [-1]
|
|
|
|
if p.iouType == 'segm' or p.iouType == 'bbox':
|
|
computeIoU = self.computeIoU
|
|
elif p.iouType == 'keypoints':
|
|
computeIoU = self.computeOks
|
|
self.ious = {(imgId, catId): computeIoU(imgId, catId) \
|
|
for imgId in p.imgIds
|
|
for catId in catIds}
|
|
# evaluateImg = self.evaluateImg
|
|
# maxDet = p.maxDets[-1]
|
|
# self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet)
|
|
# for catId in catIds
|
|
# for areaRng in p.areaRng
|
|
# for imgId in p.imgIds
|
|
# ]
|
|
# self._paramsEval = copy.deepcopy(self.params)
|
|
toc = time.time()
|
|
print('DONE (t={:0.2f}s).'.format(toc-tic))
|
|
|
|
def computeIoU(self, imgId, catId):
|
|
p = self.params
|
|
if p.useCats:
|
|
gt = self._gts[imgId,catId]
|
|
dt = self._dts[imgId,catId]
|
|
else:
|
|
gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
|
|
dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
|
|
if len(gt) == 0 and len(dt) ==0:
|
|
return []
|
|
inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
|
|
dt = [dt[i] for i in inds]
|
|
if len(dt) > p.maxDets[-1]:
|
|
dt=dt[0:p.maxDets[-1]]
|
|
|
|
if p.iouType == 'segm':
|
|
g = [g['segmentation'] for g in gt]
|
|
d = [d['segmentation'] for d in dt]
|
|
elif p.iouType == 'bbox':
|
|
g = [g['bbox'] for g in gt]
|
|
d = [d['bbox'] for d in dt]
|
|
else:
|
|
raise Exception('unknown iouType for iou computation')
|
|
|
|
# compute iou between each dt and gt region
|
|
iscrowd = [int(o['iscrowd']) for o in gt]
|
|
ious = maskUtils.iou(d,g,iscrowd)
|
|
|
|
# for computing overall iou
|
|
# there is only one bbox and segm
|
|
if p.iouType == 'bbox':
|
|
g, d = g[0], d[0]
|
|
g_bbox = [g[0], g[1], g[2] + g[0], g[3] + g[1]] # x1y1wh -> x1y1x2y2
|
|
d_bbox = [d[0], d[1], d[2] + d[0], d[3] + d[1]] # x1y1wh -> x1y1x2y2
|
|
g_bbox = torch.tensor(g_bbox).unsqueeze(0)
|
|
d_bbox = torch.tensor(d_bbox).unsqueeze(0)
|
|
iou, intersection, union = compute_bbox_iou(d_bbox, g_bbox)
|
|
elif p.iouType == 'segm':
|
|
g_segm = decode(g[0])
|
|
d_segm = decode(d[0])
|
|
g_segm = torch.tensor(g_segm).unsqueeze(0)
|
|
d_segm = torch.tensor(d_segm).unsqueeze(0)
|
|
iou, intersection, union = compute_mask_iou(d_segm, g_segm)
|
|
else:
|
|
raise Exception('unknown iouType for iou computation')
|
|
iou, intersection, union = iou.item(), intersection.item(), union.item()
|
|
self.total_intersection_area += intersection
|
|
self.total_union_area += union
|
|
self.iou_list.append(iou)
|
|
return ious
|
|
|
|
|
|
def evaluateImg(self, imgId, catId, aRng, maxDet):
|
|
'''
|
|
perform evaluation for single category and image
|
|
:return: dict (single image results)
|
|
'''
|
|
p = self.params
|
|
if p.useCats:
|
|
gt = self._gts[imgId,catId]
|
|
dt = self._dts[imgId,catId]
|
|
else:
|
|
gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
|
|
dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
|
|
if len(gt) == 0 and len(dt) ==0:
|
|
return None
|
|
|
|
for g in gt:
|
|
if g['ignore'] or (g['area']<aRng[0] or g['area']>aRng[1]):
|
|
g['_ignore'] = 1
|
|
else:
|
|
g['_ignore'] = 0
|
|
|
|
# sort dt highest score first, sort gt ignore last
|
|
gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
|
|
gt = [gt[i] for i in gtind]
|
|
dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
|
|
dt = [dt[i] for i in dtind[0:maxDet]]
|
|
iscrowd = [int(o['iscrowd']) for o in gt]
|
|
# load computed ious
|
|
ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]
|
|
|
|
T = len(p.iouThrs)
|
|
G = len(gt)
|
|
D = len(dt)
|
|
gtm = np.zeros((T,G))
|
|
dtm = np.zeros((T,D))
|
|
gtIg = np.array([g['_ignore'] for g in gt])
|
|
dtIg = np.zeros((T,D))
|
|
if not len(ious)==0:
|
|
for tind, t in enumerate(p.iouThrs):
|
|
for dind, d in enumerate(dt):
|
|
# information about best match so far (m=-1 -> unmatched)
|
|
iou = min([t,1-1e-10])
|
|
m = -1
|
|
for gind, g in enumerate(gt):
|
|
# if this gt already matched, and not a crowd, continue
|
|
if gtm[tind,gind]>0 and not iscrowd[gind]:
|
|
continue
|
|
# if dt matched to reg gt, and on ignore gt, stop
|
|
if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
|
|
break
|
|
# continue to next gt unless better match made
|
|
if ious[dind,gind] < iou:
|
|
continue
|
|
# if match successful and best so far, store appropriately
|
|
iou=ious[dind,gind]
|
|
m=gind
|
|
# if match made store id of match for both dt and gt
|
|
if m ==-1:
|
|
continue
|
|
dtIg[tind,dind] = gtIg[m]
|
|
dtm[tind,dind] = gt[m]['id']
|
|
gtm[tind,m] = d['id']
|
|
# set unmatched detections outside of area range to ignore
|
|
a = np.array([d['area']<aRng[0] or d['area']>aRng[1] for d in dt]).reshape((1, len(dt)))
|
|
dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0)))
|
|
# store results for given image and category
|
|
return {
|
|
'image_id': imgId,
|
|
'category_id': catId,
|
|
'aRng': aRng,
|
|
'maxDet': maxDet,
|
|
'dtIds': [d['id'] for d in dt],
|
|
'gtIds': [g['id'] for g in gt],
|
|
'dtMatches': dtm,
|
|
'gtMatches': gtm,
|
|
'dtScores': [d['score'] for d in dt],
|
|
'gtIgnore': gtIg,
|
|
'dtIgnore': dtIg,
|
|
}
|
|
|
|
def accumulate(self, p = None):
|
|
'''
|
|
Accumulate per image evaluation results and store the result in self.eval
|
|
:param p: input params for evaluation
|
|
:return: None
|
|
'''
|
|
print('Accumulating evaluation results...')
|
|
tic = time.time()
|
|
if not self.evalImgs:
|
|
print('Please run evaluate() first')
|
|
# allows input customized parameters
|
|
if p is None:
|
|
p = self.params
|
|
p.catIds = p.catIds if p.useCats == 1 else [-1]
|
|
T = len(p.iouThrs)
|
|
R = len(p.recThrs)
|
|
K = len(p.catIds) if p.useCats else 1
|
|
A = len(p.areaRng)
|
|
M = len(p.maxDets)
|
|
precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories
|
|
recall = -np.ones((T,K,A,M))
|
|
scores = -np.ones((T,R,K,A,M))
|
|
|
|
# create dictionary for future indexing
|
|
_pe = self._paramsEval
|
|
catIds = _pe.catIds if _pe.useCats else [-1]
|
|
setK = set(catIds)
|
|
setA = set(map(tuple, _pe.areaRng))
|
|
setM = set(_pe.maxDets)
|
|
setI = set(_pe.imgIds)
|
|
# get inds to evaluate
|
|
k_list = [n for n, k in enumerate(p.catIds) if k in setK]
|
|
m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
|
|
a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
|
|
i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
|
|
I0 = len(_pe.imgIds)
|
|
A0 = len(_pe.areaRng)
|
|
# retrieve E at each category, area range, and max number of detections
|
|
for k, k0 in enumerate(k_list):
|
|
Nk = k0*A0*I0
|
|
for a, a0 in enumerate(a_list):
|
|
Na = a0*I0
|
|
for m, maxDet in enumerate(m_list):
|
|
E = [self.evalImgs[Nk + Na + i] for i in i_list]
|
|
E = [e for e in E if not e is None]
|
|
if len(E) == 0:
|
|
continue
|
|
dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E])
|
|
|
|
# different sorting method generates slightly different results.
|
|
# mergesort is used to be consistent as Matlab implementation.
|
|
inds = np.argsort(-dtScores, kind='mergesort')
|
|
dtScoresSorted = dtScores[inds]
|
|
|
|
dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds]
|
|
dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds]
|
|
gtIg = np.concatenate([e['gtIgnore'] for e in E])
|
|
npig = np.count_nonzero(gtIg==0 )
|
|
if npig == 0:
|
|
continue
|
|
tps = np.logical_and( dtm, np.logical_not(dtIg) )
|
|
fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) )
|
|
|
|
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
|
|
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
|
|
for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
|
|
tp = np.array(tp)
|
|
fp = np.array(fp)
|
|
nd = len(tp)
|
|
rc = tp / npig
|
|
pr = tp / (fp+tp+np.spacing(1))
|
|
q = np.zeros((R,))
|
|
ss = np.zeros((R,))
|
|
|
|
if nd:
|
|
recall[t,k,a,m] = rc[-1]
|
|
else:
|
|
recall[t,k,a,m] = 0
|
|
|
|
# numpy is slow without cython optimization for accessing elements
|
|
# use python array gets significant speed improvement
|
|
pr = pr.tolist(); q = q.tolist()
|
|
|
|
for i in range(nd-1, 0, -1):
|
|
if pr[i] > pr[i-1]:
|
|
pr[i-1] = pr[i]
|
|
|
|
inds = np.searchsorted(rc, p.recThrs, side='left')
|
|
try:
|
|
for ri, pi in enumerate(inds):
|
|
q[ri] = pr[pi]
|
|
ss[ri] = dtScoresSorted[pi]
|
|
except:
|
|
pass
|
|
precision[t,:,k,a,m] = np.array(q)
|
|
scores[t,:,k,a,m] = np.array(ss)
|
|
self.eval = {
|
|
'params': p,
|
|
'counts': [T, R, K, A, M],
|
|
'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
|
'precision': precision,
|
|
'recall': recall,
|
|
'scores': scores,
|
|
}
|
|
toc = time.time()
|
|
print('DONE (t={:0.2f}s).'.format( toc-tic))
|
|
|
|
def summarize(self):
|
|
'''
|
|
Compute and display summary metrics for evaluation results.
|
|
Note this functin can *only* be applied on the default parameter setting
|
|
'''
|
|
def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
|
|
p = self.params
|
|
iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
|
|
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
|
|
typeStr = '(AP)' if ap==1 else '(AR)'
|
|
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
|
|
if iouThr is None else '{:0.2f}'.format(iouThr)
|
|
|
|
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
|
|
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
|
if ap == 1:
|
|
# dimension of precision: [TxRxKxAxM]
|
|
s = self.eval['precision']
|
|
# IoU
|
|
if iouThr is not None:
|
|
t = np.where(iouThr == p.iouThrs)[0]
|
|
s = s[t]
|
|
s = s[:,:,:,aind,mind]
|
|
else:
|
|
# dimension of recall: [TxKxAxM]
|
|
s = self.eval['recall']
|
|
if iouThr is not None:
|
|
t = np.where(iouThr == p.iouThrs)[0]
|
|
s = s[t]
|
|
s = s[:,:,aind,mind]
|
|
if len(s[s>-1])==0:
|
|
mean_s = -1
|
|
else:
|
|
mean_s = np.mean(s[s>-1])
|
|
print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
|
|
return mean_s
|
|
def _summarizeDets():
|
|
stats = np.zeros((12,))
|
|
stats[0] = _summarize(1)
|
|
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
|
|
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
|
|
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
|
|
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
|
|
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
|
|
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
|
|
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
|
|
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
|
|
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
|
|
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
|
|
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
|
|
return stats
|
|
def _summarizeKps():
|
|
stats = np.zeros((10,))
|
|
stats[0] = _summarize(1, maxDets=20)
|
|
stats[1] = _summarize(1, maxDets=20, iouThr=.5)
|
|
stats[2] = _summarize(1, maxDets=20, iouThr=.75)
|
|
stats[3] = _summarize(1, maxDets=20, areaRng='medium')
|
|
stats[4] = _summarize(1, maxDets=20, areaRng='large')
|
|
stats[5] = _summarize(0, maxDets=20)
|
|
stats[6] = _summarize(0, maxDets=20, iouThr=.5)
|
|
stats[7] = _summarize(0, maxDets=20, iouThr=.75)
|
|
stats[8] = _summarize(0, maxDets=20, areaRng='medium')
|
|
stats[9] = _summarize(0, maxDets=20, areaRng='large')
|
|
return stats
|
|
if not self.eval:
|
|
raise Exception('Please run accumulate() first')
|
|
iouType = self.params.iouType
|
|
if iouType == 'segm' or iouType == 'bbox':
|
|
summarize = _summarizeDets
|
|
elif iouType == 'keypoints':
|
|
summarize = _summarizeKps
|
|
self.stats = summarize()
|
|
|
|
def __str__(self):
|
|
self.summarize()
|
|
|
|
class Params:
|
|
'''
|
|
Params for coco evaluation api
|
|
'''
|
|
def setDetParams(self):
|
|
self.imgIds = []
|
|
self.catIds = []
|
|
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
|
self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
|
self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
|
|
self.maxDets = [1, 10, 100]
|
|
self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 32 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
|
|
self.areaRngLbl = ['all', 'small', 'medium', 'large']
|
|
self.useCats = 1
|
|
|
|
def setKpParams(self):
|
|
self.imgIds = []
|
|
self.catIds = []
|
|
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
|
self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
|
self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
|
|
self.maxDets = [20]
|
|
self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
|
|
self.areaRngLbl = ['all', 'medium', 'large']
|
|
self.useCats = 1
|
|
self.kpt_oks_sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
|
|
|
|
def __init__(self, iouType='segm'):
|
|
if iouType == 'segm' or iouType == 'bbox':
|
|
self.setDetParams()
|
|
elif iouType == 'keypoints':
|
|
self.setKpParams()
|
|
else:
|
|
raise Exception('iouType not supported')
|
|
self.iouType = iouType
|
|
# useSegm is deprecated
|
|
self.useSegm = None
|