97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import numpy as np
|
|
import torch
|
|
from ignite.metrics import Metric
|
|
|
|
from data.datasets.eval_reid import eval_func
|
|
from .re_ranking import re_ranking
|
|
|
|
|
|
class R1_mAP(Metric):
|
|
def __init__(self, num_query, max_rank=50, feat_norm='yes'):
|
|
super(R1_mAP, self).__init__()
|
|
self.num_query = num_query
|
|
self.max_rank = max_rank
|
|
self.feat_norm = feat_norm
|
|
|
|
def reset(self):
|
|
self.feats = []
|
|
self.pids = []
|
|
self.camids = []
|
|
|
|
def update(self, output):
|
|
feat, pid, camid = output
|
|
self.feats.append(feat)
|
|
self.pids.extend(np.asarray(pid))
|
|
self.camids.extend(np.asarray(camid))
|
|
|
|
def compute(self):
|
|
feats = torch.cat(self.feats, dim=0)
|
|
if self.feat_norm == 'yes':
|
|
print("The test feature is normalized")
|
|
feats = torch.nn.functional.normalize(feats, dim=1, p=2)
|
|
# query
|
|
qf = feats[:self.num_query]
|
|
q_pids = np.asarray(self.pids[:self.num_query])
|
|
q_camids = np.asarray(self.camids[:self.num_query])
|
|
# gallery
|
|
gf = feats[self.num_query:]
|
|
g_pids = np.asarray(self.pids[self.num_query:])
|
|
g_camids = np.asarray(self.camids[self.num_query:])
|
|
m, n = qf.shape[0], gf.shape[0]
|
|
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
|
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
|
distmat.addmm_(1, -2, qf, gf.t())
|
|
distmat = distmat.cpu().numpy()
|
|
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
|
|
|
|
return cmc, mAP
|
|
|
|
|
|
class R1_mAP_reranking(Metric):
|
|
def __init__(self, num_query, max_rank=50, feat_norm='yes'):
|
|
super(R1_mAP_reranking, self).__init__()
|
|
self.num_query = num_query
|
|
self.max_rank = max_rank
|
|
self.feat_norm = feat_norm
|
|
|
|
def reset(self):
|
|
self.feats = []
|
|
self.pids = []
|
|
self.camids = []
|
|
|
|
def update(self, output):
|
|
feat, pid, camid = output
|
|
self.feats.append(feat)
|
|
self.pids.extend(np.asarray(pid))
|
|
self.camids.extend(np.asarray(camid))
|
|
|
|
def compute(self):
|
|
feats = torch.cat(self.feats, dim=0)
|
|
if self.feat_norm == 'yes':
|
|
print("The test feature is normalized")
|
|
feats = torch.nn.functional.normalize(feats, dim=1, p=2)
|
|
|
|
# query
|
|
qf = feats[:self.num_query]
|
|
q_pids = np.asarray(self.pids[:self.num_query])
|
|
q_camids = np.asarray(self.camids[:self.num_query])
|
|
# gallery
|
|
gf = feats[self.num_query:]
|
|
g_pids = np.asarray(self.pids[self.num_query:])
|
|
g_camids = np.asarray(self.camids[self.num_query:])
|
|
# m, n = qf.shape[0], gf.shape[0]
|
|
# distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
|
# torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
|
# distmat.addmm_(1, -2, qf, gf.t())
|
|
# distmat = distmat.cpu().numpy()
|
|
print("Enter reranking")
|
|
distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3)
|
|
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
|
|
|
|
return cmc, mAP |