2019-01-10 18:39:31 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
2019-07-31 21:11:54 +08:00
|
|
|
@author: liaoxingyu
|
2019-01-10 18:39:31 +08:00
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
import logging
|
|
|
|
|
|
|
|
import torch
|
2019-07-31 21:11:54 +08:00
|
|
|
from data.datasets.eval_reid import evaluate
|
|
|
|
from fastai.vision import *
|
2019-01-10 18:39:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(
|
|
|
|
cfg,
|
|
|
|
model,
|
2019-07-31 21:11:54 +08:00
|
|
|
data_bunch,
|
|
|
|
test_labels,
|
2019-01-10 18:39:31 +08:00
|
|
|
num_query
|
|
|
|
):
|
|
|
|
logger = logging.getLogger("reid_baseline.inference")
|
|
|
|
logger.info("Start inferencing")
|
|
|
|
|
2019-07-31 21:11:54 +08:00
|
|
|
pids = []
|
|
|
|
camids = []
|
|
|
|
for p, c in test_labels:
|
|
|
|
pids.append(p)
|
|
|
|
camids.append(c)
|
|
|
|
q_pids = np.asarray(pids[:num_query])
|
|
|
|
g_pids = np.asarray(pids[num_query:])
|
|
|
|
q_camids = np.asarray(camids[:num_query])
|
|
|
|
g_camids = np.asarray(camids[num_query:])
|
|
|
|
feats = []
|
|
|
|
model.eval()
|
|
|
|
for imgs, _ in data_bunch.test_dl:
|
|
|
|
with torch.no_grad():
|
|
|
|
feat = model(imgs)
|
|
|
|
feats.append(feat)
|
|
|
|
feats = torch.cat(feats, dim=0)
|
|
|
|
|
|
|
|
qf = feats[:num_query]
|
|
|
|
gf = feats[num_query:]
|
|
|
|
m, n = qf.shape[0], gf.shape[0]
|
|
|
|
distmat = torch.pow(qf,2).sum(dim=1,keepdim=True).expand(m,n) + \
|
|
|
|
torch.pow(gf,2).sum(dim=1,keepdim=True).expand(n,m).t()
|
|
|
|
distmat.addmm_(1, -2, qf, gf.t())
|
|
|
|
distmat = to_np(distmat)
|
|
|
|
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
|
|
|
|
logger.info('Test Results')
|
2019-01-10 18:39:31 +08:00
|
|
|
logger.info("mAP: {:.1%}".format(mAP))
|
|
|
|
for r in [1, 5, 10]:
|
|
|
|
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|