fix: pair_score_evaluator.py

pull/608/head
zuchen.wang 2021-11-10 17:01:59 +08:00
parent 93edc0cf53
commit 3be6d2c439
1 changed files with 12 additions and 2 deletions

View File

@ -27,8 +27,18 @@ class PairScoreEvaluator(DatasetEvaluator):
self._threshold_list = [x / 10 for x in range(5, 9)] + [x / 100 for x in range(90, 100)]
def process(self, inputs, outputs):
scores = outputs['cls_outputs']
if scores.dim() > 1:
# 全连接层输出为2类
if scores.shape[1] > 1:
scores = torch.softmax(scores, dim=1)
scores = scores[:, 1]
else: # 全连接层输出为1类
scores = torch.sigmoid(scores)
scores = torch.squeeze(scores, 1)
prediction = {
'score': outputs['cls_outputs'][:, 1].to(self._cpu_device).numpy(),
'scores': scores.to(self._cpu_device).numpy(),
'labels': inputs["targets"].to(self._cpu_device).numpy()
}
self._predictions.append(prediction)
@ -47,7 +57,7 @@ class PairScoreEvaluator(DatasetEvaluator):
all_scores = []
all_labels = []
for prediction in predictions:
all_scores.append(prediction['score'])
all_scores.append(prediction['scores'])
all_labels.append(prediction['labels'])
all_scores = np.concatenate(all_scores)