mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Support custom palette (#157)
* Fix split * Update tests/test_data/test_dataset.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
93a2456def
commit
e7240c8cf1
@ -60,6 +60,10 @@ class CustomDataset(Dataset):
|
|||||||
Default: False
|
Default: False
|
||||||
classes (str | Sequence[str], optional): Specify classes to load.
|
classes (str | Sequence[str], optional): Specify classes to load.
|
||||||
If is None, ``cls.CLASSES`` will be used. Default: None.
|
If is None, ``cls.CLASSES`` will be used. Default: None.
|
||||||
|
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
||||||
|
The palette of segmentation map. If None is given, and
|
||||||
|
self.PALETTE is None, random palette will be generated.
|
||||||
|
Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CLASSES = None
|
CLASSES = None
|
||||||
@ -77,7 +81,8 @@ class CustomDataset(Dataset):
|
|||||||
test_mode=False,
|
test_mode=False,
|
||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
reduce_zero_label=False,
|
reduce_zero_label=False,
|
||||||
classes=None):
|
classes=None,
|
||||||
|
palette=None):
|
||||||
self.pipeline = Compose(pipeline)
|
self.pipeline = Compose(pipeline)
|
||||||
self.img_dir = img_dir
|
self.img_dir = img_dir
|
||||||
self.img_suffix = img_suffix
|
self.img_suffix = img_suffix
|
||||||
@ -89,7 +94,8 @@ class CustomDataset(Dataset):
|
|||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.reduce_zero_label = reduce_zero_label
|
self.reduce_zero_label = reduce_zero_label
|
||||||
self.label_map = None
|
self.label_map = None
|
||||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes)
|
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
||||||
|
classes, palette)
|
||||||
|
|
||||||
# join paths if data_root is specified
|
# join paths if data_root is specified
|
||||||
if self.data_root is not None:
|
if self.data_root is not None:
|
||||||
@ -241,7 +247,7 @@ class CustomDataset(Dataset):
|
|||||||
|
|
||||||
return gt_seg_maps
|
return gt_seg_maps
|
||||||
|
|
||||||
def get_classes_and_palette(self, classes=None):
|
def get_classes_and_palette(self, classes=None, palette=None):
|
||||||
"""Get class names of current dataset.
|
"""Get class names of current dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -250,6 +256,9 @@ class CustomDataset(Dataset):
|
|||||||
string, take it as a file name. The file contains the name of
|
string, take it as a file name. The file contains the name of
|
||||||
classes where each line contains one class name. If classes is
|
classes where each line contains one class name. If classes is
|
||||||
a tuple or list, override the CLASSES defined by the dataset.
|
a tuple or list, override the CLASSES defined by the dataset.
|
||||||
|
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
||||||
|
The palette of segmentation map. If None is given, random
|
||||||
|
palette will be generated. Default: None
|
||||||
"""
|
"""
|
||||||
if classes is None:
|
if classes is None:
|
||||||
self.custom_classes = False
|
self.custom_classes = False
|
||||||
@ -278,11 +287,11 @@ class CustomDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.label_map[i] = classes.index(c)
|
self.label_map[i] = classes.index(c)
|
||||||
|
|
||||||
palette = self.get_palette_for_custom_classes()
|
palette = self.get_palette_for_custom_classes(class_names, palette)
|
||||||
|
|
||||||
return class_names, palette
|
return class_names, palette
|
||||||
|
|
||||||
def get_palette_for_custom_classes(self):
|
def get_palette_for_custom_classes(self, class_names, palette=None):
|
||||||
|
|
||||||
if self.label_map is not None:
|
if self.label_map is not None:
|
||||||
# return subset of palette
|
# return subset of palette
|
||||||
@ -293,6 +302,9 @@ class CustomDataset(Dataset):
|
|||||||
palette.append(self.PALETTE[old_id])
|
palette.append(self.PALETTE[old_id])
|
||||||
palette = type(self.PALETTE)(palette)
|
palette = type(self.PALETTE)(palette)
|
||||||
|
|
||||||
|
elif palette is None:
|
||||||
|
if self.PALETTE is None:
|
||||||
|
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
||||||
else:
|
else:
|
||||||
palette = self.PALETTE
|
palette = self.PALETTE
|
||||||
|
|
||||||
|
@ -231,3 +231,33 @@ def test_custom_classes_override_default(dataset, classes):
|
|||||||
test_mode=True)
|
test_mode=True)
|
||||||
|
|
||||||
assert custom_dataset.CLASSES == original_classes
|
assert custom_dataset.CLASSES == original_classes
|
||||||
|
|
||||||
|
|
||||||
|
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||||
|
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||||
|
MagicMock(side_effect=lambda idx: idx))
|
||||||
|
def test_custom_dataset_random_palette_is_generated():
|
||||||
|
dataset = CustomDataset(
|
||||||
|
pipeline=[],
|
||||||
|
img_dir=MagicMock(),
|
||||||
|
split=MagicMock(),
|
||||||
|
classes=('bus', 'car'),
|
||||||
|
test_mode=True)
|
||||||
|
assert len(dataset.PALETTE) == 2
|
||||||
|
for class_color in dataset.PALETTE:
|
||||||
|
assert len(class_color) == 3
|
||||||
|
assert all(x >= 0 and x <= 255 for x in class_color)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||||
|
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||||
|
MagicMock(side_effect=lambda idx: idx))
|
||||||
|
def test_custom_dataset_custom_palette():
|
||||||
|
dataset = CustomDataset(
|
||||||
|
pipeline=[],
|
||||||
|
img_dir=MagicMock(),
|
||||||
|
split=MagicMock(),
|
||||||
|
classes=('bus', 'car'),
|
||||||
|
palette=[[100, 100, 100], [200, 200, 200]],
|
||||||
|
test_mode=True)
|
||||||
|
assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user