Support custom palette (#157)

* Fix split

* Update tests/test_data/test_dataset.py

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
David de la Iglesia Castro 2020-09-30 12:02:08 +02:00 committed by GitHub
parent 93a2456def
commit e7240c8cf1
2 changed files with 49 additions and 7 deletions

View File

@ -60,6 +60,10 @@ class CustomDataset(Dataset):
Default: False
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, and
self.PALETTE is None, random palette will be generated.
Default: None
"""
CLASSES = None
@ -77,7 +81,8 @@ class CustomDataset(Dataset):
test_mode=False,
ignore_index=255,
reduce_zero_label=False,
classes=None):
classes=None,
palette=None):
self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
@ -89,7 +94,8 @@ class CustomDataset(Dataset):
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes)
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
# join paths if data_root is specified
if self.data_root is not None:
@ -241,7 +247,7 @@ class CustomDataset(Dataset):
return gt_seg_maps
def get_classes_and_palette(self, classes=None):
def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset.
Args:
@ -250,6 +256,9 @@ class CustomDataset(Dataset):
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, random
palette will be generated. Default: None
"""
if classes is None:
self.custom_classes = False
@ -278,11 +287,11 @@ class CustomDataset(Dataset):
else:
self.label_map[i] = classes.index(c)
palette = self.get_palette_for_custom_classes()
palette = self.get_palette_for_custom_classes(class_names, palette)
return class_names, palette
def get_palette_for_custom_classes(self):
def get_palette_for_custom_classes(self, class_names, palette=None):
if self.label_map is not None:
# return subset of palette
@ -293,8 +302,11 @@ class CustomDataset(Dataset):
palette.append(self.PALETTE[old_id])
palette = type(self.PALETTE)(palette)
else:
palette = self.PALETTE
elif palette is None:
if self.PALETTE is None:
palette = np.random.randint(0, 255, size=(len(class_names), 3))
else:
palette = self.PALETTE
return palette

View File

@ -231,3 +231,33 @@ def test_custom_classes_override_default(dataset, classes):
test_mode=True)
assert custom_dataset.CLASSES == original_classes
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
def test_custom_dataset_random_palette_is_generated():
dataset = CustomDataset(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=('bus', 'car'),
test_mode=True)
assert len(dataset.PALETTE) == 2
for class_color in dataset.PALETTE:
assert len(class_color) == 3
assert all(x >= 0 and x <= 255 for x in class_color)
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
def test_custom_dataset_custom_palette():
dataset = CustomDataset(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=('bus', 'car'),
palette=[[100, 100, 100], [200, 200, 200]],
test_mode=True)
assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])