Fix mIoU calculatiton range (#471)

* Fix fence(IoU) = 0 when training on PascalContextDataset59;

* Add a test case in test_metrics() of tests/test_metrics.py to test the bug caused by torch.histc;

* Update tests/test_metrics.py

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
pull/495/head
sennnnn 2021-04-14 23:37:23 +08:00 committed by GitHub
parent 789d1a142b
commit fb24bf54b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 6 deletions

View File

@ -57,11 +57,11 @@ def intersect_and_union(pred_label,
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0, max=num_classes)
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes)
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label

View File

@ -64,7 +64,11 @@ def test_metrics():
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size)
# Test the availability of arg: ignore_index.
label[:, 2, 5:10] = ignore_index
# Test the correctness of the implementation of mIoU calculation.
all_acc, acc, iou = eval_metrics(
results, label, num_classes, ignore_index, metrics='mIoU')
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
@ -72,7 +76,7 @@ def test_metrics():
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)
# Test the correctness of the implementation of mDice calculation.
all_acc, acc, dice = eval_metrics(
results, label, num_classes, ignore_index, metrics='mDice')
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
@ -80,7 +84,7 @@ def test_metrics():
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(dice, dice_l)
# Test the correctness of the implementation of joint calculation.
all_acc, acc, iou, dice = eval_metrics(
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
assert all_acc == all_acc_l
@ -88,6 +92,8 @@ def test_metrics():
assert np.allclose(iou, iou_l)
assert np.allclose(dice, dice_l)
# Test the correctness of calculation when arg: num_classes is larger
# than the maximum value of input maps.
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
all_acc, acc, iou = eval_metrics(
@ -121,6 +127,17 @@ def test_metrics():
assert dice[-1] == -1
assert iou[-1] == -1
# Test the bug which is caused by torch.histc.
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
# When the arg:bins is set to be same as arg:max,
# some channels of mIoU may be nan.
results = np.array([np.repeat(31, 59)])
label = np.array([np.arange(59)])
num_classes = 59
all_acc, acc, iou = eval_metrics(
results, label, num_classes, ignore_index=255, metrics='mIoU')
assert not np.any(np.isnan(iou))
def test_mean_iou():
pred_size = (10, 30, 30)
@ -182,7 +199,7 @@ def test_filename_inputs():
filenames.append(filename)
return filenames
pred_size = (10, 512, 1024)
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)