diff --git a/mmcls/datasets/pipelines/__init__.py b/mmcls/datasets/pipelines/__init__.py index f84cddb0c..424b3f507 100644 --- a/mmcls/datasets/pipelines/__init__.py +++ b/mmcls/datasets/pipelines/__init__.py @@ -1,4 +1,4 @@ -from .auto_augment import Rotate, Shear, Translate +from .auto_augment import Invert, Rotate, Shear, Translate from .compose import Compose from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor, Transpose, to_tensor) @@ -10,5 +10,5 @@ __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop', 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop', - 'RandomGrayscale', 'Shear', 'Translate', 'Rotate' + 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert' ] diff --git a/mmcls/datasets/pipelines/auto_augment.py b/mmcls/datasets/pipelines/auto_augment.py index 9e357a2d2..8ee7388f2 100644 --- a/mmcls/datasets/pipelines/auto_augment.py +++ b/mmcls/datasets/pipelines/auto_augment.py @@ -89,6 +89,7 @@ class Shear(object): @PIPELINES.register_module() class Translate(object): """Translate images. + Args: magnitude (int | float): The magnitude used for translate. Note that the offset is calculated by magnitude * size in the corresponding @@ -258,3 +259,32 @@ class Rotate(object): repr_str += f'random_negative_prob={self.random_negative_prob}, ' repr_str += f'interpolation={self.interpolation})' return repr_str + + +@PIPELINES.register_module() +class Invert(object): + """Invert images. + Args: + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 0.5. + """ + + def __init__(self, prob=0.5): + assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ + f'got {prob} instead.' + + self.prob = prob + + def __call__(self, results): + if np.random.rand() > self.prob: + return results + for key in results.get('img_fields', ['img']): + img = results[key] + img_inverted = mmcv.iminvert(img) + results[key] = img_inverted.astype(img.dtype) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str diff --git a/tests/test_pipelines/test_auto_augment.py b/tests/test_pipelines/test_auto_augment.py index 23cf765b5..c84bf5a8c 100644 --- a/tests/test_pipelines/test_auto_augment.py +++ b/tests/test_pipelines/test_auto_augment.py @@ -308,3 +308,30 @@ def test_rotate(): rotated_img = np.stack([rotated_img, rotated_img, rotated_img], axis=-1) assert (results['img'] == rotated_img).all() assert (results['img'] == results['img2']).all() + + +def test_invert(): + # test assertion for invalid value of prob + with pytest.raises(AssertionError): + transform = dict(type='Invert', prob=100) + build_from_cfg(transform, PIPELINES) + + # test case when prob=0, therefore no invert + results = construct_toy_data() + transform = dict(type='Invert', prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=1 + results = construct_toy_data() + transform = dict(type='Invert', prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + inverted_img = np.array( + [[254, 253, 252, 251], [250, 249, 248, 247], [246, 245, 244, 243]], + dtype=np.uint8) + inverted_img = np.stack([inverted_img, inverted_img, inverted_img], + axis=-1) + assert (results['img'] == inverted_img).all() + assert (results['img'] == results['img2']).all()