mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add Rgb2Gray transform (#227)
* add transformer Rgb2Gray * restore * fix self.weights * restore * fix code * restore * fix syntax error * restore
This commit is contained in:
parent
500babf958
commit
0588426eaa
@ -548,6 +548,61 @@ class RandomRotate(object):
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RGB2Gray(object):
|
||||
"""Convert RGB image to grayscale image.
|
||||
|
||||
This transform calculate the weighted mean of input image channels with
|
||||
``weights`` and then expand the channels to ``out_channels``. When
|
||||
``out_channels`` is None, the number of output channels is the same as
|
||||
input channels.
|
||||
|
||||
Args:
|
||||
out_channels (int): Expected number of output channels after
|
||||
transforming. Default: None.
|
||||
weights (tuple[float]): The weights to calculate the weighted mean.
|
||||
Default: (0.299, 0.587, 0.114).
|
||||
"""
|
||||
|
||||
def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
|
||||
assert out_channels is None or out_channels > 0
|
||||
self.out_channels = out_channels
|
||||
assert isinstance(weights, tuple)
|
||||
for item in weights:
|
||||
assert isinstance(item, (float, int))
|
||||
self.weights = weights
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to convert RGB image to grayscale image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Result dict with grayscale image.
|
||||
"""
|
||||
img = results['img']
|
||||
assert len(img.shape) == 3
|
||||
assert img.shape[2] == len(self.weights)
|
||||
weights = np.array(self.weights).reshape((1, 1, -1))
|
||||
img = (img * weights).sum(2, keepdims=True)
|
||||
if self.out_channels is None:
|
||||
img = img.repeat(weights.shape[2], axis=2)
|
||||
else:
|
||||
img = img.repeat(self.out_channels, axis=2)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(out_channels={self.out_channels}, ' \
|
||||
f'weights={self.weights})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class SegRescale(object):
|
||||
"""Rescale semantic segmentation maps.
|
||||
|
@ -263,6 +263,73 @@ def test_normalize():
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
|
||||
|
||||
def test_rgb2gray():
|
||||
# test assertion out_channels should be greater than 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RGB2Gray', out_channels=-1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion weights should be tuple[float]
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test out_channels is None
|
||||
transform = dict(type='RGB2Gray')
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
assert str(transform) == f'RGB2Gray(' \
|
||||
f'out_channels={None}, ' \
|
||||
f'weights={(0.299, 0.587, 0.114)})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, c = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == (h, w, c)
|
||||
assert results['img_shape'] == (h, w, c)
|
||||
assert results['ori_shape'] == (h, w, c)
|
||||
|
||||
# test out_channels = 2
|
||||
transform = dict(type='RGB2Gray', out_channels=2)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
assert str(transform) == f'RGB2Gray(' \
|
||||
f'out_channels={2}, ' \
|
||||
f'weights={(0.299, 0.587, 0.114)})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, c = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == (h, w, 2)
|
||||
assert results['img_shape'] == (h, w, 2)
|
||||
assert results['ori_shape'] == (h, w, c)
|
||||
|
||||
|
||||
def test_seg_rescale():
|
||||
results = dict()
|
||||
seg = np.array(
|
||||
|
Loading…
x
Reference in New Issue
Block a user