diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 4615d4114..0ee15b1aa 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -349,7 +349,7 @@ class CustomDataset(Dataset): self.label_map = {} for i, c in enumerate(self.CLASSES): if c not in class_names: - self.label_map[i] = -1 + self.label_map[i] = 255 else: self.label_map[i] = class_names.index(c) @@ -364,7 +364,7 @@ class CustomDataset(Dataset): palette = [] for old_id, new_id in sorted( self.label_map.items(), key=lambda x: x[1]): - if new_id != -1: + if new_id != 255: palette.append(self.PALETTE[old_id]) palette = type(self.PALETTE)(palette) diff --git a/tests/test_data/test_loading.py b/tests/test_data/test_loading.py index d41d46023..19f495acc 100644 --- a/tests/test_data/test_loading.py +++ b/tests/test_data/test_loading.py @@ -187,7 +187,7 @@ class TestLoading(object): # classes=["A", "C", "D"] which removes class "B". label_map={ 0: 0, - 1: -1, # simulate removing class 1 + 1: 255, # simulate removing class 1 2: 1, 3: 2 }, @@ -204,7 +204,7 @@ class TestLoading(object): true_mask = np.ones_like(gt_array) * 255 # all zeros get mapped to 255 true_mask[2:4, 2:4] = 0 # 1s are reduced to class 0 mapped to class 0 - true_mask[2:4, 6:8] = -1 # 2s are reduced to class 1 which is removed + true_mask[2:4, 6:8] = 255 # 2s are reduced to class 1 which is removed true_mask[6:8, 2:4] = 1 # 3s are reduced to class 2 mapped to class 1 true_mask[6:8, 6:8] = 2 # 4s are reduced to class 3 mapped to class 2