fast-reid/utils/validation_metrics.py

26 lines
620 B
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
def accuracy(score, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = score.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
ret = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
ret.append(correct_k.mul_(1. / batch_size))
return ret