refactor evaluator

pull/608/head
zuchen.wang 2021-10-27 16:03:43 +08:00
parent 6687df06e0
commit 3880b23ab4
4 changed files with 63 additions and 11 deletions

View File

@ -22,11 +22,11 @@ class PairEvaluator(DatasetEvaluator):
self._output_dir = output_dir
self._cpu_device = torch.device('cpu')
self._predictions = []
self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98]
self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98]
def reset(self):
self._predictions = []
def process(self, inputs, outputs):
embedding = outputs['features'].to(self._cpu_device)
embedding = embedding.view(embedding.size(0) * 2, -1)
@ -41,31 +41,31 @@ class PairEvaluator(DatasetEvaluator):
'labels': inputs["targets"].to(self._cpu_device).numpy()
}
self._predictions.append(prediction)
def evaluate(self):
if comm.get_world_size() > 1:
comm.synchronize()
predictions = comm.gather(self._predictions, dst=0)
predictions = list(itertools.chain(*predictions))
if not comm.is_main_process():
if not comm.is_main_process():
return {}
else:
predictions = self._predictions
all_distances = []
all_labels = []
for prediction in predictions:
all_distances.append(prediction['distances'])
all_labels.append(prediction['labels'])
all_distances = np.concatenate(all_distances)
all_labels = np.concatenate(all_labels)
# 计算这3个总体值还有给定阈值下的precision, recall, f1
acc = skmetrics.accuracy_score(all_labels, all_distances > 0.5)
ap = skmetrics.average_precision_score(all_labels, all_distances)
auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc
auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc
precisions = []
recalls = []
@ -74,11 +74,11 @@ class PairEvaluator(DatasetEvaluator):
precision = skmetrics.precision_score(all_labels, all_distances >= thresh, zero_division=0)
recall = skmetrics.recall_score(all_labels, all_distances >= thresh, zero_division=0)
f1 = 2 * precision * recall / (precision + recall + 1e-12)
precisions.append(precision)
recalls.append(recall)
f1s.append(f1)
self._results = OrderedDict()
self._results['Acc'] = acc
self._results['Ap'] = ap

View File

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/27 15:56:52
# @Author : zuchen.wang@vipshop.com
# @File : pair_score_evaluator.py
# coding: utf-8
import logging
import torch
from fastreid.modeling.losses.utils import normalize
from .pair_evaluator import PairEvaluator
logger = logging.getLogger(__name__)
class PairScoreEvaluator(PairEvaluator):
def process(self, inputs, outputs):
prediction = {
'distances': outputs['cls_outputs'].to(self._cpu_device).numpy(),
'labels': inputs["targets"].to(self._cpu_device).numpy()
}
self._predictions.append(prediction)

View File

@ -0,0 +1,28 @@
# coding: utf-8
import logging
import torch
from fastreid.modeling.losses.utils import normalize
from .pair_evaluator import PairEvaluator
logger = logging.getLogger(__name__)
class PairDistanceEvaluator(PairEvaluator):
def process(self, inputs, outputs):
embedding = outputs['features'].to(self._cpu_device)
embedding = embedding.view(embedding.size(0) * 2, -1)
embedding = normalize(embedding, axis=-1)
embed1 = embedding[0:len(embedding):2, :]
embed2 = embedding[1:len(embedding):2, :]
distances = torch.mul(embed1, embed2).sum(-1).numpy()
# print(distances)
prediction = {
'distances': distances,
'labels': inputs["targets"].to(self._cpu_device).numpy()
}
self._predictions.append(prediction)

View File

@ -13,7 +13,7 @@ from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.utils import comm
from fastreid.data.transforms import build_transforms
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
from fastreid.evaluation.pair_evaluator import PairEvaluator
from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator
from projects.FastShoe.fastshoe.data import PairDataset
@ -85,4 +85,4 @@ class PairTrainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
data_loader = cls.build_test_loader(cfg, dataset_name)
return data_loader, PairEvaluator(cfg, output_dir)
return data_loader, PairScoreEvaluator(cfg, output_dir)