2021-10-16 21:24:16 +08:00
|
|
|
|
# coding: utf-8
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
|
|
import itertools
|
|
|
|
|
import logging
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from fastreid.utils import comm
|
|
|
|
|
from sklearn import metrics as skmetrics
|
|
|
|
|
|
2021-10-16 22:21:49 +08:00
|
|
|
|
from fastreid.modeling.losses.utils import normalize
|
|
|
|
|
from .evaluator import DatasetEvaluator
|
2021-10-16 21:24:16 +08:00
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2021-10-16 22:21:49 +08:00
|
|
|
|
class PairEvaluator(DatasetEvaluator):
|
2021-10-16 21:24:16 +08:00
|
|
|
|
def __init__(self, cfg, output_dir=None):
|
2021-10-16 22:21:49 +08:00
|
|
|
|
self.cfg = cfg
|
|
|
|
|
self._output_dir = output_dir
|
|
|
|
|
self._cpu_device = torch.device('cpu')
|
|
|
|
|
self._predictions = []
|
2021-10-27 16:03:43 +08:00
|
|
|
|
self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98]
|
2021-10-16 21:24:16 +08:00
|
|
|
|
|
2021-10-16 22:21:49 +08:00
|
|
|
|
def reset(self):
|
|
|
|
|
self._predictions = []
|
2021-10-27 16:03:43 +08:00
|
|
|
|
|
2021-10-16 21:24:16 +08:00
|
|
|
|
def process(self, inputs, outputs):
|
2021-10-16 22:21:49 +08:00
|
|
|
|
embedding = outputs['features'].to(self._cpu_device)
|
|
|
|
|
embedding = embedding.view(embedding.size(0) * 2, -1)
|
|
|
|
|
embedding = normalize(embedding, axis=-1)
|
|
|
|
|
embed1 = embedding[0:len(embedding):2, :]
|
|
|
|
|
embed2 = embedding[1:len(embedding):2, :]
|
|
|
|
|
distances = torch.mul(embed1, embed2).sum(-1).numpy()
|
|
|
|
|
|
2021-10-18 13:57:08 +08:00
|
|
|
|
# print(distances)
|
2021-10-16 22:21:49 +08:00
|
|
|
|
prediction = {
|
|
|
|
|
'distances': distances,
|
|
|
|
|
'labels': inputs["targets"].to(self._cpu_device).numpy()
|
|
|
|
|
}
|
|
|
|
|
self._predictions.append(prediction)
|
2021-10-27 16:03:43 +08:00
|
|
|
|
|
2021-10-16 21:24:16 +08:00
|
|
|
|
def evaluate(self):
|
|
|
|
|
if comm.get_world_size() > 1:
|
|
|
|
|
comm.synchronize()
|
|
|
|
|
predictions = comm.gather(self._predictions, dst=0)
|
|
|
|
|
predictions = list(itertools.chain(*predictions))
|
|
|
|
|
|
2021-10-27 16:03:43 +08:00
|
|
|
|
if not comm.is_main_process():
|
2021-10-16 21:24:16 +08:00
|
|
|
|
return {}
|
|
|
|
|
else:
|
|
|
|
|
predictions = self._predictions
|
2021-10-27 16:03:43 +08:00
|
|
|
|
|
2021-10-16 22:21:49 +08:00
|
|
|
|
all_distances = []
|
|
|
|
|
all_labels = []
|
2021-10-16 21:24:16 +08:00
|
|
|
|
for prediction in predictions:
|
2021-10-16 22:21:49 +08:00
|
|
|
|
all_distances.append(prediction['distances'])
|
|
|
|
|
all_labels.append(prediction['labels'])
|
2021-10-27 16:03:43 +08:00
|
|
|
|
|
2021-10-16 22:21:49 +08:00
|
|
|
|
all_distances = np.concatenate(all_distances)
|
|
|
|
|
all_labels = np.concatenate(all_labels)
|
|
|
|
|
|
|
|
|
|
# 计算这3个总体值,还有给定阈值下的precision, recall, f1
|
|
|
|
|
acc = skmetrics.accuracy_score(all_labels, all_distances > 0.5)
|
|
|
|
|
ap = skmetrics.average_precision_score(all_labels, all_distances)
|
2021-10-27 16:03:43 +08:00
|
|
|
|
auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc
|
2021-10-16 22:21:49 +08:00
|
|
|
|
|
|
|
|
|
precisions = []
|
|
|
|
|
recalls = []
|
|
|
|
|
f1s = []
|
|
|
|
|
for thresh in self._threshold_list:
|
|
|
|
|
precision = skmetrics.precision_score(all_labels, all_distances >= thresh, zero_division=0)
|
|
|
|
|
recall = skmetrics.recall_score(all_labels, all_distances >= thresh, zero_division=0)
|
|
|
|
|
f1 = 2 * precision * recall / (precision + recall + 1e-12)
|
2021-10-27 16:03:43 +08:00
|
|
|
|
|
2021-10-16 22:21:49 +08:00
|
|
|
|
precisions.append(precision)
|
|
|
|
|
recalls.append(recall)
|
|
|
|
|
f1s.append(f1)
|
2021-10-27 16:03:43 +08:00
|
|
|
|
|
2021-10-16 21:24:16 +08:00
|
|
|
|
self._results = OrderedDict()
|
|
|
|
|
self._results['Acc'] = acc
|
|
|
|
|
self._results['Ap'] = ap
|
|
|
|
|
self._results['Auc'] = auc
|
|
|
|
|
self._results['Thresholds'] = self._threshold_list
|
|
|
|
|
self._results['Precisions'] = precisions
|
|
|
|
|
self._results['Recalls'] = recalls
|
|
|
|
|
self._results['F1_Scores'] = f1s
|
|
|
|
|
|
|
|
|
|
return copy.deepcopy(self._results)
|
|
|
|
|
|