mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix]Add label_map and reduce_zero_label in metainfo of dataset and deprecate reduce_zero_label in load annotation
This commit is contained in:
parent
5e7d7626a8
commit
f59ef99b00
@ -122,6 +122,10 @@ class CustomDataset(BaseDataset):
|
|||||||
# Get label map for custom classes
|
# Get label map for custom classes
|
||||||
new_classes = self._metainfo.get('classes', None)
|
new_classes = self._metainfo.get('classes', None)
|
||||||
self.label_map = self.get_label_map(new_classes)
|
self.label_map = self.get_label_map(new_classes)
|
||||||
|
self._metainfo.update(
|
||||||
|
dict(
|
||||||
|
label_map=self.label_map,
|
||||||
|
reduce_zero_label=self.reduce_zero_label))
|
||||||
|
|
||||||
# Update palette based on label map or generate palette
|
# Update palette based on label map or generate palette
|
||||||
# if it is not defined
|
# if it is not defined
|
||||||
@ -240,7 +244,8 @@ class CustomDataset(BaseDataset):
|
|||||||
seg_map = img_name + self.seg_map_suffix
|
seg_map = img_name + self.seg_map_suffix
|
||||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||||
data_info['label_map'] = self.label_map
|
data_info['label_map'] = self.label_map
|
||||||
data_info['seg_field'] = []
|
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||||
|
data_info['seg_fields'] = []
|
||||||
data_list.append(data_info)
|
data_list.append(data_info)
|
||||||
else:
|
else:
|
||||||
img_dir = self.data_prefix['img_path']
|
img_dir = self.data_prefix['img_path']
|
||||||
@ -254,7 +259,8 @@ class CustomDataset(BaseDataset):
|
|||||||
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
|
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
|
||||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||||
data_info['label_map'] = self.label_map
|
data_info['label_map'] = self.label_map
|
||||||
data_info['seg_field'] = []
|
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||||
|
data_info['seg_fields'] = []
|
||||||
data_list.append(data_info)
|
data_list.append(data_info)
|
||||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||||
return data_list
|
return data_list
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import warnings
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||||
@ -40,9 +42,9 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
|||||||
- gt_seg_map (np.uint8)
|
- gt_seg_map (np.uint8)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reduce_zero_label (bool): Whether reduce all label value by 1.
|
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||||
Usually used for datasets where 0 is background label.
|
by 1. Usually used for datasets where 0 is background label.
|
||||||
Defaults to False.
|
Defaults to None.
|
||||||
imdecode_backend (str): The image decoding backend type. The backend
|
imdecode_backend (str): The image decoding backend type. The backend
|
||||||
argument for :func:``mmcv.imfrombytes``.
|
argument for :func:``mmcv.imfrombytes``.
|
||||||
See :fun:``mmcv.imfrombytes`` for details.
|
See :fun:``mmcv.imfrombytes`` for details.
|
||||||
@ -54,7 +56,7 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
reduce_zero_label=False,
|
reduce_zero_label=None,
|
||||||
file_client_args=dict(backend='disk'),
|
file_client_args=dict(backend='disk'),
|
||||||
imdecode_backend='pillow',
|
imdecode_backend='pillow',
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -66,6 +68,11 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
|||||||
imdecode_backend=imdecode_backend,
|
imdecode_backend=imdecode_backend,
|
||||||
file_client_args=file_client_args)
|
file_client_args=file_client_args)
|
||||||
self.reduce_zero_label = reduce_zero_label
|
self.reduce_zero_label = reduce_zero_label
|
||||||
|
if self.reduce_zero_label is not None:
|
||||||
|
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||||
|
'if you would like to ignore the zero label, please '
|
||||||
|
'set `reduce_zero_label=True` when dataset '
|
||||||
|
'initialized')
|
||||||
self.file_client_args = file_client_args.copy()
|
self.file_client_args = file_client_args.copy()
|
||||||
self.imdecode_backend = imdecode_backend
|
self.imdecode_backend = imdecode_backend
|
||||||
|
|
||||||
@ -93,6 +100,12 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
|||||||
for old_id, new_id in results['label_map'].items():
|
for old_id, new_id in results['label_map'].items():
|
||||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||||
# reduce zero_label
|
# reduce zero_label
|
||||||
|
if self.reduce_zero_label is None:
|
||||||
|
self.reduce_zero_label = results['reduce_zero_label']
|
||||||
|
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||||
|
'Initialize dataset with `reduce_zero_label` as ' \
|
||||||
|
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||||
|
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||||
if self.reduce_zero_label:
|
if self.reduce_zero_label:
|
||||||
# avoid using underflow conversion
|
# avoid using underflow conversion
|
||||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||||
|
@ -47,13 +47,14 @@ class TestLoading(object):
|
|||||||
|
|
||||||
def test_load_seg(self):
|
def test_load_seg(self):
|
||||||
seg_path = osp.join(self.data_prefix, 'seg.png')
|
seg_path = osp.join(self.data_prefix, 'seg.png')
|
||||||
results = dict(seg_map_path=seg_path, seg_fields=[])
|
results = dict(
|
||||||
|
seg_map_path=seg_path, reduce_zero_label=True, seg_fields=[])
|
||||||
transform = LoadAnnotations()
|
transform = LoadAnnotations()
|
||||||
results = transform(copy.deepcopy(results))
|
results = transform(copy.deepcopy(results))
|
||||||
assert results['gt_seg_map'].shape == (288, 512)
|
assert results['gt_seg_map'].shape == (288, 512)
|
||||||
assert results['gt_seg_map'].dtype == np.uint8
|
assert results['gt_seg_map'].dtype == np.uint8
|
||||||
assert repr(transform) == transform.__class__.__name__ + \
|
assert repr(transform) == transform.__class__.__name__ + \
|
||||||
"(reduce_zero_label=False,imdecode_backend='pillow')" + \
|
"(reduce_zero_label=True,imdecode_backend='pillow')" + \
|
||||||
"file_client_args={'backend': 'disk'})"
|
"file_client_args={'backend': 'disk'})"
|
||||||
|
|
||||||
# reduce_zero_label
|
# reduce_zero_label
|
||||||
@ -89,6 +90,7 @@ class TestLoading(object):
|
|||||||
3: 1,
|
3: 1,
|
||||||
4: 0
|
4: 0
|
||||||
},
|
},
|
||||||
|
reduce_zero_label=False,
|
||||||
seg_fields=[])
|
seg_fields=[])
|
||||||
|
|
||||||
load_imgs = LoadImageFromFile()
|
load_imgs = LoadImageFromFile()
|
||||||
@ -118,6 +120,7 @@ class TestLoading(object):
|
|||||||
3: 2,
|
3: 2,
|
||||||
4: 1
|
4: 1
|
||||||
},
|
},
|
||||||
|
reduce_zero_label=False,
|
||||||
seg_fields=[])
|
seg_fields=[])
|
||||||
|
|
||||||
load_imgs = LoadImageFromFile()
|
load_imgs = LoadImageFromFile()
|
||||||
@ -138,7 +141,11 @@ class TestLoading(object):
|
|||||||
np.testing.assert_array_equal(gt_array, true_mask)
|
np.testing.assert_array_equal(gt_array, true_mask)
|
||||||
|
|
||||||
# test no custom classes
|
# test no custom classes
|
||||||
results = dict(img_path=img_path, seg_map_path=gt_path, seg_fields=[])
|
results = dict(
|
||||||
|
img_path=img_path,
|
||||||
|
seg_map_path=gt_path,
|
||||||
|
reduce_zero_label=False,
|
||||||
|
seg_fields=[])
|
||||||
|
|
||||||
load_imgs = LoadImageFromFile()
|
load_imgs = LoadImageFromFile()
|
||||||
results = load_imgs(copy.deepcopy(results))
|
results = load_imgs(copy.deepcopy(results))
|
@ -49,7 +49,7 @@ def test_photo_metric_distortion():
|
|||||||
results['pad_shape'] = img.shape
|
results['pad_shape'] = img.shape
|
||||||
results['scale_factor'] = 1.0
|
results['scale_factor'] = 1.0
|
||||||
|
|
||||||
pipeline = PhotoMetricDistortion()
|
pipeline = PhotoMetricDistortion(saturation_range=(1., 1.))
|
||||||
results = pipeline(results)
|
results = pipeline(results)
|
||||||
|
|
||||||
assert not ((results['img'] == img).all())
|
assert not ((results['img'] == img).all())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user