From 6968ad5b3bda9d0f302842b9781cead035eaa19e Mon Sep 17 00:00:00 2001 From: yanglei Date: Wed, 8 Jul 2020 23:54:49 +0800 Subject: [PATCH] Add RandomFlip --- docs/tutorials/data_pipeline.md | 2 +- mmcls/datasets/pipelines/__init__.py | 5 +- mmcls/datasets/pipelines/transforms.py | 42 ++++++++++++++ tests/test_pipelines/test_transform.py | 77 ++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/data_pipeline.md b/docs/tutorials/data_pipeline.md index 1e50a538..e274cd63 100644 --- a/docs/tutorials/data_pipeline.md +++ b/docs/tutorials/data_pipeline.md @@ -52,7 +52,7 @@ At the end of the pipeline, we use `Collect` to only retain the necessary items - update: img, img_shape `RandomFlip` -- add: flip +- add: flip, flip_direction - update: img diff --git a/mmcls/datasets/pipelines/__init__.py b/mmcls/datasets/pipelines/__init__.py index 8ee80335..78df8a3a 100644 --- a/mmcls/datasets/pipelines/__init__.py +++ b/mmcls/datasets/pipelines/__init__.py @@ -2,11 +2,12 @@ from .compose import Compose from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor, Transpose, to_tensor) from .loading import LoadImageFromFile -from .transforms import (CenterCrop, RandomCrop, RandomGrayscale, +from .transforms import (CenterCrop, RandomCrop, RandomFlip, RandomGrayscale, RandomResizedCrop, Resize) __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop', - 'Normalize', 'RandomCrop', 'RandomResizedCrop', 'RandomGrayscale' + 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop', + 'RandomGrayscale' ] diff --git a/mmcls/datasets/pipelines/transforms.py b/mmcls/datasets/pipelines/transforms.py index 5bbc8365..08a9e410 100644 --- a/mmcls/datasets/pipelines/transforms.py +++ b/mmcls/datasets/pipelines/transforms.py @@ -280,6 +280,48 @@ class RandomGrayscale(object): return self.__class__.__name__ + f'(gray_prob={self.gray_prob})' +@PIPELINES.register_module() +class RandomFlip(object): + """Flip the image randomly. + + Flip the image randomly based on flip probaility and flip direction. + + Args: + flip_prob (float): probability of the image being flipped. Default: 0.5 + direction (str, optional): The flipping direction. Options are + 'horizontal' and 'vertical'. Default: 'horizontal'. + """ + + def __init__(self, flip_prob=0.5, direction='horizontal'): + assert 0 <= flip_prob <= 1 + assert direction in ['horizontal', 'vertical'] + self.flip_prob = flip_prob + self.direction = direction + + def __call__(self, results): + """Call function to flip image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added into + result dict. + """ + flip = True if np.random.rand() < self.flip_prob else False + results['flip'] = flip + results['flip_direction'] = self.direction + if results['flip']: + # flip image + for key in results.get('img_fields', ['img']): + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(flip_prob={self.flip_prob})' + + @PIPELINES.register_module() class Resize(object): """Resize images. diff --git a/tests/test_pipelines/test_transform.py b/tests/test_pipelines/test_transform.py index 04f27523..ca752599 100644 --- a/tests/test_pipelines/test_transform.py +++ b/tests/test_pipelines/test_transform.py @@ -668,3 +668,80 @@ def test_randomgrayscale(): img_pil = composed_transform(in_img_pil) assert_array_equal(np.array(img_pil), np.array(in_img_pil)) assert np.array(img_pil).shape == (10, 10) + + +def test_randomflip(): + # test assertion if flip probability is smaller than 0 + with pytest.raises(AssertionError): + transform = dict(type='RandomFlip', flip_prob=-1) + build_from_cfg(transform, PIPELINES) + + # test assertion if flip probability is larger than 1 + with pytest.raises(AssertionError): + transform = dict(type='RandomFlip', flip_prob=2) + build_from_cfg(transform, PIPELINES) + + # test assertion if direction is not horizontal and vertical + with pytest.raises(AssertionError): + transform = dict(type='RandomFlip', direction='random') + build_from_cfg(transform, PIPELINES) + + # test assertion if direction is not lowercase + with pytest.raises(AssertionError): + transform = dict(type='RandomFlip', direction='Horizontal') + build_from_cfg(transform, PIPELINES) + + # read test image + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + original_img = copy.deepcopy(img) + results['img'] = img + results['img2'] = copy.deepcopy(img) + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img', 'img2'] + + def reset_results(results, original_img): + results['img'] = copy.deepcopy(original_img) + results['img2'] = copy.deepcopy(original_img) + results['img_shape'] = original_img.shape + results['ori_shape'] = original_img.shape + return results + + # test RandomFlip when flip_prob is 0 + transform = dict(type='RandomFlip', flip_prob=0) + flip_module = build_from_cfg(transform, PIPELINES) + results = flip_module(results) + assert np.equal(results['img'], original_img).all() + assert np.equal(results['img'], results['img2']).all() + + # test RandomFlip when flip_prob is 1 + transform = dict(type='RandomFlip', flip_prob=1) + flip_module = build_from_cfg(transform, PIPELINES) + results = flip_module(results) + assert np.equal(results['img'], results['img2']).all() + + # compare hotizontal flip with torchvision + transform = dict(type='RandomFlip', flip_prob=1, direction='horizontal') + flip_module = build_from_cfg(transform, PIPELINES) + results = reset_results(results, original_img) + results = flip_module(results) + flip_module = transforms.RandomHorizontalFlip(p=1) + pil_img = Image.fromarray(original_img) + flipped_img = flip_module(pil_img) + flipped_img = np.array(flipped_img) + assert np.equal(results['img'], results['img2']).all() + assert np.equal(results['img'], flipped_img).all() + + # compare vertical flip with torchvision + transform = dict(type='RandomFlip', flip_prob=1, direction='vertical') + flip_module = build_from_cfg(transform, PIPELINES) + results = reset_results(results, original_img) + results = flip_module(results) + flip_module = transforms.RandomVerticalFlip(p=1) + pil_img = Image.fromarray(original_img) + flipped_img = flip_module(pil_img) + flipped_img = np.array(flipped_img) + assert np.equal(results['img'], results['img2']).all() + assert np.equal(results['img'], flipped_img).all()