[Fix] Switch order of `reduce_zero_label` and applying `label_map` (#2500)

## Motivation

I want to fix a bug through this PR. The bug occurs when two options --
`reduce_zero_label=True`, and custom classes are used.
`reduce_zero_label` remaps the GT seg labels by remapping the zero-class
to 255 which is ignored. Conceptually, this should occur *before* the
`label_map` is applied, which maps *already reduced labels*. However,
currently, the `label_map` is applied before the zero label is reduced.

## Modification

The modification is simple:
- I've just interchanged the order of the two operations by moving 4
lines from bottom to top.
- I've added a test that passes when the fix is introduced, and fails on
the original `master` branch.

## BC-breaking (Optional)

I do not anticipate this change braking any backward-compatibility.

## Checklist

- [x] Pre-commit or other linting tools are used to fix the potential
lint issues.
  - _I've fixed all linting/pre-commit errors._
- [x] The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
  - _I've added a unit test._ 
- [x] If the modification has potential influence on downstream
projects, this PR should be tested with downstream projects, like MMDet
or MMDet3D.
  - _I don't think this change affects MMDet or MMDet3D._
- [x] The documentation has been modified accordingly, like docstring or
example tutorials.
- _This change fixes an existing bug and doesn't require modifying any
documentation/docstring._
pull/2515/head
Siddharth Ancha 2023-01-19 02:01:40 -05:00 committed by GitHub
parent 6cb7fe0c51
commit 5d49918b3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 10 deletions

View File

@ -63,14 +63,14 @@ def intersect_and_union(pred_label,
else:
label = torch.from_numpy(label)
if label_map is not None:
label_copy = label.clone()
for old_id, new_id in label_map.items():
label[label_copy == old_id] = new_id
if reduce_zero_label:
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
if label_map is not None:
label_copy = label.clone()
for old_id, new_id in label_map.items():
label[label_copy == old_id] = new_id
mask = (label != ignore_index)
pred_label = pred_label[mask]

View File

@ -133,6 +133,12 @@ class LoadAnnotations(object):
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze().astype(np.uint8)
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == 254] = 255
# modify if custom classes
if results.get('label_map', None) is not None:
# Add deep copy to solve bug of repeatedly
@ -141,12 +147,6 @@ class LoadAnnotations(object):
gt_semantic_seg_copy = gt_semantic_seg.copy()
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == 254] = 255
results['gt_semantic_seg'] = gt_semantic_seg
results['seg_fields'].append('gt_semantic_seg')
return results

View File

@ -177,6 +177,61 @@ class TestLoading(object):
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test with removing a class and reducing zero label simultaneously
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
# since reduce_zero_label is True, there are only 4 real classes.
# if the full set of classes is ["A", "B", "C", "D"], the
# following label map simulates the dataset option
# classes=["A", "C", "D"] which removes class "B".
label_map={
0: 0,
1: -1, # simulate removing class 1
2: 1,
3: 2
},
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
# reduce zero label
load_anns = LoadAnnotations(reduce_zero_label=True)
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
true_mask = np.ones_like(gt_array) * 255 # all zeros get mapped to 255
true_mask[2:4, 2:4] = 0 # 1s are reduced to class 0 mapped to class 0
true_mask[2:4, 6:8] = -1 # 2s are reduced to class 1 which is removed
true_mask[6:8, 2:4] = 1 # 3s are reduced to class 2 mapped to class 1
true_mask[6:8, 6:8] = 2 # 4s are reduced to class 3 mapped to class 2
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test no custom classes
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, test_gt)
# test no custom classes
results = dict(
img_info=dict(filename=img_path),