mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Make accuracy take into account ignore_index (#1259)
* make accuracy take into account ignore_index * add UT for accuracy
This commit is contained in:
parent
0934a57f4f
commit
346f70da5f
@ -261,5 +261,6 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||||||
weight=seg_weight,
|
weight=seg_weight,
|
||||||
ignore_index=self.ignore_index)
|
ignore_index=self.ignore_index)
|
||||||
|
|
||||||
loss['acc_seg'] = accuracy(seg_logit, seg_label)
|
loss['acc_seg'] = accuracy(
|
||||||
|
seg_logit, seg_label, ignore_index=self.ignore_index)
|
||||||
return loss
|
return loss
|
||||||
|
@ -264,7 +264,8 @@ class PointHead(BaseCascadeDecodeHead):
|
|||||||
loss['point' + loss_module.loss_name] = loss_module(
|
loss['point' + loss_module.loss_name] = loss_module(
|
||||||
point_logits, point_label, ignore_index=self.ignore_index)
|
point_logits, point_label, ignore_index=self.ignore_index)
|
||||||
|
|
||||||
loss['acc_point'] = accuracy(point_logits, point_label)
|
loss['acc_point'] = accuracy(
|
||||||
|
point_logits, point_label, ignore_index=self.ignore_index)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def get_points_train(self, seg_logits, uncertainty_func, cfg):
|
def get_points_train(self, seg_logits, uncertainty_func, cfg):
|
||||||
|
@ -2,12 +2,13 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def accuracy(pred, target, topk=1, thresh=None):
|
def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
|
||||||
"""Calculate accuracy according to the prediction and target.
|
"""Calculate accuracy according to the prediction and target.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
||||||
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
||||||
|
ignore_index (int | None): The label index to be ignored. Default: None
|
||||||
topk (int | tuple[int], optional): If the predictions in ``topk``
|
topk (int | tuple[int], optional): If the predictions in ``topk``
|
||||||
matches the target, the predictions will be regarded as
|
matches the target, the predictions will be regarded as
|
||||||
correct ones. Defaults to 1.
|
correct ones. Defaults to 1.
|
||||||
@ -43,17 +44,19 @@ def accuracy(pred, target, topk=1, thresh=None):
|
|||||||
if thresh is not None:
|
if thresh is not None:
|
||||||
# Only prediction values larger than thresh are counted as correct
|
# Only prediction values larger than thresh are counted as correct
|
||||||
correct = correct & (pred_value > thresh).t()
|
correct = correct & (pred_value > thresh).t()
|
||||||
|
correct = correct[:, target != ignore_index]
|
||||||
res = []
|
res = []
|
||||||
for k in topk:
|
for k in topk:
|
||||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||||
res.append(correct_k.mul_(100.0 / target.numel()))
|
res.append(
|
||||||
|
correct_k.mul_(100.0 / target[target != ignore_index].numel()))
|
||||||
return res[0] if return_single else res
|
return res[0] if return_single else res
|
||||||
|
|
||||||
|
|
||||||
class Accuracy(nn.Module):
|
class Accuracy(nn.Module):
|
||||||
"""Accuracy calculation module."""
|
"""Accuracy calculation module."""
|
||||||
|
|
||||||
def __init__(self, topk=(1, ), thresh=None):
|
def __init__(self, topk=(1, ), thresh=None, ignore_index=None):
|
||||||
"""Module to calculate the accuracy.
|
"""Module to calculate the accuracy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -65,6 +68,7 @@ class Accuracy(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.thresh = thresh
|
self.thresh = thresh
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, target):
|
||||||
"""Forward function to calculate accuracy.
|
"""Forward function to calculate accuracy.
|
||||||
@ -76,4 +80,5 @@ class Accuracy(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[float]: The accuracies under different topk criterions.
|
tuple[float]: The accuracies under different topk criterions.
|
||||||
"""
|
"""
|
||||||
return accuracy(pred, target, self.topk, self.thresh)
|
return accuracy(pred, target, self.topk, self.thresh,
|
||||||
|
self.ignore_index)
|
||||||
|
@ -52,6 +52,30 @@ def test_accuracy():
|
|||||||
pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
|
pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
|
||||||
[0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
|
[0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
|
||||||
[0.0, 0.0, 0.99, 0]])
|
[0.0, 0.0, 0.99, 0]])
|
||||||
|
# test for ignore_index
|
||||||
|
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||||
|
accuracy = Accuracy(topk=1, ignore_index=None)
|
||||||
|
acc = accuracy(pred, true_label)
|
||||||
|
assert acc.item() == 100
|
||||||
|
|
||||||
|
# test for ignore_index with a wrong prediction of that index
|
||||||
|
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
|
||||||
|
accuracy = Accuracy(topk=1, ignore_index=1)
|
||||||
|
acc = accuracy(pred, true_label)
|
||||||
|
assert acc.item() == 100
|
||||||
|
|
||||||
|
# test for ignore_index 1 with a wrong prediction of other index
|
||||||
|
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
|
||||||
|
accuracy = Accuracy(topk=1, ignore_index=1)
|
||||||
|
acc = accuracy(pred, true_label)
|
||||||
|
assert acc.item() == 75
|
||||||
|
|
||||||
|
# test for ignore_index 4 with a wrong prediction of other index
|
||||||
|
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
|
||||||
|
accuracy = Accuracy(topk=1, ignore_index=4)
|
||||||
|
acc = accuracy(pred, true_label)
|
||||||
|
assert acc.item() == 80
|
||||||
|
|
||||||
# test for top1
|
# test for top1
|
||||||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||||
accuracy = Accuracy(topk=1)
|
accuracy = Accuracy(topk=1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user