fix acc and iou compute nan problem (#116)
* fix acc and iou compute nan problem * fix acc and iou compute nan problem * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * add nan_to_num args for mean_iou * Update mmseg/core/evaluation/mean_iou.py * Update mean_iou.py * Update mean_iou.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>pull/1801/head
parent
276e9ca75e
commit
cac4138f99
mmseg/core/evaluation
tests
|
@ -34,7 +34,7 @@ def intersect_and_union(pred_label, label, num_classes, ignore_index):
|
||||||
return area_intersect, area_union, area_pred_label, area_label
|
return area_intersect, area_union, area_pred_label, area_label
|
||||||
|
|
||||||
|
|
||||||
def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None):
|
||||||
"""Calculate Intersection and Union (IoU)
|
"""Calculate Intersection and Union (IoU)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -42,6 +42,8 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
||||||
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
|
gt_seg_maps (list[ndarray]): list of ground truth segmentation maps
|
||||||
num_classes (int): Number of categories
|
num_classes (int): Number of categories
|
||||||
ignore_index (int): Index that will be ignored in evaluation.
|
ignore_index (int): Index that will be ignored in evaluation.
|
||||||
|
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||||
|
by the numbers defined by the user. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: Overall accuracy on all images.
|
float: Overall accuracy on all images.
|
||||||
|
@ -66,5 +68,7 @@ def mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
||||||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||||
acc = total_area_intersect / total_area_label
|
acc = total_area_intersect / total_area_label
|
||||||
iou = total_area_intersect / total_area_union
|
iou = total_area_intersect / total_area_union
|
||||||
|
if nan_to_num is not None:
|
||||||
|
return all_acc, np.nan_to_num(acc, nan=nan_to_num), \
|
||||||
|
np.nan_to_num(iou, nan=nan_to_num)
|
||||||
return all_acc, acc, iou
|
return all_acc, acc, iou
|
||||||
|
|
|
@ -54,3 +54,10 @@ def test_mean_iou():
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
assert np.allclose(acc, acc_l)
|
assert np.allclose(acc, acc_l)
|
||||||
assert np.allclose(iou, iou_l)
|
assert np.allclose(iou, iou_l)
|
||||||
|
|
||||||
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
|
all_acc, acc, iou = mean_iou(
|
||||||
|
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||||
|
assert acc[-1] == -1
|
||||||
|
assert iou[-1] == -1
|
||||||
|
|
Loading…
Reference in New Issue