diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index f055faee2..7e42d6622 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -60,6 +60,10 @@ class CustomDataset(Dataset): Default: False classes (str | Sequence[str], optional): Specify classes to load. If is None, ``cls.CLASSES`` will be used. Default: None. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, and + self.PALETTE is None, random palette will be generated. + Default: None """ CLASSES = None @@ -77,7 +81,8 @@ class CustomDataset(Dataset): test_mode=False, ignore_index=255, reduce_zero_label=False, - classes=None): + classes=None, + palette=None): self.pipeline = Compose(pipeline) self.img_dir = img_dir self.img_suffix = img_suffix @@ -89,7 +94,8 @@ class CustomDataset(Dataset): self.ignore_index = ignore_index self.reduce_zero_label = reduce_zero_label self.label_map = None - self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes) + self.CLASSES, self.PALETTE = self.get_classes_and_palette( + classes, palette) # join paths if data_root is specified if self.data_root is not None: @@ -241,7 +247,7 @@ class CustomDataset(Dataset): return gt_seg_maps - def get_classes_and_palette(self, classes=None): + def get_classes_and_palette(self, classes=None, palette=None): """Get class names of current dataset. Args: @@ -250,6 +256,9 @@ class CustomDataset(Dataset): string, take it as a file name. The file contains the name of classes where each line contains one class name. If classes is a tuple or list, override the CLASSES defined by the dataset. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, random + palette will be generated. Default: None """ if classes is None: self.custom_classes = False @@ -278,11 +287,11 @@ class CustomDataset(Dataset): else: self.label_map[i] = classes.index(c) - palette = self.get_palette_for_custom_classes() + palette = self.get_palette_for_custom_classes(class_names, palette) return class_names, palette - def get_palette_for_custom_classes(self): + def get_palette_for_custom_classes(self, class_names, palette=None): if self.label_map is not None: # return subset of palette @@ -293,8 +302,11 @@ class CustomDataset(Dataset): palette.append(self.PALETTE[old_id]) palette = type(self.PALETTE)(palette) - else: - palette = self.PALETTE + elif palette is None: + if self.PALETTE is None: + palette = np.random.randint(0, 255, size=(len(class_names), 3)) + else: + palette = self.PALETTE return palette diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index cb178b2b0..d7e44f50e 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -231,3 +231,33 @@ def test_custom_classes_override_default(dataset, classes): test_mode=True) assert custom_dataset.CLASSES == original_classes + + +@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) +@patch('mmseg.datasets.CustomDataset.__getitem__', + MagicMock(side_effect=lambda idx: idx)) +def test_custom_dataset_random_palette_is_generated(): + dataset = CustomDataset( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=('bus', 'car'), + test_mode=True) + assert len(dataset.PALETTE) == 2 + for class_color in dataset.PALETTE: + assert len(class_color) == 3 + assert all(x >= 0 and x <= 255 for x in class_color) + + +@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) +@patch('mmseg.datasets.CustomDataset.__getitem__', + MagicMock(side_effect=lambda idx: idx)) +def test_custom_dataset_custom_palette(): + dataset = CustomDataset( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=('bus', 'car'), + palette=[[100, 100, 100], [200, 200, 200]], + test_mode=True) + assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])