[Fix]Add label_map and reduce_zero_label in metainfo of dataset and deprecate reduce_zero_label in load annotation

This commit is contained in:
zhengmiao 2022-06-09 12:23:36 +00:00
parent 5e7d7626a8
commit f59ef99b00
7 changed files with 36 additions and 10 deletions

View File

@ -122,6 +122,10 @@ class CustomDataset(BaseDataset):
# Get label map for custom classes
new_classes = self._metainfo.get('classes', None)
self.label_map = self.get_label_map(new_classes)
self._metainfo.update(
dict(
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label))
# Update palette based on label map or generate palette
# if it is not defined
@ -240,7 +244,8 @@ class CustomDataset(BaseDataset):
seg_map = img_name + self.seg_map_suffix
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_info['seg_field'] = []
data_info['reduce_zero_label'] = self.reduce_zero_label
data_info['seg_fields'] = []
data_list.append(data_info)
else:
img_dir = self.data_prefix['img_path']
@ -254,7 +259,8 @@ class CustomDataset(BaseDataset):
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_info['seg_field'] = []
data_info['reduce_zero_label'] = self.reduce_zero_label
data_info['seg_fields'] = []
data_list.append(data_info)
data_list = sorted(data_list, key=lambda x: x['img_path'])
return data_list

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import mmcv
import numpy as np
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
@ -40,9 +42,9 @@ class LoadAnnotations(MMCV_LoadAnnotations):
- gt_seg_map (np.uint8)
Args:
reduce_zero_label (bool): Whether reduce all label value by 1.
Usually used for datasets where 0 is background label.
Defaults to False.
reduce_zero_label (bool, optional): Whether reduce all label value
by 1. Usually used for datasets where 0 is background label.
Defaults to None.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :fun:``mmcv.imfrombytes`` for details.
@ -54,7 +56,7 @@ class LoadAnnotations(MMCV_LoadAnnotations):
def __init__(
self,
reduce_zero_label=False,
reduce_zero_label=None,
file_client_args=dict(backend='disk'),
imdecode_backend='pillow',
) -> None:
@ -66,6 +68,11 @@ class LoadAnnotations(MMCV_LoadAnnotations):
imdecode_backend=imdecode_backend,
file_client_args=file_client_args)
self.reduce_zero_label = reduce_zero_label
if self.reduce_zero_label is not None:
warnings.warn('`reduce_zero_label` will be deprecated, '
'if you would like to ignore the zero label, please '
'set `reduce_zero_label=True` when dataset '
'initialized')
self.file_client_args = file_client_args.copy()
self.imdecode_backend = imdecode_backend
@ -93,6 +100,12 @@ class LoadAnnotations(MMCV_LoadAnnotations):
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
# reduce zero_label
if self.reduce_zero_label is None:
self.reduce_zero_label = results['reduce_zero_label']
assert self.reduce_zero_label == results['reduce_zero_label'], \
'Initialize dataset with `reduce_zero_label` as ' \
f'{results["reduce_zero_label"]} but when load annotation ' \
f'the `reduce_zero_label` is {self.reduce_zero_label}'
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255

View File

@ -47,13 +47,14 @@ class TestLoading(object):
def test_load_seg(self):
seg_path = osp.join(self.data_prefix, 'seg.png')
results = dict(seg_map_path=seg_path, seg_fields=[])
results = dict(
seg_map_path=seg_path, reduce_zero_label=True, seg_fields=[])
transform = LoadAnnotations()
results = transform(copy.deepcopy(results))
assert results['gt_seg_map'].shape == (288, 512)
assert results['gt_seg_map'].dtype == np.uint8
assert repr(transform) == transform.__class__.__name__ + \
"(reduce_zero_label=False,imdecode_backend='pillow')" + \
"(reduce_zero_label=True,imdecode_backend='pillow')" + \
"file_client_args={'backend': 'disk'})"
# reduce_zero_label
@ -89,6 +90,7 @@ class TestLoading(object):
3: 1,
4: 0
},
reduce_zero_label=False,
seg_fields=[])
load_imgs = LoadImageFromFile()
@ -118,6 +120,7 @@ class TestLoading(object):
3: 2,
4: 1
},
reduce_zero_label=False,
seg_fields=[])
load_imgs = LoadImageFromFile()
@ -138,7 +141,11 @@ class TestLoading(object):
np.testing.assert_array_equal(gt_array, true_mask)
# test no custom classes
results = dict(img_path=img_path, seg_map_path=gt_path, seg_fields=[])
results = dict(
img_path=img_path,
seg_map_path=gt_path,
reduce_zero_label=False,
seg_fields=[])
load_imgs = LoadImageFromFile()
results = load_imgs(copy.deepcopy(results))

View File

@ -49,7 +49,7 @@ def test_photo_metric_distortion():
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
pipeline = PhotoMetricDistortion()
pipeline = PhotoMetricDistortion(saturation_range=(1., 1.))
results = pipeline(results)
assert not ((results['img'] == img).all())