From 3be6d2c4398f9268c857a1efde93cbb2d2b02fa6 Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Wed, 10 Nov 2021 17:01:59 +0800 Subject: [PATCH] fix: pair_score_evaluator.py --- fastreid/evaluation/pair_score_evaluator.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fastreid/evaluation/pair_score_evaluator.py b/fastreid/evaluation/pair_score_evaluator.py index 08c29f8..0793d16 100644 --- a/fastreid/evaluation/pair_score_evaluator.py +++ b/fastreid/evaluation/pair_score_evaluator.py @@ -27,8 +27,18 @@ class PairScoreEvaluator(DatasetEvaluator): self._threshold_list = [x / 10 for x in range(5, 9)] + [x / 100 for x in range(90, 100)] def process(self, inputs, outputs): + scores = outputs['cls_outputs'] + if scores.dim() > 1: + # 全连接层输出为2类 + if scores.shape[1] > 1: + scores = torch.softmax(scores, dim=1) + scores = scores[:, 1] + else: # 全连接层输出为1类 + scores = torch.sigmoid(scores) + scores = torch.squeeze(scores, 1) + prediction = { - 'score': outputs['cls_outputs'][:, 1].to(self._cpu_device).numpy(), + 'scores': scores.to(self._cpu_device).numpy(), 'labels': inputs["targets"].to(self._cpu_device).numpy() } self._predictions.append(prediction) @@ -47,7 +57,7 @@ class PairScoreEvaluator(DatasetEvaluator): all_scores = [] all_labels = [] for prediction in predictions: - all_scores.append(prediction['score']) + all_scores.append(prediction['scores']) all_labels.append(prediction['labels']) all_scores = np.concatenate(all_scores)