fix bug in eval_metrics (#122)

pull/127/head
LXXXXR 2020-12-23 16:20:47 +08:00 committed by GitHub
parent 736eec1fd2
commit 4203b94643
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -17,7 +17,8 @@ def calculate_confusion_matrix(pred, target):
assert len(pred_label) == len(target_label)
confusion_matrix = torch.zeros(num_classes, num_classes)
with torch.no_grad():
confusion_matrix[target_label.long(), pred_label.long()] += 1
for t, p in zip(target_label, pred_label):
confusion_matrix[t.long(), p.long()] += 1
return confusion_matrix

View File

@ -65,22 +65,23 @@ def test_datasets_override_default(dataset_name):
def test_dataset_evaluation():
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
dataset.data_infos = [
dict(gt_label=0),
dict(gt_label=0),
dict(gt_label=1),
dict(gt_label=2),
dict(gt_label=1),
dict(gt_label=0)
]
fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1],
[0, 0, 1]])
fake_results = np.array([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
[0, 0, 1], [0, 0, 1]])
eval_results = dataset.evaluate(
fake_results, metric=['precision', 'recall', 'f1_score'])
assert eval_results['precision'] == pytest.approx(
(1 + 1 + 1 / 3) / 3 * 100.0)
assert eval_results['recall'] == pytest.approx(
(1 / 2 + 1 / 2 + 1) / 3 * 100.0)
(2 / 3 + 1 / 2 + 1) / 3 * 100.0)
assert eval_results['f1_score'] == pytest.approx(
(2 / 3 + 2 / 3 + 1 / 2) / 3 * 100.0)
(4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0)
@patch.multiple(BaseDataset, __abstractmethods__=set())