mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix]Add label_map and reduce_zero_label in metainfo of dataset and deprecate reduce_zero_label in load annotation
This commit is contained in:
parent
5e7d7626a8
commit
f59ef99b00
@ -122,6 +122,10 @@ class CustomDataset(BaseDataset):
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
@ -240,7 +244,8 @@ class CustomDataset(BaseDataset):
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['seg_field'] = []
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
img_dir = self.data_prefix['img_path']
|
||||
@ -254,7 +259,8 @@ class CustomDataset(BaseDataset):
|
||||
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['seg_field'] = []
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
@ -40,9 +42,9 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool): Whether reduce all label value by 1.
|
||||
Usually used for datasets where 0 is background label.
|
||||
Defaults to False.
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
@ -54,7 +56,7 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=False,
|
||||
reduce_zero_label=None,
|
||||
file_client_args=dict(backend='disk'),
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
@ -66,6 +68,11 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
imdecode_backend=imdecode_backend,
|
||||
file_client_args=file_client_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.imdecode_backend = imdecode_backend
|
||||
|
||||
@ -93,6 +100,12 @@ class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
|
@ -47,13 +47,14 @@ class TestLoading(object):
|
||||
|
||||
def test_load_seg(self):
|
||||
seg_path = osp.join(self.data_prefix, 'seg.png')
|
||||
results = dict(seg_map_path=seg_path, seg_fields=[])
|
||||
results = dict(
|
||||
seg_map_path=seg_path, reduce_zero_label=True, seg_fields=[])
|
||||
transform = LoadAnnotations()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['gt_seg_map'].shape == (288, 512)
|
||||
assert results['gt_seg_map'].dtype == np.uint8
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(reduce_zero_label=False,imdecode_backend='pillow')" + \
|
||||
"(reduce_zero_label=True,imdecode_backend='pillow')" + \
|
||||
"file_client_args={'backend': 'disk'})"
|
||||
|
||||
# reduce_zero_label
|
||||
@ -89,6 +90,7 @@ class TestLoading(object):
|
||||
3: 1,
|
||||
4: 0
|
||||
},
|
||||
reduce_zero_label=False,
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
@ -118,6 +120,7 @@ class TestLoading(object):
|
||||
3: 2,
|
||||
4: 1
|
||||
},
|
||||
reduce_zero_label=False,
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
@ -138,7 +141,11 @@ class TestLoading(object):
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test no custom classes
|
||||
results = dict(img_path=img_path, seg_map_path=gt_path, seg_fields=[])
|
||||
results = dict(
|
||||
img_path=img_path,
|
||||
seg_map_path=gt_path,
|
||||
reduce_zero_label=False,
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
@ -49,7 +49,7 @@ def test_photo_metric_distortion():
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
pipeline = PhotoMetricDistortion()
|
||||
pipeline = PhotoMetricDistortion(saturation_range=(1., 1.))
|
||||
results = pipeline(results)
|
||||
|
||||
assert not ((results['img'] == img).all())
|
||||
|
Loading…
x
Reference in New Issue
Block a user