fix bug in recall and precision (#112)

pull/115/head
LXXXXR 2020-12-09 16:27:42 +08:00 committed by GitHub
parent 92438da12a
commit b1e91f256b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 7 deletions

View File

@ -34,7 +34,7 @@ def precision(pred, target):
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
res = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(1), min=1)
confusion_matrix.sum(0), min=1)
res = res.mean().item() * 100
return res
@ -52,7 +52,7 @@ def recall(pred, target):
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
res = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(0), min=1)
confusion_matrix.sum(1), min=1)
res = res.mean().item() * 100
return res

View File

@ -68,16 +68,19 @@ def test_dataset_evaluation():
dict(gt_label=0),
dict(gt_label=1),
dict(gt_label=2),
dict(gt_label=1)
dict(gt_label=1),
dict(gt_label=0)
]
fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]])
fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1],
[0, 0, 1]])
eval_results = dataset.evaluate(
fake_results, metric=['precision', 'recall', 'f1_score'])
assert eval_results['precision'] == pytest.approx(
(1 + 1 + 1 / 2) / 3 * 100.0)
assert eval_results['recall'] == pytest.approx((1 + 1 / 2 + 1) / 3 * 100.0)
(1 + 1 + 1 / 3) / 3 * 100.0)
assert eval_results['recall'] == pytest.approx(
(1 / 2 + 1 / 2 + 1) / 3 * 100.0)
assert eval_results['f1_score'] == pytest.approx(
(1 + 2 / 3 + 2 / 3) / 3 * 100.0)
(2 / 3 + 2 / 3 + 1 / 2) / 3 * 100.0)
@patch.multiple(BaseDataset, __abstractmethods__=set())