add invert pipeline (#168)

This commit is contained in:
LXXXXR 2021-03-02 16:46:57 +08:00 committed by GitHub
parent b1fa298a66
commit c8033ece8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 2 deletions

View File

@ -1,4 +1,4 @@
from .auto_augment import Rotate, Shear, Translate
from .auto_augment import Invert, Rotate, Shear, Translate
from .compose import Compose
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
Transpose, to_tensor)
@ -10,5 +10,5 @@ __all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
'RandomGrayscale', 'Shear', 'Translate', 'Rotate'
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert'
]

View File

@ -89,6 +89,7 @@ class Shear(object):
@PIPELINES.register_module()
class Translate(object):
"""Translate images.
Args:
magnitude (int | float): The magnitude used for translate. Note that
the offset is calculated by magnitude * size in the corresponding
@ -258,3 +259,32 @@ class Rotate(object):
repr_str += f'random_negative_prob={self.random_negative_prob}, '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@PIPELINES.register_module()
class Invert(object):
"""Invert images.
Args:
prob (float): The probability for performing invert therefore should
be in range [0, 1]. Defaults to 0.5.
"""
def __init__(self, prob=0.5):
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.prob = prob
def __call__(self, results):
if np.random.rand() > self.prob:
return results
for key in results.get('img_fields', ['img']):
img = results[key]
img_inverted = mmcv.iminvert(img)
results[key] = img_inverted.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob})'
return repr_str

View File

@ -308,3 +308,30 @@ def test_rotate():
rotated_img = np.stack([rotated_img, rotated_img, rotated_img], axis=-1)
assert (results['img'] == rotated_img).all()
assert (results['img'] == results['img2']).all()
def test_invert():
# test assertion for invalid value of prob
with pytest.raises(AssertionError):
transform = dict(type='Invert', prob=100)
build_from_cfg(transform, PIPELINES)
# test case when prob=0, therefore no invert
results = construct_toy_data()
transform = dict(type='Invert', prob=0.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()
# test case when prob=1
results = construct_toy_data()
transform = dict(type='Invert', prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
inverted_img = np.array(
[[254, 253, 252, 251], [250, 249, 248, 247], [246, 245, 244, 243]],
dtype=np.uint8)
inverted_img = np.stack([inverted_img, inverted_img, inverted_img],
axis=-1)
assert (results['img'] == inverted_img).all()
assert (results['img'] == results['img2']).all()