Fix mIoU calculatiton range (#471)
* Fix fence(IoU) = 0 when training on PascalContextDataset59; * Add a test case in test_metrics() of tests/test_metrics.py to test the bug caused by torch.histc; * Update tests/test_metrics.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>pull/495/head
parent
789d1a142b
commit
fb24bf54b6
|
@ -57,11 +57,11 @@ def intersect_and_union(pred_label,
|
|||
|
||||
intersect = pred_label[pred_label == label]
|
||||
area_intersect = torch.histc(
|
||||
intersect.float(), bins=(num_classes), min=0, max=num_classes)
|
||||
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||
area_pred_label = torch.histc(
|
||||
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
|
||||
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||
area_label = torch.histc(
|
||||
label.float(), bins=(num_classes), min=0, max=num_classes)
|
||||
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
||||
area_union = area_pred_label + area_label - area_intersect
|
||||
return area_intersect, area_union, area_pred_label, area_label
|
||||
|
||||
|
|
|
@ -64,7 +64,11 @@ def test_metrics():
|
|||
ignore_index = 255
|
||||
results = np.random.randint(0, num_classes, size=pred_size)
|
||||
label = np.random.randint(0, num_classes, size=pred_size)
|
||||
|
||||
# Test the availability of arg: ignore_index.
|
||||
label[:, 2, 5:10] = ignore_index
|
||||
|
||||
# Test the correctness of the implementation of mIoU calculation.
|
||||
all_acc, acc, iou = eval_metrics(
|
||||
results, label, num_classes, ignore_index, metrics='mIoU')
|
||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
||||
|
@ -72,7 +76,7 @@ def test_metrics():
|
|||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(iou, iou_l)
|
||||
|
||||
# Test the correctness of the implementation of mDice calculation.
|
||||
all_acc, acc, dice = eval_metrics(
|
||||
results, label, num_classes, ignore_index, metrics='mDice')
|
||||
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
||||
|
@ -80,7 +84,7 @@ def test_metrics():
|
|||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(dice, dice_l)
|
||||
|
||||
# Test the correctness of the implementation of joint calculation.
|
||||
all_acc, acc, iou, dice = eval_metrics(
|
||||
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
|
||||
assert all_acc == all_acc_l
|
||||
|
@ -88,6 +92,8 @@ def test_metrics():
|
|||
assert np.allclose(iou, iou_l)
|
||||
assert np.allclose(dice, dice_l)
|
||||
|
||||
# Test the correctness of calculation when arg: num_classes is larger
|
||||
# than the maximum value of input maps.
|
||||
results = np.random.randint(0, 5, size=pred_size)
|
||||
label = np.random.randint(0, 4, size=pred_size)
|
||||
all_acc, acc, iou = eval_metrics(
|
||||
|
@ -121,6 +127,17 @@ def test_metrics():
|
|||
assert dice[-1] == -1
|
||||
assert iou[-1] == -1
|
||||
|
||||
# Test the bug which is caused by torch.histc.
|
||||
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
|
||||
# When the arg:bins is set to be same as arg:max,
|
||||
# some channels of mIoU may be nan.
|
||||
results = np.array([np.repeat(31, 59)])
|
||||
label = np.array([np.arange(59)])
|
||||
num_classes = 59
|
||||
all_acc, acc, iou = eval_metrics(
|
||||
results, label, num_classes, ignore_index=255, metrics='mIoU')
|
||||
assert not np.any(np.isnan(iou))
|
||||
|
||||
|
||||
def test_mean_iou():
|
||||
pred_size = (10, 30, 30)
|
||||
|
@ -182,7 +199,7 @@ def test_filename_inputs():
|
|||
filenames.append(filename)
|
||||
return filenames
|
||||
|
||||
pred_size = (10, 512, 1024)
|
||||
pred_size = (10, 30, 30)
|
||||
num_classes = 19
|
||||
ignore_index = 255
|
||||
results = np.random.randint(0, num_classes, size=pred_size)
|
||||
|
|
Loading…
Reference in New Issue