From 7c6fa484110f855e517d0d54967c2610c947fffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iago=20Gonz=C3=A1lez?= Date: Wed, 16 Sep 2020 15:33:01 +0200 Subject: [PATCH] Add support for custom classes (#71) * Support for custom classes * Fix test * Fix pre-commit * Add pipeline logic for custom classes * Fix minor issues, fix test * Fix issues from PR review * Fix tests * Remove palette as str * Rename old_to_new_ids to label_map * Test for load_anns * Remove get_palette function * fixed temp * Add subset of palette, remove palette as arg * minor update Co-authored-by: Jiarui XU --- mmseg/datasets/custom.py | 70 ++++++++++++++++++++- mmseg/datasets/pipelines/loading.py | 4 ++ tests/test_data/test_dataset.py | 64 ++++++++++++++++++- tests/test_data/test_loading.py | 98 +++++++++++++++++++++++++++++ 4 files changed, 233 insertions(+), 3 deletions(-) diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 92d17c525..91d7b0b5e 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -58,6 +58,8 @@ class CustomDataset(Dataset): ignore_index (int): The label index to be ignored. Default: 255 reduce_zero_label (bool): Whether to mark label zero as ignored. Default: False + classes (str | Sequence[str], optional): Specify classes to load. + If is None, ``cls.CLASSES`` will be used. Default: None. """ CLASSES = None @@ -74,7 +76,8 @@ class CustomDataset(Dataset): data_root=None, test_mode=False, ignore_index=255, - reduce_zero_label=False): + reduce_zero_label=False, + classes=None): self.pipeline = Compose(pipeline) self.img_dir = img_dir self.img_suffix = img_suffix @@ -85,6 +88,8 @@ class CustomDataset(Dataset): self.test_mode = test_mode self.ignore_index = ignore_index self.reduce_zero_label = reduce_zero_label + self.label_map = None + self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes) # join paths if data_root is specified if self.data_root is not None: @@ -160,6 +165,8 @@ class CustomDataset(Dataset): def pre_pipeline(self, results): """Prepare results dict for pipeline.""" results['seg_fields'] = [] + if self.custom_classes: + results['label_map'] = self.label_map def __getitem__(self, idx): """Get training/test data after pipeline. @@ -220,6 +227,10 @@ class CustomDataset(Dataset): for img_info in self.img_infos: gt_seg_map = mmcv.imread( img_info['ann']['seg_map'], flag='unchanged', backend='pillow') + # modify if custom classes + if self.label_map is not None: + for old_id, new_id in self.label_map.items(): + gt_seg_map[gt_seg_map == old_id] = new_id if self.reduce_zero_label: # avoid using underflow conversion gt_seg_map[gt_seg_map == 0] = 255 @@ -230,6 +241,63 @@ class CustomDataset(Dataset): return gt_seg_maps + def get_classes_and_palette(self, classes=None): + """Get class names of current dataset. + + Args: + classes (Sequence[str] | str | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is a + string, take it as a file name. The file contains the name of + classes where each line contains one class name. If classes is + a tuple or list, override the CLASSES defined by the dataset. + """ + if classes is None: + self.custom_classes = False + return self.CLASSES, self.PALETTE + + self.custom_classes = True + if isinstance(classes, str): + # take it as a file path + class_names = mmcv.list_from_file(classes) + elif isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if self.CLASSES: + if not set(classes).issubset(self.CLASSES): + raise ValueError('classes is not a subset of CLASSES.') + + # dictionary, its keys are the old label ids and its values + # are the new label ids. + # used for changing pixel labels in load_annotations. + self.label_map = {} + for i, c in enumerate(self.CLASSES): + if c not in class_names: + self.label_map[i] = -1 + else: + self.label_map[i] = classes.index(c) + + palette = self.get_palette_for_custom_classes() + + return class_names, palette + + def get_palette_for_custom_classes(self): + + if self.label_map is not None: + # return subset of palette + palette = [] + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != -1: + palette.append(self.PALETTE[old_id]) + palette = type(self.PALETTE)(palette) + + else: + palette = self.PALETTE + + return palette + def evaluate(self, results, metric='mIoU', logger=None, **kwargs): """Evaluate the dataset. diff --git a/mmseg/datasets/pipelines/loading.py b/mmseg/datasets/pipelines/loading.py index 978626910..a98ddf20b 100644 --- a/mmseg/datasets/pipelines/loading.py +++ b/mmseg/datasets/pipelines/loading.py @@ -132,6 +132,10 @@ class LoadAnnotations(object): gt_semantic_seg = mmcv.imfrombytes( img_bytes, flag='unchanged', backend=self.imdecode_backend).squeeze().astype(np.uint8) + # modify if custom classes + if results.get('label_map', None) is not None: + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg == old_id] = new_id # reduce zero_label if self.reduce_zero_label: # avoid using underflow conversion diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index ee6d2c47a..cb178b2b0 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -5,8 +5,9 @@ import numpy as np import pytest from mmseg.core.evaluation import get_classes, get_palette -from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset, - CustomDataset, PascalVOCDataset, RepeatDataset) +from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, + ConcatDataset, CustomDataset, PascalVOCDataset, + RepeatDataset) def test_classes(): @@ -171,3 +172,62 @@ def test_custom_dataset(): assert 'mIoU' in eval_results assert 'mAcc' in eval_results assert 'aAcc' in eval_results + + +@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) +@patch('mmseg.datasets.CustomDataset.__getitem__', + MagicMock(side_effect=lambda idx: idx)) +@pytest.mark.parametrize('dataset, classes', [ + ('ADE20KDataset', ('wall', 'building')), + ('CityscapesDataset', ('road', 'sidewalk')), + ('CustomDataset', ('bus', 'car')), + ('PascalVOCDataset', ('aeroplane', 'bicycle')), +]) +def test_custom_classes_override_default(dataset, classes): + + dataset_class = DATASETS.get(dataset) + + original_classes = dataset_class.CLASSES + + # Test setting classes as a tuple + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=classes, + test_mode=True) + + assert custom_dataset.CLASSES != original_classes + assert custom_dataset.CLASSES == classes + + # Test setting classes as a list + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=list(classes), + test_mode=True) + + assert custom_dataset.CLASSES != original_classes + assert custom_dataset.CLASSES == list(classes) + + # Test overriding not a subset + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=[classes[0]], + test_mode=True) + + assert custom_dataset.CLASSES != original_classes + assert custom_dataset.CLASSES == [classes[0]] + + # Test default behavior + custom_dataset = dataset_class( + pipeline=[], + img_dir=MagicMock(), + split=MagicMock(), + classes=None, + test_mode=True) + + assert custom_dataset.CLASSES == original_classes diff --git a/tests/test_data/test_loading.py b/tests/test_data/test_loading.py index 653b3daf4..e8aa5d313 100644 --- a/tests/test_data/test_loading.py +++ b/tests/test_data/test_loading.py @@ -1,6 +1,8 @@ import copy import os.path as osp +import tempfile +import mmcv import numpy as np from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile @@ -98,3 +100,99 @@ class TestLoading(object): # this image is saved by PIL assert results['gt_semantic_seg'].shape == (288, 512) assert results['gt_semantic_seg'].dtype == np.uint8 + + def test_load_seg_custom_classes(self): + + test_img = np.random.rand(10, 10) + test_gt = np.zeros_like(test_img) + test_gt[2:4, 2:4] = 1 + test_gt[2:4, 6:8] = 2 + test_gt[6:8, 2:4] = 3 + test_gt[6:8, 6:8] = 4 + + tmp_dir = tempfile.TemporaryDirectory() + img_path = osp.join(tmp_dir.name, 'img.jpg') + gt_path = osp.join(tmp_dir.name, 'gt.png') + + mmcv.imwrite(test_img, img_path) + mmcv.imwrite(test_gt, gt_path) + + # test only train with label with id 3 + results = dict( + img_info=dict(filename=img_path), + ann_info=dict(seg_map=gt_path), + label_map={ + 0: 0, + 1: 0, + 2: 0, + 3: 1, + 4: 0 + }, + seg_fields=[]) + + load_imgs = LoadImageFromFile() + results = load_imgs(copy.deepcopy(results)) + + load_anns = LoadAnnotations() + results = load_anns(copy.deepcopy(results)) + + gt_array = results['gt_semantic_seg'] + + true_mask = np.zeros_like(gt_array) + true_mask[6:8, 2:4] = 1 + + assert results['seg_fields'] == ['gt_semantic_seg'] + assert gt_array.shape == (10, 10) + assert gt_array.dtype == np.uint8 + np.testing.assert_array_equal(gt_array, true_mask) + + # test only train with label with id 4 and 3 + results = dict( + img_info=dict(filename=img_path), + ann_info=dict(seg_map=gt_path), + label_map={ + 0: 0, + 1: 0, + 2: 0, + 3: 2, + 4: 1 + }, + seg_fields=[]) + + load_imgs = LoadImageFromFile() + results = load_imgs(copy.deepcopy(results)) + + load_anns = LoadAnnotations() + results = load_anns(copy.deepcopy(results)) + + gt_array = results['gt_semantic_seg'] + + true_mask = np.zeros_like(gt_array) + true_mask[6:8, 2:4] = 2 + true_mask[6:8, 6:8] = 1 + + assert results['seg_fields'] == ['gt_semantic_seg'] + assert gt_array.shape == (10, 10) + assert gt_array.dtype == np.uint8 + np.testing.assert_array_equal(gt_array, true_mask) + + # test no custom classes + results = dict( + img_info=dict(filename=img_path), + ann_info=dict(seg_map=gt_path), + seg_fields=[]) + + load_imgs = LoadImageFromFile() + results = load_imgs(copy.deepcopy(results)) + + load_anns = LoadAnnotations() + results = load_anns(copy.deepcopy(results)) + + gt_array = results['gt_semantic_seg'] + + assert results['seg_fields'] == ['gt_semantic_seg'] + assert gt_array.shape == (10, 10) + assert gt_array.dtype == np.uint8 + np.testing.assert_array_equal(gt_array, test_gt) + + tmp_dir.cleanup()