fix bug in eval_metrics (#122)
parent
736eec1fd2
commit
4203b94643
|
@ -17,7 +17,8 @@ def calculate_confusion_matrix(pred, target):
|
||||||
assert len(pred_label) == len(target_label)
|
assert len(pred_label) == len(target_label)
|
||||||
confusion_matrix = torch.zeros(num_classes, num_classes)
|
confusion_matrix = torch.zeros(num_classes, num_classes)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
confusion_matrix[target_label.long(), pred_label.long()] += 1
|
for t, p in zip(target_label, pred_label):
|
||||||
|
confusion_matrix[t.long(), p.long()] += 1
|
||||||
return confusion_matrix
|
return confusion_matrix
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -65,22 +65,23 @@ def test_datasets_override_default(dataset_name):
|
||||||
def test_dataset_evaluation():
|
def test_dataset_evaluation():
|
||||||
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
||||||
dataset.data_infos = [
|
dataset.data_infos = [
|
||||||
|
dict(gt_label=0),
|
||||||
dict(gt_label=0),
|
dict(gt_label=0),
|
||||||
dict(gt_label=1),
|
dict(gt_label=1),
|
||||||
dict(gt_label=2),
|
dict(gt_label=2),
|
||||||
dict(gt_label=1),
|
dict(gt_label=1),
|
||||||
dict(gt_label=0)
|
dict(gt_label=0)
|
||||||
]
|
]
|
||||||
fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1],
|
fake_results = np.array([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
|
||||||
[0, 0, 1]])
|
[0, 0, 1], [0, 0, 1]])
|
||||||
eval_results = dataset.evaluate(
|
eval_results = dataset.evaluate(
|
||||||
fake_results, metric=['precision', 'recall', 'f1_score'])
|
fake_results, metric=['precision', 'recall', 'f1_score'])
|
||||||
assert eval_results['precision'] == pytest.approx(
|
assert eval_results['precision'] == pytest.approx(
|
||||||
(1 + 1 + 1 / 3) / 3 * 100.0)
|
(1 + 1 + 1 / 3) / 3 * 100.0)
|
||||||
assert eval_results['recall'] == pytest.approx(
|
assert eval_results['recall'] == pytest.approx(
|
||||||
(1 / 2 + 1 / 2 + 1) / 3 * 100.0)
|
(2 / 3 + 1 / 2 + 1) / 3 * 100.0)
|
||||||
assert eval_results['f1_score'] == pytest.approx(
|
assert eval_results['f1_score'] == pytest.approx(
|
||||||
(2 / 3 + 2 / 3 + 1 / 2) / 3 * 100.0)
|
(4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0)
|
||||||
|
|
||||||
|
|
||||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||||
|
|
Loading…
Reference in New Issue