add invert pipeline (#168)

This commit is contained in:
LXXXXR 2021-03-02 16:46:57 +08:00 committed by GitHub
parent b1fa298a66
commit c8033ece8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 2 deletions

View File

@ -1,4 +1,4 @@
from .auto_augment import Rotate, Shear, Translate from .auto_augment import Invert, Rotate, Shear, Translate
from .compose import Compose from .compose import Compose
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor, from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
Transpose, to_tensor) Transpose, to_tensor)
@ -10,5 +10,5 @@ __all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy', 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop', 'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop', 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
'RandomGrayscale', 'Shear', 'Translate', 'Rotate' 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert'
] ]

View File

@ -89,6 +89,7 @@ class Shear(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class Translate(object): class Translate(object):
"""Translate images. """Translate images.
Args: Args:
magnitude (int | float): The magnitude used for translate. Note that magnitude (int | float): The magnitude used for translate. Note that
the offset is calculated by magnitude * size in the corresponding the offset is calculated by magnitude * size in the corresponding
@ -258,3 +259,32 @@ class Rotate(object):
repr_str += f'random_negative_prob={self.random_negative_prob}, ' repr_str += f'random_negative_prob={self.random_negative_prob}, '
repr_str += f'interpolation={self.interpolation})' repr_str += f'interpolation={self.interpolation})'
return repr_str return repr_str
@PIPELINES.register_module()
class Invert(object):
"""Invert images.
Args:
prob (float): The probability for performing invert therefore should
be in range [0, 1]. Defaults to 0.5.
"""
def __init__(self, prob=0.5):
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.prob = prob
def __call__(self, results):
if np.random.rand() > self.prob:
return results
for key in results.get('img_fields', ['img']):
img = results[key]
img_inverted = mmcv.iminvert(img)
results[key] = img_inverted.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob})'
return repr_str

View File

@ -308,3 +308,30 @@ def test_rotate():
rotated_img = np.stack([rotated_img, rotated_img, rotated_img], axis=-1) rotated_img = np.stack([rotated_img, rotated_img, rotated_img], axis=-1)
assert (results['img'] == rotated_img).all() assert (results['img'] == rotated_img).all()
assert (results['img'] == results['img2']).all() assert (results['img'] == results['img2']).all()
def test_invert():
# test assertion for invalid value of prob
with pytest.raises(AssertionError):
transform = dict(type='Invert', prob=100)
build_from_cfg(transform, PIPELINES)
# test case when prob=0, therefore no invert
results = construct_toy_data()
transform = dict(type='Invert', prob=0.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()
# test case when prob=1
results = construct_toy_data()
transform = dict(type='Invert', prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
inverted_img = np.array(
[[254, 253, 252, 251], [250, 249, 248, 247], [246, 245, 244, 243]],
dtype=np.uint8)
inverted_img = np.stack([inverted_img, inverted_img, inverted_img],
axis=-1)
assert (results['img'] == inverted_img).all()
assert (results['img'] == results['img2']).all()