# encoding: utf-8 """ @author: liaoxingyu @contact: xyliao1993@qq.com """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import numpy as np import torch class BaseEvaluator(object): def __init__(self, model): self.model = model def evaluate(self, queryloader, galleryloader, ranks=[1, 5, 10, 20]): self.model.eval() qf, q_pids, q_camids = [], [], [] for batch_idx, inputs in enumerate(queryloader): inputs, pids, camids = self._parse_data(inputs) feature = self._forward(inputs) qf.append(feature) q_pids.extend(pids) q_camids.extend(camids) qf = torch.cat(qf, 0) q_pids = np.asarray(q_pids) q_camids = np.asarray(q_camids) print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) gf, g_pids, g_camids = [], [], [] for batch_idx, inputs in enumerate(galleryloader): inputs, pids, camids = self._parse_data(inputs) feature = self._forward(inputs) gf.append(feature) g_pids.extend(pids) g_camids.extend(camids) gf = torch.cat(gf, 0) g_pids = np.asarray(g_pids) g_camids = np.asarray(g_camids) print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) print("Computing distance matrix") m, n = qf.size(0), gf.size(0) distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.numpy() print("Computing CMC and mAP") cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids) print("Results ----------") print("mAP: {:.1%}".format(mAP)) print("CMC curve") for r in ranks: print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1])) print("------------------") return cmc[0] def _parse_data(self, inputs): raise NotImplementedError def _forward(self, inputs): raise NotImplementedError def eval_func(self, distmat, q_pids, g_pids, q_camids, g_camids): raise NotImplementedError