Fix mIoU calculatiton range (#471)
* Fix fence(IoU) = 0 when training on PascalContextDataset59; * Add a test case in test_metrics() of tests/test_metrics.py to test the bug caused by torch.histc; * Update tests/test_metrics.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>pull/495/head
parent
789d1a142b
commit
fb24bf54b6
|
@ -57,11 +57,11 @@ def intersect_and_union(pred_label,
|
||||||
|
|
||||||
intersect = pred_label[pred_label == label]
|
intersect = pred_label[pred_label == label]
|
||||||
area_intersect = torch.histc(
|
area_intersect = torch.histc(
|
||||||
intersect.float(), bins=(num_classes), min=0, max=num_classes)
|
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||||
area_pred_label = torch.histc(
|
area_pred_label = torch.histc(
|
||||||
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
|
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||||
area_label = torch.histc(
|
area_label = torch.histc(
|
||||||
label.float(), bins=(num_classes), min=0, max=num_classes)
|
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||||
area_union = area_pred_label + area_label - area_intersect
|
area_union = area_pred_label + area_label - area_intersect
|
||||||
return area_intersect, area_union, area_pred_label, area_label
|
return area_intersect, area_union, area_pred_label, area_label
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,11 @@ def test_metrics():
|
||||||
ignore_index = 255
|
ignore_index = 255
|
||||||
results = np.random.randint(0, num_classes, size=pred_size)
|
results = np.random.randint(0, num_classes, size=pred_size)
|
||||||
label = np.random.randint(0, num_classes, size=pred_size)
|
label = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
|
||||||
|
# Test the availability of arg: ignore_index.
|
||||||
label[:, 2, 5:10] = ignore_index
|
label[:, 2, 5:10] = ignore_index
|
||||||
|
|
||||||
|
# Test the correctness of the implementation of mIoU calculation.
|
||||||
all_acc, acc, iou = eval_metrics(
|
all_acc, acc, iou = eval_metrics(
|
||||||
results, label, num_classes, ignore_index, metrics='mIoU')
|
results, label, num_classes, ignore_index, metrics='mIoU')
|
||||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
||||||
|
@ -72,7 +76,7 @@ def test_metrics():
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
assert np.allclose(acc, acc_l)
|
assert np.allclose(acc, acc_l)
|
||||||
assert np.allclose(iou, iou_l)
|
assert np.allclose(iou, iou_l)
|
||||||
|
# Test the correctness of the implementation of mDice calculation.
|
||||||
all_acc, acc, dice = eval_metrics(
|
all_acc, acc, dice = eval_metrics(
|
||||||
results, label, num_classes, ignore_index, metrics='mDice')
|
results, label, num_classes, ignore_index, metrics='mDice')
|
||||||
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
||||||
|
@ -80,7 +84,7 @@ def test_metrics():
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
assert np.allclose(acc, acc_l)
|
assert np.allclose(acc, acc_l)
|
||||||
assert np.allclose(dice, dice_l)
|
assert np.allclose(dice, dice_l)
|
||||||
|
# Test the correctness of the implementation of joint calculation.
|
||||||
all_acc, acc, iou, dice = eval_metrics(
|
all_acc, acc, iou, dice = eval_metrics(
|
||||||
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
|
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
|
||||||
assert all_acc == all_acc_l
|
assert all_acc == all_acc_l
|
||||||
|
@ -88,6 +92,8 @@ def test_metrics():
|
||||||
assert np.allclose(iou, iou_l)
|
assert np.allclose(iou, iou_l)
|
||||||
assert np.allclose(dice, dice_l)
|
assert np.allclose(dice, dice_l)
|
||||||
|
|
||||||
|
# Test the correctness of calculation when arg: num_classes is larger
|
||||||
|
# than the maximum value of input maps.
|
||||||
results = np.random.randint(0, 5, size=pred_size)
|
results = np.random.randint(0, 5, size=pred_size)
|
||||||
label = np.random.randint(0, 4, size=pred_size)
|
label = np.random.randint(0, 4, size=pred_size)
|
||||||
all_acc, acc, iou = eval_metrics(
|
all_acc, acc, iou = eval_metrics(
|
||||||
|
@ -121,6 +127,17 @@ def test_metrics():
|
||||||
assert dice[-1] == -1
|
assert dice[-1] == -1
|
||||||
assert iou[-1] == -1
|
assert iou[-1] == -1
|
||||||
|
|
||||||
|
# Test the bug which is caused by torch.histc.
|
||||||
|
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
|
||||||
|
# When the arg:bins is set to be same as arg:max,
|
||||||
|
# some channels of mIoU may be nan.
|
||||||
|
results = np.array([np.repeat(31, 59)])
|
||||||
|
label = np.array([np.arange(59)])
|
||||||
|
num_classes = 59
|
||||||
|
all_acc, acc, iou = eval_metrics(
|
||||||
|
results, label, num_classes, ignore_index=255, metrics='mIoU')
|
||||||
|
assert not np.any(np.isnan(iou))
|
||||||
|
|
||||||
|
|
||||||
def test_mean_iou():
|
def test_mean_iou():
|
||||||
pred_size = (10, 30, 30)
|
pred_size = (10, 30, 30)
|
||||||
|
@ -182,7 +199,7 @@ def test_filename_inputs():
|
||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
pred_size = (10, 512, 1024)
|
pred_size = (10, 30, 30)
|
||||||
num_classes = 19
|
num_classes = 19
|
||||||
ignore_index = 255
|
ignore_index = 255
|
||||||
results = np.random.randint(0, num_classes, size=pred_size)
|
results = np.random.randint(0, num_classes, size=pred_size)
|
||||||
|
|
Loading…
Reference in New Issue