Add RandomFlip

pull/2/head
yanglei 2020-07-08 23:54:49 +08:00 committed by yl-1993
parent 45812e87bd
commit 6968ad5b3b
4 changed files with 123 additions and 3 deletions

View File

@ -52,7 +52,7 @@ At the end of the pipeline, we use `Collect` to only retain the necessary items
- update: img, img_shape
`RandomFlip`
- add: flip
- add: flip, flip_direction
- update: img

View File

@ -2,11 +2,12 @@ from .compose import Compose
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
Transpose, to_tensor)
from .loading import LoadImageFromFile
from .transforms import (CenterCrop, RandomCrop, RandomGrayscale,
from .transforms import (CenterCrop, RandomCrop, RandomFlip, RandomGrayscale,
RandomResizedCrop, Resize)
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
'Normalize', 'RandomCrop', 'RandomResizedCrop', 'RandomGrayscale'
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
'RandomGrayscale'
]

View File

@ -280,6 +280,48 @@ class RandomGrayscale(object):
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
@PIPELINES.register_module()
class RandomFlip(object):
"""Flip the image randomly.
Flip the image randomly based on flip probaility and flip direction.
Args:
flip_prob (float): probability of the image being flipped. Default: 0.5
direction (str, optional): The flipping direction. Options are
'horizontal' and 'vertical'. Default: 'horizontal'.
"""
def __init__(self, flip_prob=0.5, direction='horizontal'):
assert 0 <= flip_prob <= 1
assert direction in ['horizontal', 'vertical']
self.flip_prob = flip_prob
self.direction = direction
def __call__(self, results):
"""Call function to flip image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'flip', 'flip_direction' keys are added into
result dict.
"""
flip = True if np.random.rand() < self.flip_prob else False
results['flip'] = flip
results['flip_direction'] = self.direction
if results['flip']:
# flip image
for key in results.get('img_fields', ['img']):
results[key] = mmcv.imflip(
results[key], direction=results['flip_direction'])
return results
def __repr__(self):
return self.__class__.__name__ + f'(flip_prob={self.flip_prob})'
@PIPELINES.register_module()
class Resize(object):
"""Resize images.

View File

@ -668,3 +668,80 @@ def test_randomgrayscale():
img_pil = composed_transform(in_img_pil)
assert_array_equal(np.array(img_pil), np.array(in_img_pil))
assert np.array(img_pil).shape == (10, 10)
def test_randomflip():
# test assertion if flip probability is smaller than 0
with pytest.raises(AssertionError):
transform = dict(type='RandomFlip', flip_prob=-1)
build_from_cfg(transform, PIPELINES)
# test assertion if flip probability is larger than 1
with pytest.raises(AssertionError):
transform = dict(type='RandomFlip', flip_prob=2)
build_from_cfg(transform, PIPELINES)
# test assertion if direction is not horizontal and vertical
with pytest.raises(AssertionError):
transform = dict(type='RandomFlip', direction='random')
build_from_cfg(transform, PIPELINES)
# test assertion if direction is not lowercase
with pytest.raises(AssertionError):
transform = dict(type='RandomFlip', direction='Horizontal')
build_from_cfg(transform, PIPELINES)
# read test image
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img2'] = copy.deepcopy(img)
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img', 'img2']
def reset_results(results, original_img):
results['img'] = copy.deepcopy(original_img)
results['img2'] = copy.deepcopy(original_img)
results['img_shape'] = original_img.shape
results['ori_shape'] = original_img.shape
return results
# test RandomFlip when flip_prob is 0
transform = dict(type='RandomFlip', flip_prob=0)
flip_module = build_from_cfg(transform, PIPELINES)
results = flip_module(results)
assert np.equal(results['img'], original_img).all()
assert np.equal(results['img'], results['img2']).all()
# test RandomFlip when flip_prob is 1
transform = dict(type='RandomFlip', flip_prob=1)
flip_module = build_from_cfg(transform, PIPELINES)
results = flip_module(results)
assert np.equal(results['img'], results['img2']).all()
# compare hotizontal flip with torchvision
transform = dict(type='RandomFlip', flip_prob=1, direction='horizontal')
flip_module = build_from_cfg(transform, PIPELINES)
results = reset_results(results, original_img)
results = flip_module(results)
flip_module = transforms.RandomHorizontalFlip(p=1)
pil_img = Image.fromarray(original_img)
flipped_img = flip_module(pil_img)
flipped_img = np.array(flipped_img)
assert np.equal(results['img'], results['img2']).all()
assert np.equal(results['img'], flipped_img).all()
# compare vertical flip with torchvision
transform = dict(type='RandomFlip', flip_prob=1, direction='vertical')
flip_module = build_from_cfg(transform, PIPELINES)
results = reset_results(results, original_img)
results = flip_module(results)
flip_module = transforms.RandomVerticalFlip(p=1)
pil_img = Image.fromarray(original_img)
flipped_img = flip_module(pil_img)
flipped_img = np.array(flipped_img)
assert np.equal(results['img'], results['img2']).all()
assert np.equal(results['img'], flipped_img).all()