diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index bb879a717..eca4155bb 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -122,6 +122,10 @@ class CustomDataset(BaseDataset): # Get label map for custom classes new_classes = self._metainfo.get('classes', None) self.label_map = self.get_label_map(new_classes) + self._metainfo.update( + dict( + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label)) # Update palette based on label map or generate palette # if it is not defined @@ -240,7 +244,8 @@ class CustomDataset(BaseDataset): seg_map = img_name + self.seg_map_suffix data_info['seg_map_path'] = osp.join(ann_dir, seg_map) data_info['label_map'] = self.label_map - data_info['seg_field'] = [] + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] data_list.append(data_info) else: img_dir = self.data_prefix['img_path'] @@ -254,7 +259,8 @@ class CustomDataset(BaseDataset): seg_map = img.replace(self.img_suffix, self.seg_map_suffix) data_info['seg_map_path'] = osp.join(ann_dir, seg_map) data_info['label_map'] = self.label_map - data_info['seg_field'] = [] + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] data_list.append(data_info) data_list = sorted(data_list, key=lambda x: x['img_path']) return data_list diff --git a/mmseg/datasets/pipelines/loading.py b/mmseg/datasets/pipelines/loading.py index f18fd268f..fbdaaca31 100644 --- a/mmseg/datasets/pipelines/loading.py +++ b/mmseg/datasets/pipelines/loading.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings + import mmcv import numpy as np from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations @@ -40,9 +42,9 @@ class LoadAnnotations(MMCV_LoadAnnotations): - gt_seg_map (np.uint8) Args: - reduce_zero_label (bool): Whether reduce all label value by 1. - Usually used for datasets where 0 is background label. - Defaults to False. + reduce_zero_label (bool, optional): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to None. imdecode_backend (str): The image decoding backend type. The backend argument for :func:``mmcv.imfrombytes``. See :fun:``mmcv.imfrombytes`` for details. @@ -54,7 +56,7 @@ class LoadAnnotations(MMCV_LoadAnnotations): def __init__( self, - reduce_zero_label=False, + reduce_zero_label=None, file_client_args=dict(backend='disk'), imdecode_backend='pillow', ) -> None: @@ -66,6 +68,11 @@ class LoadAnnotations(MMCV_LoadAnnotations): imdecode_backend=imdecode_backend, file_client_args=file_client_args) self.reduce_zero_label = reduce_zero_label + if self.reduce_zero_label is not None: + warnings.warn('`reduce_zero_label` will be deprecated, ' + 'if you would like to ignore the zero label, please ' + 'set `reduce_zero_label=True` when dataset ' + 'initialized') self.file_client_args = file_client_args.copy() self.imdecode_backend = imdecode_backend @@ -93,6 +100,12 @@ class LoadAnnotations(MMCV_LoadAnnotations): for old_id, new_id in results['label_map'].items(): gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id # reduce zero_label + if self.reduce_zero_label is None: + self.reduce_zero_label = results['reduce_zero_label'] + assert self.reduce_zero_label == results['reduce_zero_label'], \ + 'Initialize dataset with `reduce_zero_label` as ' \ + f'{results["reduce_zero_label"]} but when load annotation ' \ + f'the `reduce_zero_label` is {self.reduce_zero_label}' if self.reduce_zero_label: # avoid using underflow conversion gt_semantic_seg[gt_semantic_seg == 0] = 255 diff --git a/tests/test_data/test_formatting.py b/tests/test_datasets/test_formatting.py similarity index 100% rename from tests/test_data/test_formatting.py rename to tests/test_datasets/test_formatting.py diff --git a/tests/test_data/test_loading.py b/tests/test_datasets/test_loading.py similarity index 92% rename from tests/test_data/test_loading.py rename to tests/test_datasets/test_loading.py index 6bf554259..937dcf6df 100644 --- a/tests/test_data/test_loading.py +++ b/tests/test_datasets/test_loading.py @@ -47,13 +47,14 @@ class TestLoading(object): def test_load_seg(self): seg_path = osp.join(self.data_prefix, 'seg.png') - results = dict(seg_map_path=seg_path, seg_fields=[]) + results = dict( + seg_map_path=seg_path, reduce_zero_label=True, seg_fields=[]) transform = LoadAnnotations() results = transform(copy.deepcopy(results)) assert results['gt_seg_map'].shape == (288, 512) assert results['gt_seg_map'].dtype == np.uint8 assert repr(transform) == transform.__class__.__name__ + \ - "(reduce_zero_label=False,imdecode_backend='pillow')" + \ + "(reduce_zero_label=True,imdecode_backend='pillow')" + \ "file_client_args={'backend': 'disk'})" # reduce_zero_label @@ -89,6 +90,7 @@ class TestLoading(object): 3: 1, 4: 0 }, + reduce_zero_label=False, seg_fields=[]) load_imgs = LoadImageFromFile() @@ -118,6 +120,7 @@ class TestLoading(object): 3: 2, 4: 1 }, + reduce_zero_label=False, seg_fields=[]) load_imgs = LoadImageFromFile() @@ -138,7 +141,11 @@ class TestLoading(object): np.testing.assert_array_equal(gt_array, true_mask) # test no custom classes - results = dict(img_path=img_path, seg_map_path=gt_path, seg_fields=[]) + results = dict( + img_path=img_path, + seg_map_path=gt_path, + reduce_zero_label=False, + seg_fields=[]) load_imgs = LoadImageFromFile() results = load_imgs(copy.deepcopy(results)) diff --git a/tests/test_datasets/test_pipelines/test_transforms.py b/tests/test_datasets/test_pipelines/test_transforms.py index 0321b0169..87c053b3f 100644 --- a/tests/test_datasets/test_pipelines/test_transforms.py +++ b/tests/test_datasets/test_pipelines/test_transforms.py @@ -49,7 +49,7 @@ def test_photo_metric_distortion(): results['pad_shape'] = img.shape results['scale_factor'] = 1.0 - pipeline = PhotoMetricDistortion() + pipeline = PhotoMetricDistortion(saturation_range=(1., 1.)) results = pipeline(results) assert not ((results['img'] == img).all()) diff --git a/tests/test_data/test_transform.py b/tests/test_datasets/test_transform.py similarity index 100% rename from tests/test_data/test_transform.py rename to tests/test_datasets/test_transform.py diff --git a/tests/test_data/test_tta.py b/tests/test_datasets/test_tta.py similarity index 100% rename from tests/test_data/test_tta.py rename to tests/test_datasets/test_tta.py