fix acc and iou compute nan problem (#116)
* fix acc and iou compute nan problem * fix acc and iou compute nan problem * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * Update mmseg/core/evaluation/mean_iou.py * Update mean_iou.py * Update mean_iou.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>pull/1801/head
parent
276e9ca75e
commit
cac4138f99
|
@ -34,7 +34,7 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
|
|||
return area_intersect, area_union, area_pred_label, area_label
|
||||
|
||||
|
||||
def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
||||
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
|
||||
"""Calculate Intersection and Union (IoU)
|
||||
|
||||
Args:
|
||||
|
@ -42,6 +42,8 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
|||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
|
||||
num_classes (int): Number of categories
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||
by the numbers defined by the user. Default: None.
|
||||
|
||||
Returns:
|
||||
float: Overall accuracy on all images.
|
||||
|
@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
|||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||
acc = total_area_intersect / total_area_label
|
||||
iou = total_area_intersect / total_area_union
|
||||
|
||||
if nan_to_num is not None:
|
||||
return all_acc, np.nan_to_num(acc, nan=nan_to_num), \
|
||||
np.nan_to_num(iou, nan=nan_to_num)
|
||||
return all_acc, acc, iou
|
||||
|
|
|
@ -54,3 +54,10 @@ def test_mean_iou():
|
|||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(iou, iou_l)
|
||||
|
||||
results = np.random.randint(0, 5, size=pred_size)
|
||||
label = np.random.randint(0, 4, size=pred_size)
|
||||
all_acc, acc, iou = mean_iou(
|
||||
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||
assert acc[-1] == -1
|
||||
assert iou[-1] == -1
|
||||
|
|
Loading…
Reference in New Issue