fix acc and iou compute nan problem (#116)

* fix acc and iou compute nan problem

* fix acc and iou compute nan problem

* add nan_to_num args for mean_iou

* add nan_to_num args for mean_iou

* add nan_to_num args for mean_iou

* add nan_to_num args for mean_iou

* add nan_to_num args for mean_iou

* Update mmseg/core/evaluation/mean_iou.py

* Update mean_iou.py

* Update mean_iou.py

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
pull/1801/head
sshuair 2020-09-22 01:04:46 +08:00 committed by GitHub
parent 276e9ca75e
commit cac4138f99
2 changed files with 13 additions and 2 deletions

View File

@ -34,7 +34,7 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
return area_intersect, area_union, area_pred_label, area_label
def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
"""Calculate Intersection and Union (IoU)
Args:
@ -42,6 +42,8 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
num_classes (int): Number of categories
ignore_index (int): Index that will be ignored in evaluation.
nan_to_num (int, optional): If specified, NaN values will be replaced
by the numbers defined by the user. Default: None.
Returns:
float: Overall accuracy on all images.
@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
all_acc = total_area_intersect.sum() / total_area_label.sum()
acc = total_area_intersect / total_area_label
iou = total_area_intersect / total_area_union
if nan_to_num is not None:
return all_acc, np.nan_to_num(acc, nan=nan_to_num), \
np.nan_to_num(iou, nan=nan_to_num)
return all_acc, acc, iou

View File

@ -54,3 +54,10 @@ def test_mean_iou():
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
all_acc, acc, iou = mean_iou(
results, label, num_classes, ignore_index=255, nan_to_num=-1)
assert acc[-1] == -1
assert iou[-1] == -1