fast-reid/fastreid/evaluation/reid_evaluation.py

61 lines
1.8 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
from collections import OrderedDict
2020-02-27 12:16:57 +08:00
import numpy as np
import torch
import torch.nn.functional as F
2020-02-10 07:38:56 +08:00
from .evaluator import DatasetEvaluator
from .rank import evaluate_rank
class ReidEvaluator(DatasetEvaluator):
2020-02-27 12:16:57 +08:00
def __init__(self, cfg, num_query, output_dir=None):
2020-02-10 07:38:56 +08:00
self._num_query = num_query
2020-02-27 12:16:57 +08:00
self._output_dir = output_dir
2020-02-10 22:13:04 +08:00
2020-02-10 07:38:56 +08:00
self.features = []
self.pids = []
self.camids = []
def reset(self):
self.features = []
self.pids = []
self.camids = []
def process(self, outputs):
self.features.append(outputs[0].cpu())
self.pids.extend(outputs[1].cpu().numpy())
self.camids.extend(outputs[2].cpu().numpy())
2020-02-10 07:38:56 +08:00
def evaluate(self):
features = torch.cat(self.features, dim=0)
# normalize feature
features = F.normalize(features, dim=1)
2020-02-10 07:38:56 +08:00
# query feature, person ids and camera ids
query_features = features[:self._num_query]
2020-02-27 12:16:57 +08:00
query_pids = np.asarray(self.pids[:self._num_query])
query_camids = np.asarray(self.camids[:self._num_query])
2020-02-10 07:38:56 +08:00
# gallery features, person ids and camera ids
gallery_features = features[self._num_query:]
2020-02-27 12:16:57 +08:00
gallery_pids = np.asarray(self.pids[self._num_query:])
gallery_camids = np.asarray(self.camids[self._num_query:])
2020-02-10 07:38:56 +08:00
self._results = OrderedDict()
cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
2020-03-25 10:58:26 +08:00
cmc, mAP, mINP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
2020-02-10 07:38:56 +08:00
for r in [1, 5, 10]:
self._results['Rank-{}'.format(r)] = cmc[r - 1]
self._results['mAP'] = mAP
2020-03-25 10:58:26 +08:00
self._results['mINP'] = mINP
2020-02-10 07:38:56 +08:00
return copy.deepcopy(self._results)