mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Switch order of reduce_zero_label
and applying label_map
in 1.x (#2517)
This is an almost exact duplicate of #2500 (that was made to the `master` branch) now applied to the `1.x` branch. --- ## Motivation I want to fix a bug through this PR. The bug occurs when two options -- `reduce_zero_label=True`, and custom classes are used. `reduce_zero_label` remaps the GT seg labels by remapping the zero-class to 255 which is ignored. Conceptually, this should occur *before* the `label_map` is applied, which maps *already reduced labels*. However, currently, the `label_map` is applied before the zero label is reduced. ## Modification The modification is simple: - I've just interchanged the order of the two operations by moving a few lines from bottom to top. - I've added a test that passes when the fix is introduced, and fails on the original `master` branch. ## BC-breaking (Optional) I do not anticipate this change braking any backward-compatibility. ## Checklist - [x] Pre-commit or other linting tools are used to fix the potential lint issues. - _I've fixed all linting/pre-commit errors._ - [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. - _I've added a unit test._ - [x] If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. - _I don't think this change affects MMDet or MMDet3D._ - [x] The documentation has been modified accordingly, like docstring or example tutorials. - _This change fixes an existing bug and doesn't require modifying any documentation/docstring._
This commit is contained in:
parent
6b53ec0a1f
commit
74e8b89b17
@ -96,14 +96,6 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
img_bytes, flag='unchanged',
|
||||
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
@ -116,6 +108,14 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
|
@ -144,6 +144,43 @@ class TestLoading:
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test with removing a class and reducing zero label simultaneously
|
||||
results = dict(
|
||||
img_path=img_path,
|
||||
seg_map_path=gt_path,
|
||||
# since reduce_zero_label is True, there are only 4 real classes.
|
||||
# if the full set of classes is ["A", "B", "C", "D"], the
|
||||
# following label map simulates the dataset option
|
||||
# classes=["A", "C", "D"] which removes class "B".
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 255, # simulate removing class 1
|
||||
2: 1,
|
||||
3: 2
|
||||
},
|
||||
reduce_zero_label=True, # reduce zero label
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
# reduce zero label
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_seg_map']
|
||||
|
||||
true_mask = np.ones_like(gt_array) * 255 # all zeros get mapped to 255
|
||||
true_mask[2:4, 2:4] = 0 # 1s are reduced to class 0 mapped to class 0
|
||||
true_mask[2:4, 6:8] = 255 # 2s are reduced to class 1 which is removed
|
||||
true_mask[6:8, 2:4] = 1 # 3s are reduced to class 2 mapped to class 1
|
||||
true_mask[6:8, 6:8] = 2 # 4s are reduced to class 3 mapped to class 2
|
||||
|
||||
assert results['seg_fields'] == ['gt_seg_map']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test no custom classes
|
||||
results = dict(
|
||||
img_path=img_path,
|
||||
|
Loading…
x
Reference in New Issue
Block a user