[Fix] Fix the bug that when all pixels in an image is ignored, the ac… (#1336)
* [Fix] Fix the bug that when all pixels in an image is ignored, the accuracy calculation raises ZeroDivisionError * use eps * all close * add ignore test * add epspull/1801/head
parent
f5e1f2e82d
commit
6665b42159
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
|
@ -46,10 +47,13 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
|
|||
correct = correct & (pred_value > thresh).t()
|
||||
correct = correct[:, target != ignore_index]
|
||||
res = []
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(
|
||||
correct_k.mul_(100.0 / target[target != ignore_index].numel()))
|
||||
# Avoid causing ZeroDivisionError when all pixels
|
||||
# of an image are ignored
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
|
||||
total_num = target[target != ignore_index].numel() + eps
|
||||
res.append(correct_k.mul_(100.0 / total_num))
|
||||
return res[0] if return_single else res
|
||||
|
||||
|
||||
|
|
|
@ -56,50 +56,56 @@ def test_accuracy():
|
|||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=None)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert acc.item() == 100
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for ignore_index with a wrong prediction of that index
|
||||
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=1)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert acc.item() == 100
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for ignore_index 1 with a wrong prediction of other index
|
||||
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=1)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert acc.item() == 75
|
||||
assert torch.allclose(acc, torch.tensor(75.0))
|
||||
|
||||
# test for ignore_index 4 with a wrong prediction of other index
|
||||
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=4)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert acc.item() == 80
|
||||
assert torch.allclose(acc, torch.tensor(80.0))
|
||||
|
||||
# test for ignoring all the pixels
|
||||
true_label = torch.Tensor([2, 2, 2, 2, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=2)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for top1
|
||||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert acc.item() == 100
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for top1 with score thresh=0.8
|
||||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, thresh=0.8)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert acc.item() == 40
|
||||
assert torch.allclose(acc, torch.tensor(40.0))
|
||||
|
||||
# test for top2
|
||||
accuracy = Accuracy(topk=2)
|
||||
label = torch.Tensor([3, 2, 0, 0, 2]).long()
|
||||
acc = accuracy(pred, label)
|
||||
assert acc.item() == 100
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for both top1 and top2
|
||||
accuracy = Accuracy(topk=(1, 2))
|
||||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||
acc = accuracy(pred, true_label)
|
||||
for a in acc:
|
||||
assert a.item() == 100
|
||||
assert torch.allclose(a, torch.tensor(100.0))
|
||||
|
||||
# topk is larger than pred class number
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
Loading…
Reference in New Issue