diff --git a/mmcls/models/losses/eval_metrics.py b/mmcls/models/losses/eval_metrics.py index 52222aec6..448cb8b6a 100644 --- a/mmcls/models/losses/eval_metrics.py +++ b/mmcls/models/losses/eval_metrics.py @@ -34,7 +34,7 @@ def precision(pred, target): confusion_matrix = calculate_confusion_matrix(pred, target) with torch.no_grad(): res = confusion_matrix.diag() / torch.clamp( - confusion_matrix.sum(1), min=1) + confusion_matrix.sum(0), min=1) res = res.mean().item() * 100 return res @@ -52,7 +52,7 @@ def recall(pred, target): confusion_matrix = calculate_confusion_matrix(pred, target) with torch.no_grad(): res = confusion_matrix.diag() / torch.clamp( - confusion_matrix.sum(0), min=1) + confusion_matrix.sum(1), min=1) res = res.mean().item() * 100 return res diff --git a/tests/test_dataset.py b/tests/test_dataset.py index bf2666300..f5e28f379 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -68,16 +68,19 @@ def test_dataset_evaluation(): dict(gt_label=0), dict(gt_label=1), dict(gt_label=2), - dict(gt_label=1) + dict(gt_label=1), + dict(gt_label=0) ] - fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]]) + fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], + [0, 0, 1]]) eval_results = dataset.evaluate( fake_results, metric=['precision', 'recall', 'f1_score']) assert eval_results['precision'] == pytest.approx( - (1 + 1 + 1 / 2) / 3 * 100.0) - assert eval_results['recall'] == pytest.approx((1 + 1 / 2 + 1) / 3 * 100.0) + (1 + 1 + 1 / 3) / 3 * 100.0) + assert eval_results['recall'] == pytest.approx( + (1 / 2 + 1 / 2 + 1) / 3 * 100.0) assert eval_results['f1_score'] == pytest.approx( - (1 + 2 / 3 + 2 / 3) / 3 * 100.0) + (2 / 3 + 2 / 3 + 1 / 2) / 3 * 100.0) @patch.multiple(BaseDataset, __abstractmethods__=set())