mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Add support for custom classes (#71)
* Support for custom classes * Fix test * Fix pre-commit * Add pipeline logic for custom classes * Fix minor issues, fix test * Fix issues from PR review * Fix tests * Remove palette as str * Rename old_to_new_ids to label_map * Test for load_anns * Remove get_palette function * fixed temp * Add subset of palette, remove palette as arg * minor update Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
e2371a196e
commit
7c6fa48411
@ -58,6 +58,8 @@ class CustomDataset(Dataset):
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default: False
|
||||
classes (str | Sequence[str], optional): Specify classes to load.
|
||||
If is None, ``cls.CLASSES`` will be used. Default: None.
|
||||
"""
|
||||
|
||||
CLASSES = None
|
||||
@ -74,7 +76,8 @@ class CustomDataset(Dataset):
|
||||
data_root=None,
|
||||
test_mode=False,
|
||||
ignore_index=255,
|
||||
reduce_zero_label=False):
|
||||
reduce_zero_label=False,
|
||||
classes=None):
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.img_dir = img_dir
|
||||
self.img_suffix = img_suffix
|
||||
@ -85,6 +88,8 @@ class CustomDataset(Dataset):
|
||||
self.test_mode = test_mode
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.label_map = None
|
||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(classes)
|
||||
|
||||
# join paths if data_root is specified
|
||||
if self.data_root is not None:
|
||||
@ -160,6 +165,8 @@ class CustomDataset(Dataset):
|
||||
def pre_pipeline(self, results):
|
||||
"""Prepare results dict for pipeline."""
|
||||
results['seg_fields'] = []
|
||||
if self.custom_classes:
|
||||
results['label_map'] = self.label_map
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get training/test data after pipeline.
|
||||
@ -220,6 +227,10 @@ class CustomDataset(Dataset):
|
||||
for img_info in self.img_infos:
|
||||
gt_seg_map = mmcv.imread(
|
||||
img_info['ann']['seg_map'], flag='unchanged', backend='pillow')
|
||||
# modify if custom classes
|
||||
if self.label_map is not None:
|
||||
for old_id, new_id in self.label_map.items():
|
||||
gt_seg_map[gt_seg_map == old_id] = new_id
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_seg_map[gt_seg_map == 0] = 255
|
||||
@ -230,6 +241,63 @@ class CustomDataset(Dataset):
|
||||
|
||||
return gt_seg_maps
|
||||
|
||||
def get_classes_and_palette(self, classes=None):
|
||||
"""Get class names of current dataset.
|
||||
|
||||
Args:
|
||||
classes (Sequence[str] | str | None): If classes is None, use
|
||||
default CLASSES defined by builtin dataset. If classes is a
|
||||
string, take it as a file name. The file contains the name of
|
||||
classes where each line contains one class name. If classes is
|
||||
a tuple or list, override the CLASSES defined by the dataset.
|
||||
"""
|
||||
if classes is None:
|
||||
self.custom_classes = False
|
||||
return self.CLASSES, self.PALETTE
|
||||
|
||||
self.custom_classes = True
|
||||
if isinstance(classes, str):
|
||||
# take it as a file path
|
||||
class_names = mmcv.list_from_file(classes)
|
||||
elif isinstance(classes, (tuple, list)):
|
||||
class_names = classes
|
||||
else:
|
||||
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
||||
|
||||
if self.CLASSES:
|
||||
if not set(classes).issubset(self.CLASSES):
|
||||
raise ValueError('classes is not a subset of CLASSES.')
|
||||
|
||||
# dictionary, its keys are the old label ids and its values
|
||||
# are the new label ids.
|
||||
# used for changing pixel labels in load_annotations.
|
||||
self.label_map = {}
|
||||
for i, c in enumerate(self.CLASSES):
|
||||
if c not in class_names:
|
||||
self.label_map[i] = -1
|
||||
else:
|
||||
self.label_map[i] = classes.index(c)
|
||||
|
||||
palette = self.get_palette_for_custom_classes()
|
||||
|
||||
return class_names, palette
|
||||
|
||||
def get_palette_for_custom_classes(self):
|
||||
|
||||
if self.label_map is not None:
|
||||
# return subset of palette
|
||||
palette = []
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != -1:
|
||||
palette.append(self.PALETTE[old_id])
|
||||
palette = type(self.PALETTE)(palette)
|
||||
|
||||
else:
|
||||
palette = self.PALETTE
|
||||
|
||||
return palette
|
||||
|
||||
def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
|
@ -132,6 +132,10 @@ class LoadAnnotations(object):
|
||||
gt_semantic_seg = mmcv.imfrombytes(
|
||||
img_bytes, flag='unchanged',
|
||||
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg == old_id] = new_id
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
|
@ -5,8 +5,9 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmseg.core.evaluation import get_classes, get_palette
|
||||
from mmseg.datasets import (ADE20KDataset, CityscapesDataset, ConcatDataset,
|
||||
CustomDataset, PascalVOCDataset, RepeatDataset)
|
||||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||
ConcatDataset, CustomDataset, PascalVOCDataset,
|
||||
RepeatDataset)
|
||||
|
||||
|
||||
def test_classes():
|
||||
@ -171,3 +172,62 @@ def test_custom_dataset():
|
||||
assert 'mIoU' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||
MagicMock(side_effect=lambda idx: idx))
|
||||
@pytest.mark.parametrize('dataset, classes', [
|
||||
('ADE20KDataset', ('wall', 'building')),
|
||||
('CityscapesDataset', ('road', 'sidewalk')),
|
||||
('CustomDataset', ('bus', 'car')),
|
||||
('PascalVOCDataset', ('aeroplane', 'bicycle')),
|
||||
])
|
||||
def test_custom_classes_override_default(dataset, classes):
|
||||
|
||||
dataset_class = DATASETS.get(dataset)
|
||||
|
||||
original_classes = dataset_class.CLASSES
|
||||
|
||||
# Test setting classes as a tuple
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=classes,
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES != original_classes
|
||||
assert custom_dataset.CLASSES == classes
|
||||
|
||||
# Test setting classes as a list
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=list(classes),
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES != original_classes
|
||||
assert custom_dataset.CLASSES == list(classes)
|
||||
|
||||
# Test overriding not a subset
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=[classes[0]],
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES != original_classes
|
||||
assert custom_dataset.CLASSES == [classes[0]]
|
||||
|
||||
# Test default behavior
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=None,
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES == original_classes
|
||||
|
@ -1,6 +1,8 @@
|
||||
import copy
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile
|
||||
@ -98,3 +100,99 @@ class TestLoading(object):
|
||||
# this image is saved by PIL
|
||||
assert results['gt_semantic_seg'].shape == (288, 512)
|
||||
assert results['gt_semantic_seg'].dtype == np.uint8
|
||||
|
||||
def test_load_seg_custom_classes(self):
|
||||
|
||||
test_img = np.random.rand(10, 10)
|
||||
test_gt = np.zeros_like(test_img)
|
||||
test_gt[2:4, 2:4] = 1
|
||||
test_gt[2:4, 6:8] = 2
|
||||
test_gt[6:8, 2:4] = 3
|
||||
test_gt[6:8, 6:8] = 4
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
img_path = osp.join(tmp_dir.name, 'img.jpg')
|
||||
gt_path = osp.join(tmp_dir.name, 'gt.png')
|
||||
|
||||
mmcv.imwrite(test_img, img_path)
|
||||
mmcv.imwrite(test_gt, gt_path)
|
||||
|
||||
# test only train with label with id 3
|
||||
results = dict(
|
||||
img_info=dict(filename=img_path),
|
||||
ann_info=dict(seg_map=gt_path),
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 0,
|
||||
2: 0,
|
||||
3: 1,
|
||||
4: 0
|
||||
},
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_semantic_seg']
|
||||
|
||||
true_mask = np.zeros_like(gt_array)
|
||||
true_mask[6:8, 2:4] = 1
|
||||
|
||||
assert results['seg_fields'] == ['gt_semantic_seg']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test only train with label with id 4 and 3
|
||||
results = dict(
|
||||
img_info=dict(filename=img_path),
|
||||
ann_info=dict(seg_map=gt_path),
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 0,
|
||||
2: 0,
|
||||
3: 2,
|
||||
4: 1
|
||||
},
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_semantic_seg']
|
||||
|
||||
true_mask = np.zeros_like(gt_array)
|
||||
true_mask[6:8, 2:4] = 2
|
||||
true_mask[6:8, 6:8] = 1
|
||||
|
||||
assert results['seg_fields'] == ['gt_semantic_seg']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test no custom classes
|
||||
results = dict(
|
||||
img_info=dict(filename=img_path),
|
||||
ann_info=dict(seg_map=gt_path),
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_semantic_seg']
|
||||
|
||||
assert results['seg_fields'] == ['gt_semantic_seg']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, test_gt)
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
Loading…
x
Reference in New Issue
Block a user