mirror of https://github.com/JDAI-CV/fast-reid.git
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: xyliao1993@qq.com
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class BaseEvaluator(object):
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def evaluate(self, queryloader, galleryloader, ranks=[1, 5, 10, 20]):
|
|
self.model.eval()
|
|
qf, q_pids, q_camids = [], [], []
|
|
for batch_idx, inputs in enumerate(queryloader):
|
|
inputs, pids, camids = self._parse_data(inputs)
|
|
feature = self._forward(inputs)
|
|
qf.append(feature)
|
|
q_pids.extend(pids)
|
|
q_camids.extend(camids)
|
|
qf = torch.cat(qf, 0)
|
|
q_pids = np.asarray(q_pids)
|
|
q_camids = np.asarray(q_camids)
|
|
|
|
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
|
|
|
|
gf, g_pids, g_camids = [], [], []
|
|
for batch_idx, inputs in enumerate(galleryloader):
|
|
inputs, pids, camids = self._parse_data(inputs)
|
|
feature = self._forward(inputs)
|
|
gf.append(feature)
|
|
g_pids.extend(pids)
|
|
g_camids.extend(camids)
|
|
gf = torch.cat(gf, 0)
|
|
g_pids = np.asarray(g_pids)
|
|
g_camids = np.asarray(g_camids)
|
|
|
|
print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
|
|
|
|
print("Computing distance matrix")
|
|
|
|
m, n = qf.size(0), gf.size(0)
|
|
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
|
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
|
distmat.addmm_(1, -2, qf, gf.t())
|
|
distmat = distmat.numpy()
|
|
|
|
print("Computing CMC and mAP")
|
|
cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
|
|
|
|
print("Results ----------")
|
|
print("mAP: {:.1%}".format(mAP))
|
|
print("CMC curve")
|
|
for r in ranks:
|
|
print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))
|
|
print("------------------")
|
|
|
|
return cmc[0]
|
|
|
|
def _parse_data(self, inputs):
|
|
raise NotImplementedError
|
|
|
|
def _forward(self, inputs):
|
|
raise NotImplementedError
|
|
|
|
def eval_func(self, distmat, q_pids, g_pids, q_camids, g_camids):
|
|
raise NotImplementedError
|