fix bug in recall and precision (#112)
parent
92438da12a
commit
b1e91f256b
|
@ -34,7 +34,7 @@ def precision(pred, target):
|
|||
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||
with torch.no_grad():
|
||||
res = confusion_matrix.diag() / torch.clamp(
|
||||
confusion_matrix.sum(1), min=1)
|
||||
confusion_matrix.sum(0), min=1)
|
||||
res = res.mean().item() * 100
|
||||
return res
|
||||
|
||||
|
@ -52,7 +52,7 @@ def recall(pred, target):
|
|||
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||
with torch.no_grad():
|
||||
res = confusion_matrix.diag() / torch.clamp(
|
||||
confusion_matrix.sum(0), min=1)
|
||||
confusion_matrix.sum(1), min=1)
|
||||
res = res.mean().item() * 100
|
||||
return res
|
||||
|
||||
|
|
|
@ -68,16 +68,19 @@ def test_dataset_evaluation():
|
|||
dict(gt_label=0),
|
||||
dict(gt_label=1),
|
||||
dict(gt_label=2),
|
||||
dict(gt_label=1)
|
||||
dict(gt_label=1),
|
||||
dict(gt_label=0)
|
||||
]
|
||||
fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]])
|
||||
fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1],
|
||||
[0, 0, 1]])
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results, metric=['precision', 'recall', 'f1_score'])
|
||||
assert eval_results['precision'] == pytest.approx(
|
||||
(1 + 1 + 1 / 2) / 3 * 100.0)
|
||||
assert eval_results['recall'] == pytest.approx((1 + 1 / 2 + 1) / 3 * 100.0)
|
||||
(1 + 1 + 1 / 3) / 3 * 100.0)
|
||||
assert eval_results['recall'] == pytest.approx(
|
||||
(1 / 2 + 1 / 2 + 1) / 3 * 100.0)
|
||||
assert eval_results['f1_score'] == pytest.approx(
|
||||
(1 + 2 / 3 + 2 / 3) / 3 * 100.0)
|
||||
(2 / 3 + 2 / 3 + 1 / 2) / 3 * 100.0)
|
||||
|
||||
|
||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||
|
|
Loading…
Reference in New Issue