fast-reid/fastreid/evaluation/clas_evaluator.py

85 lines
2.2 KiB
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import copy
import itertools
import logging
from collections import OrderedDict
import torch
from fastreid.utils import comm
from .evaluator import DatasetEvaluator
logger = logging.getLogger(__name__)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class ClasEvaluator(DatasetEvaluator):
def __init__(self, cfg, output_dir=None):
self.cfg = cfg
self._output_dir = output_dir
self._cpu_device = torch.device('cpu')
self._predictions = []
def reset(self):
self._predictions = []
def process(self, inputs, outputs):
predictions = {
"logits": outputs.to(self._cpu_device, torch.float32),
"labels": inputs["targets"],
}
self._predictions.append(predictions)
def evaluate(self):
if comm.get_world_size() > 1:
comm.synchronize()
predictions = comm.gather(self._predictions, dst=0)
predictions = list(itertools.chain(*predictions))
if not comm.is_main_process(): return {}
else:
predictions = self._predictions
pred_logits = []
labels = []
for prediction in predictions:
pred_logits.append(prediction['logits'])
labels.append(prediction['labels'])
pred_logits = torch.cat(pred_logits, dim=0)
labels = torch.cat(labels, dim=0)
# measure accuracy and record loss
acc1, = accuracy(pred_logits, labels, topk=(1,))
self._results = OrderedDict()
self._results["Acc@1"] = acc1
self._results["metric"] = acc1
return copy.deepcopy(self._results)