Add support for custom classes (#71)

* Support for custom classes

* Fix test

* Fix pre-commit

* Add pipeline logic for custom classes

* Fix minor issues, fix test

* Fix issues from PR review

* Fix tests

* Remove palette as str

* Rename old_to_new_ids to label_map

* Test for load_anns

* Remove get_palette function

* fixed temp

* Add subset of palette, remove palette as arg

* minor update

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
Iago González 2020-09-16 15:33:01 +02:00 committed by GitHub
parent e2371a196e
commit 7c6fa48411
4 changed files with 233 additions and 3 deletions

View File

@ -58,6 +58,8 @@ class CustomDataset(Dataset):
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
"""
CLASSES = None
@ -74,7 +76,8 @@ class CustomDataset(Dataset):
data_root=None,
test_mode=False,
ignore_index=255,
reduce_zero_label=False):
reduce_zero_label=False,
classes=None):
self.pipeline = Compose(pipeline)
self.img_dir = img_dir
self.img_suffix = img_suffix
@ -85,6 +88,8 @@ class CustomDataset(Dataset):
self.test_mode = test_mode
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes)
# join paths if data_root is specified
if self.data_root is not None:
@ -160,6 +165,8 @@ class CustomDataset(Dataset):
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['seg_fields'] = []
if self.custom_classes:
results['label_map'] = self.label_map
def __getitem__(self, idx):
"""Get training/test data after pipeline.
@ -220,6 +227,10 @@ class CustomDataset(Dataset):
for img_info in self.img_infos:
gt_seg_map = mmcv.imread(
img_info['ann']['seg_map'], flag='unchanged', backend='pillow')
# modify if custom classes
if self.label_map is not None:
for old_id, new_id in self.label_map.items():
gt_seg_map[gt_seg_map == old_id] = new_id
if self.reduce_zero_label:
# avoid using underflow conversion
gt_seg_map[gt_seg_map == 0] = 255
@ -230,6 +241,63 @@ class CustomDataset(Dataset):
return gt_seg_maps
def get_classes_and_palette(self, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
"""
if classes is None:
self.custom_classes = False
return self.CLASSES, self.PALETTE
self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES:
if not set(classes).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self.label_map = {}
for i, c in enumerate(self.CLASSES):
if c not in class_names:
self.label_map[i] = -1
else:
self.label_map[i] = classes.index(c)
palette = self.get_palette_for_custom_classes()
return class_names, palette
def get_palette_for_custom_classes(self):
if self.label_map is not None:
# return subset of palette
palette = []
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
if new_id != -1:
palette.append(self.PALETTE[old_id])
palette = type(self.PALETTE)(palette)
else:
palette = self.PALETTE
return palette
def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
"""Evaluate the dataset.

View File

@ -132,6 +132,10 @@ class LoadAnnotations(object):
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze().astype(np.uint8)
# modify if custom classes
if results.get('label_map', None) is not None:
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg == old_id] = new_id
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion

View File

@ -5,8 +5,9 @@ import numpy as np
import pytest
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset,
CustomDataset, PascalVOCDataset, RepeatDataset)
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, PascalVOCDataset,
RepeatDataset)
def test_classes():
@ -171,3 +172,62 @@ def test_custom_dataset():
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('CustomDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):
dataset_class = DATASETS.get(dataset)
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=classes,
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == classes
# Test setting classes as a list
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=list(classes),
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == list(classes)
# Test overriding not a subset
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=[classes[0]],
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == [classes[0]]
# Test default behavior
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)
assert custom_dataset.CLASSES == original_classes

View File

@ -1,6 +1,8 @@
import copy
import os.path as osp
import tempfile
import mmcv
import numpy as np
from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile
@ -98,3 +100,99 @@ class TestLoading(object):
# this image is saved by PIL
assert results['gt_semantic_seg'].shape == (288, 512)
assert results['gt_semantic_seg'].dtype == np.uint8
def test_load_seg_custom_classes(self):
test_img = np.random.rand(10, 10)
test_gt = np.zeros_like(test_img)
test_gt[2:4, 2:4] = 1
test_gt[2:4, 6:8] = 2
test_gt[6:8, 2:4] = 3
test_gt[6:8, 6:8] = 4
tmp_dir = tempfile.TemporaryDirectory()
img_path = osp.join(tmp_dir.name, 'img.jpg')
gt_path = osp.join(tmp_dir.name, 'gt.png')
mmcv.imwrite(test_img, img_path)
mmcv.imwrite(test_gt, gt_path)
# test only train with label with id 3
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
label_map={
0: 0,
1: 0,
2: 0,
3: 1,
4: 0
},
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 1
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test only train with label with id 4 and 3
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
label_map={
0: 0,
1: 0,
2: 0,
3: 2,
4: 1
},
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
true_mask = np.zeros_like(gt_array)
true_mask[6:8, 2:4] = 2
true_mask[6:8, 6:8] = 1
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, true_mask)
# test no custom classes
results = dict(
img_info=dict(filename=img_path),
ann_info=dict(seg_map=gt_path),
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))
load_anns = LoadAnnotations()
results = load_anns(copy.deepcopy(results))
gt_array = results['gt_semantic_seg']
assert results['seg_fields'] == ['gt_semantic_seg']
assert gt_array.shape == (10, 10)
assert gt_array.dtype == np.uint8
np.testing.assert_array_equal(gt_array, test_gt)
tmp_dir.cleanup()