Add support for custom classes (#71)

* Support for custom classes

* Fix test

* Fix pre-commit

* Add pipeline logic for custom classes

* Fix minor issues, fix test

* Fix issues from PR review

* Fix tests

* Remove palette as str

* Rename old_to_new_ids to label_map

* Test for load_anns

* Remove get_palette function

* fixed temp

* Add subset of palette, remove palette as arg

* minor update

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
Iago González 2020-09-16 15:33:01 +02:00 committed by GitHub
parent e2371a196e
commit 7c6fa48411
4 changed files with 233 additions and 3 deletions

View File

@ -58,6 +58,8 @@ class CustomDataset(Dataset):
ignore_index (int): The label index to be ignored. Default: 255 ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored. reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False Default: False
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
""" """
CLASSES = None CLASSES = None
@ -74,7 +76,8 @@ class CustomDataset(Dataset):
data_root=None, data_root=None,
test_mode=False, test_mode=False,
ignore_index=255, ignore_index=255,
reduce_zero_label=False): reduce_zero_label=False,
classes=None):
self.pipeline = Compose(pipeline) self.pipeline = Compose(pipeline)
self.img_dir = img_dir self.img_dir = img_dir
self.img_suffix = img_suffix self.img_suffix = img_suffix
@ -85,6 +88,8 @@ class CustomDataset(Dataset):
self.test_mode = test_mode self.test_mode = test_mode
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label self.reduce_zero_label = reduce_zero_label
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes)
# join paths if data_root is specified # join paths if data_root is specified
if self.data_root is not None: if self.data_root is not None:
@ -160,6 +165,8 @@ class CustomDataset(Dataset):
def pre_pipeline(self, results): def pre_pipeline(self, results):
"""Prepare results dict for pipeline.""" """Prepare results dict for pipeline."""
results['seg_fields'] = [] results['seg_fields'] = []
if self.custom_classes:
results['label_map'] = self.label_map
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get training/test data after pipeline. """Get training/test data after pipeline.
@ -220,6 +227,10 @@ class CustomDataset(Dataset):
for img_info in self.img_infos: for img_info in self.img_infos:
gt_seg_map = mmcv.imread( gt_seg_map = mmcv.imread(
img_info['ann']['seg_map'], flag='unchanged', backend='pillow') img_info['ann']['seg_map'], flag='unchanged', backend='pillow')
# modify if custom classes
if self.label_map is not None:
for old_id, new_id in self.label_map.items():
gt_seg_map[gt_seg_map == old_id] = new_id
if self.reduce_zero_label: if self.reduce_zero_label:
# avoid using underflow conversion # avoid using underflow conversion
gt_seg_map[gt_seg_map == 0] = 255 gt_seg_map[gt_seg_map == 0] = 255
@ -230,6 +241,63 @@ class CustomDataset(Dataset):
return gt_seg_maps return gt_seg_maps
def get_classes_and_palette(self, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
"""
if classes is None:
self.custom_classes = False
return self.CLASSES, self.PALETTE
self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES:
if not set(classes).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self.label_map = {}
for i, c in enumerate(self.CLASSES):
if c not in class_names:
self.label_map[i] = -1
else:
self.label_map[i] = classes.index(c)
palette = self.get_palette_for_custom_classes()
return class_names, palette
def get_palette_for_custom_classes(self):
if self.label_map is not None:
# return subset of palette
palette = []
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
if new_id != -1:
palette.append(self.PALETTE[old_id])
palette = type(self.PALETTE)(palette)
else:
palette = self.PALETTE
return palette
def evaluate(self, results, metric='mIoU', logger=None, **kwargs): def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
"""Evaluate the dataset. """Evaluate the dataset.

View File

@ -132,6 +132,10 @@ class LoadAnnotations(object):
gt_semantic_seg = mmcv.imfrombytes( gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged', img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze().astype(np.uint8) backend=self.imdecode_backend).squeeze().astype(np.uint8)
# modify if custom classes
if results.get('label_map', None) is not None:
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg == old_id] = new_id
# reduce zero_label # reduce zero_label
if self.reduce_zero_label: if self.reduce_zero_label:
# avoid using underflow conversion # avoid using underflow conversion

View File

@ -5,8 +5,9 @@ import numpy as np
import pytest import pytest
from mmseg.core.evaluation import get_classes, get_palette from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset, from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
CustomDataset, PascalVOCDataset, RepeatDataset) ConcatDataset, CustomDataset, PascalVOCDataset,
RepeatDataset)
def test_classes(): def test_classes():
@ -171,3 +172,62 @@ def test_custom_dataset():
assert 'mIoU' in eval_results assert 'mIoU' in eval_results
assert 'mAcc' in eval_results assert 'mAcc' in eval_results
assert 'aAcc' in eval_results assert 'aAcc' in eval_results
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('CustomDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):
dataset_class = DATASETS.get(dataset)
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=classes,
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == classes
# Test setting classes as a list
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=list(classes),
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == list(classes)
# Test overriding not a subset
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=[classes[0]],
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == [classes[0]]
# Test default behavior
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)
assert custom_dataset.CLASSES == original_classes

View File

@ -1,6 +1,8 @@
import copy import copy
import os.path as osp import os.path as osp
import tempfile
import mmcv
import numpy as np import numpy as np
from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile
@ -98,3 +100,99 @@ class TestLoading(object):
# this image is saved by PIL # this image is saved by PIL
assert results['gt_semantic_seg'].shape == (288, 512) assert results['gt_semantic_seg'].shape == (288, 512)
assert results['gt_semantic_seg'].dtype == np.uint8 assert results['gt_semantic_seg'].dtype == np.uint8
def test_load_seg_custom_classes(self):
test_img = np.random.rand(10, 10)
test_gt = np.zeros_like(test_img)
test_gt[2:4, 2:4] = 1
test_gt[2:4, 6:8] = 2
test_gt[6:8, 2:4] = 3
test_gt[6:8, 6:8] = 4
tmp_dir = tempfile.TemporaryDirectory()
img_path = osp.join(tmp_dir.name, 'img.jpg')
gt_path = osp.join(tmp_dir.name, 'gt.png')
mmcv.imwrite(test_img, img_path)
mmcv.imwrite(test_gt, gt_path)
# test only train with label with id 3
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
label_map={
0: 0,
1: 0,
2: 0,
3: 1,
4: 0
},
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 1
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test only train with label with id 4 and 3
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
label_map={
0: 0,
1: 0,
2: 0,
3: 2,
4: 1
},
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 2
true_mask[6:8, 6:8] = 1
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test no custom classes
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, test_gt)
tmp_dir.cleanup()