mirror of https://github.com/JDAI-CV/fast-reid.git
85 lines
2.2 KiB
Python
85 lines
2.2 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import copy
|
|
import itertools
|
|
import logging
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
from fastreid.utils import comm
|
|
from .evaluator import DatasetEvaluator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)):
|
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
with torch.no_grad():
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
res = []
|
|
for k in topk:
|
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
return res
|
|
|
|
|
|
class ClasEvaluator(DatasetEvaluator):
|
|
def __init__(self, cfg, output_dir=None):
|
|
self.cfg = cfg
|
|
self._output_dir = output_dir
|
|
|
|
self._cpu_device = torch.device('cpu')
|
|
|
|
self._predictions = []
|
|
|
|
def reset(self):
|
|
self._predictions = []
|
|
|
|
def process(self, inputs, outputs):
|
|
predictions = {
|
|
"logits": outputs.to(self._cpu_device, torch.float32),
|
|
"labels": inputs["targets"],
|
|
}
|
|
self._predictions.append(predictions)
|
|
|
|
def evaluate(self):
|
|
if comm.get_world_size() > 1:
|
|
comm.synchronize()
|
|
predictions = comm.gather(self._predictions, dst=0)
|
|
predictions = list(itertools.chain(*predictions))
|
|
|
|
if not comm.is_main_process(): return {}
|
|
|
|
else:
|
|
predictions = self._predictions
|
|
|
|
pred_logits = []
|
|
labels = []
|
|
for prediction in predictions:
|
|
pred_logits.append(prediction['logits'])
|
|
labels.append(prediction['labels'])
|
|
|
|
pred_logits = torch.cat(pred_logits, dim=0)
|
|
labels = torch.cat(labels, dim=0)
|
|
|
|
# measure accuracy and record loss
|
|
acc1, = accuracy(pred_logits, labels, topk=(1,))
|
|
|
|
self._results = OrderedDict()
|
|
self._results["Acc@1"] = acc1
|
|
|
|
self._results["metric"] = acc1
|
|
|
|
return copy.deepcopy(self._results)
|