mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Support custom palette (#157)
* Fix split * Update tests/test_data/test_dataset.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
93a2456def
commit
e7240c8cf1
@ -60,6 +60,10 @@ class CustomDataset(Dataset):
|
||||
Default: False
|
||||
classes (str | Sequence[str], optional): Specify classes to load.
|
||||
If is None, ``cls.CLASSES`` will be used. Default: None.
|
||||
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
||||
The palette of segmentation map. If None is given, and
|
||||
self.PALETTE is None, random palette will be generated.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
CLASSES = None
|
||||
@ -77,7 +81,8 @@ class CustomDataset(Dataset):
|
||||
test_mode=False,
|
||||
ignore_index=255,
|
||||
reduce_zero_label=False,
|
||||
classes=None):
|
||||
classes=None,
|
||||
palette=None):
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.img_dir = img_dir
|
||||
self.img_suffix = img_suffix
|
||||
@ -89,7 +94,8 @@ class CustomDataset(Dataset):
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.label_map = None
|
||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes)
|
||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
||||
classes, palette)
|
||||
|
||||
# join paths if data_root is specified
|
||||
if self.data_root is not None:
|
||||
@ -241,7 +247,7 @@ class CustomDataset(Dataset):
|
||||
|
||||
return gt_seg_maps
|
||||
|
||||
def get_classes_and_palette(self, classes=None):
|
||||
def get_classes_and_palette(self, classes=None, palette=None):
|
||||
"""Get class names of current dataset.
|
||||
|
||||
Args:
|
||||
@ -250,6 +256,9 @@ class CustomDataset(Dataset):
|
||||
string, take it as a file name. The file contains the name of
|
||||
classes where each line contains one class name. If classes is
|
||||
a tuple or list, override the CLASSES defined by the dataset.
|
||||
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
||||
The palette of segmentation map. If None is given, random
|
||||
palette will be generated. Default: None
|
||||
"""
|
||||
if classes is None:
|
||||
self.custom_classes = False
|
||||
@ -278,11 +287,11 @@ class CustomDataset(Dataset):
|
||||
else:
|
||||
self.label_map[i] = classes.index(c)
|
||||
|
||||
palette = self.get_palette_for_custom_classes()
|
||||
palette = self.get_palette_for_custom_classes(class_names, palette)
|
||||
|
||||
return class_names, palette
|
||||
|
||||
def get_palette_for_custom_classes(self):
|
||||
def get_palette_for_custom_classes(self, class_names, palette=None):
|
||||
|
||||
if self.label_map is not None:
|
||||
# return subset of palette
|
||||
@ -293,6 +302,9 @@ class CustomDataset(Dataset):
|
||||
palette.append(self.PALETTE[old_id])
|
||||
palette = type(self.PALETTE)(palette)
|
||||
|
||||
elif palette is None:
|
||||
if self.PALETTE is None:
|
||||
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
||||
else:
|
||||
palette = self.PALETTE
|
||||
|
||||
|
@ -231,3 +231,33 @@ def test_custom_classes_override_default(dataset, classes):
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES == original_classes
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||
MagicMock(side_effect=lambda idx: idx))
|
||||
def test_custom_dataset_random_palette_is_generated():
|
||||
dataset = CustomDataset(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=('bus', 'car'),
|
||||
test_mode=True)
|
||||
assert len(dataset.PALETTE) == 2
|
||||
for class_color in dataset.PALETTE:
|
||||
assert len(class_color) == 3
|
||||
assert all(x >= 0 and x <= 255 for x in class_color)
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||
MagicMock(side_effect=lambda idx: idx))
|
||||
def test_custom_dataset_custom_palette():
|
||||
dataset = CustomDataset(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=('bus', 'car'),
|
||||
palette=[[100, 100, 100], [200, 200, 200]],
|
||||
test_mode=True)
|
||||
assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])
|
||||
|
Loading…
x
Reference in New Issue
Block a user