Add RandomFlip
parent
45812e87bd
commit
6968ad5b3b
|
@ -52,7 +52,7 @@ At the end of the pipeline, we use `Collect` to only retain the necessary items
|
|||
- update: img, img_shape
|
||||
|
||||
`RandomFlip`
|
||||
- add: flip
|
||||
- add: flip, flip_direction
|
||||
- update: img
|
||||
|
||||
|
||||
|
|
|
@ -2,11 +2,12 @@ from .compose import Compose
|
|||
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
||||
Transpose, to_tensor)
|
||||
from .loading import LoadImageFromFile
|
||||
from .transforms import (CenterCrop, RandomCrop, RandomGrayscale,
|
||||
from .transforms import (CenterCrop, RandomCrop, RandomFlip, RandomGrayscale,
|
||||
RandomResizedCrop, Resize)
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
||||
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
||||
'Normalize', 'RandomCrop', 'RandomResizedCrop', 'RandomGrayscale'
|
||||
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||
'RandomGrayscale'
|
||||
]
|
||||
|
|
|
@ -280,6 +280,48 @@ class RandomGrayscale(object):
|
|||
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomFlip(object):
|
||||
"""Flip the image randomly.
|
||||
|
||||
Flip the image randomly based on flip probaility and flip direction.
|
||||
|
||||
Args:
|
||||
flip_prob (float): probability of the image being flipped. Default: 0.5
|
||||
direction (str, optional): The flipping direction. Options are
|
||||
'horizontal' and 'vertical'. Default: 'horizontal'.
|
||||
"""
|
||||
|
||||
def __init__(self, flip_prob=0.5, direction='horizontal'):
|
||||
assert 0 <= flip_prob <= 1
|
||||
assert direction in ['horizontal', 'vertical']
|
||||
self.flip_prob = flip_prob
|
||||
self.direction = direction
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to flip image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Flipped results, 'flip', 'flip_direction' keys are added into
|
||||
result dict.
|
||||
"""
|
||||
flip = True if np.random.rand() < self.flip_prob else False
|
||||
results['flip'] = flip
|
||||
results['flip_direction'] = self.direction
|
||||
if results['flip']:
|
||||
# flip image
|
||||
for key in results.get('img_fields', ['img']):
|
||||
results[key] = mmcv.imflip(
|
||||
results[key], direction=results['flip_direction'])
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(flip_prob={self.flip_prob})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Resize(object):
|
||||
"""Resize images.
|
||||
|
|
|
@ -668,3 +668,80 @@ def test_randomgrayscale():
|
|||
img_pil = composed_transform(in_img_pil)
|
||||
assert_array_equal(np.array(img_pil), np.array(in_img_pil))
|
||||
assert np.array(img_pil).shape == (10, 10)
|
||||
|
||||
|
||||
def test_randomflip():
|
||||
# test assertion if flip probability is smaller than 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', flip_prob=-1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if flip probability is larger than 1
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', flip_prob=2)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if direction is not horizontal and vertical
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', direction='random')
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if direction is not lowercase
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', direction='Horizontal')
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# read test image
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img2'] = copy.deepcopy(img)
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
results['img_fields'] = ['img', 'img2']
|
||||
|
||||
def reset_results(results, original_img):
|
||||
results['img'] = copy.deepcopy(original_img)
|
||||
results['img2'] = copy.deepcopy(original_img)
|
||||
results['img_shape'] = original_img.shape
|
||||
results['ori_shape'] = original_img.shape
|
||||
return results
|
||||
|
||||
# test RandomFlip when flip_prob is 0
|
||||
transform = dict(type='RandomFlip', flip_prob=0)
|
||||
flip_module = build_from_cfg(transform, PIPELINES)
|
||||
results = flip_module(results)
|
||||
assert np.equal(results['img'], original_img).all()
|
||||
assert np.equal(results['img'], results['img2']).all()
|
||||
|
||||
# test RandomFlip when flip_prob is 1
|
||||
transform = dict(type='RandomFlip', flip_prob=1)
|
||||
flip_module = build_from_cfg(transform, PIPELINES)
|
||||
results = flip_module(results)
|
||||
assert np.equal(results['img'], results['img2']).all()
|
||||
|
||||
# compare hotizontal flip with torchvision
|
||||
transform = dict(type='RandomFlip', flip_prob=1, direction='horizontal')
|
||||
flip_module = build_from_cfg(transform, PIPELINES)
|
||||
results = reset_results(results, original_img)
|
||||
results = flip_module(results)
|
||||
flip_module = transforms.RandomHorizontalFlip(p=1)
|
||||
pil_img = Image.fromarray(original_img)
|
||||
flipped_img = flip_module(pil_img)
|
||||
flipped_img = np.array(flipped_img)
|
||||
assert np.equal(results['img'], results['img2']).all()
|
||||
assert np.equal(results['img'], flipped_img).all()
|
||||
|
||||
# compare vertical flip with torchvision
|
||||
transform = dict(type='RandomFlip', flip_prob=1, direction='vertical')
|
||||
flip_module = build_from_cfg(transform, PIPELINES)
|
||||
results = reset_results(results, original_img)
|
||||
results = flip_module(results)
|
||||
flip_module = transforms.RandomVerticalFlip(p=1)
|
||||
pil_img = Image.fromarray(original_img)
|
||||
flipped_img = flip_module(pil_img)
|
||||
flipped_img = np.array(flipped_img)
|
||||
assert np.equal(results['img'], results['img2']).all()
|
||||
assert np.equal(results['img'], flipped_img).all()
|
||||
|
|
Loading…
Reference in New Issue