fast-reid/bases/base_evaluator.py

76 lines
2.3 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import torch
class BaseEvaluator(object):
def __init__(self, model):
self.model = model
def evaluate(self, queryloader, galleryloader, ranks=[1, 5, 10, 20]):
self.model.eval()
qf, q_pids, q_camids = [], [], []
for batch_idx, inputs in enumerate(queryloader):
inputs, pids, camids = self._parse_data(inputs)
feature = self._forward(inputs)
qf.append(feature)
q_pids.extend(pids)
q_camids.extend(camids)
qf = torch.cat(qf, 0)
q_pids = np.asarray(q_pids)
q_camids = np.asarray(q_camids)
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
gf, g_pids, g_camids = [], [], []
for batch_idx, inputs in enumerate(galleryloader):
inputs, pids, camids = self._parse_data(inputs)
feature = self._forward(inputs)
gf.append(feature)
g_pids.extend(pids)
g_camids.extend(camids)
gf = torch.cat(gf, 0)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)
print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
print("Computing distance matrix")
m, n = qf.size(0), gf.size(0)
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.numpy()
print("Computing CMC and mAP")
cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
print("Results ----------")
print("mAP: {:.1%}".format(mAP))
print("CMC curve")
for r in ranks:
print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))
print("------------------")
return cmc[0]
def _parse_data(self, inputs):
raise NotImplementedError
def _forward(self, inputs):
raise NotImplementedError
def eval_func(self, distmat, q_pids, g_pids, q_camids, g_camids):
raise NotImplementedError