mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
Add RandomFlip
This commit is contained in:
parent
45812e87bd
commit
6968ad5b3b
@ -52,7 +52,7 @@ At the end of the pipeline, we use `Collect` to only retain the necessary items
|
|||||||
- update: img, img_shape
|
- update: img, img_shape
|
||||||
|
|
||||||
`RandomFlip`
|
`RandomFlip`
|
||||||
- add: flip
|
- add: flip, flip_direction
|
||||||
- update: img
|
- update: img
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,11 +2,12 @@ from .compose import Compose
|
|||||||
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
||||||
Transpose, to_tensor)
|
Transpose, to_tensor)
|
||||||
from .loading import LoadImageFromFile
|
from .loading import LoadImageFromFile
|
||||||
from .transforms import (CenterCrop, RandomCrop, RandomGrayscale,
|
from .transforms import (CenterCrop, RandomCrop, RandomFlip, RandomGrayscale,
|
||||||
RandomResizedCrop, Resize)
|
RandomResizedCrop, Resize)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
||||||
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
|
||||||
'Normalize', 'RandomCrop', 'RandomResizedCrop', 'RandomGrayscale'
|
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||||
|
'RandomGrayscale'
|
||||||
]
|
]
|
||||||
|
@ -280,6 +280,48 @@ class RandomGrayscale(object):
|
|||||||
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
|
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class RandomFlip(object):
|
||||||
|
"""Flip the image randomly.
|
||||||
|
|
||||||
|
Flip the image randomly based on flip probaility and flip direction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flip_prob (float): probability of the image being flipped. Default: 0.5
|
||||||
|
direction (str, optional): The flipping direction. Options are
|
||||||
|
'horizontal' and 'vertical'. Default: 'horizontal'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, flip_prob=0.5, direction='horizontal'):
|
||||||
|
assert 0 <= flip_prob <= 1
|
||||||
|
assert direction in ['horizontal', 'vertical']
|
||||||
|
self.flip_prob = flip_prob
|
||||||
|
self.direction = direction
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
"""Call function to flip image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Flipped results, 'flip', 'flip_direction' keys are added into
|
||||||
|
result dict.
|
||||||
|
"""
|
||||||
|
flip = True if np.random.rand() < self.flip_prob else False
|
||||||
|
results['flip'] = flip
|
||||||
|
results['flip_direction'] = self.direction
|
||||||
|
if results['flip']:
|
||||||
|
# flip image
|
||||||
|
for key in results.get('img_fields', ['img']):
|
||||||
|
results[key] = mmcv.imflip(
|
||||||
|
results[key], direction=results['flip_direction'])
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + f'(flip_prob={self.flip_prob})'
|
||||||
|
|
||||||
|
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class Resize(object):
|
class Resize(object):
|
||||||
"""Resize images.
|
"""Resize images.
|
||||||
|
@ -668,3 +668,80 @@ def test_randomgrayscale():
|
|||||||
img_pil = composed_transform(in_img_pil)
|
img_pil = composed_transform(in_img_pil)
|
||||||
assert_array_equal(np.array(img_pil), np.array(in_img_pil))
|
assert_array_equal(np.array(img_pil), np.array(in_img_pil))
|
||||||
assert np.array(img_pil).shape == (10, 10)
|
assert np.array(img_pil).shape == (10, 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_randomflip():
|
||||||
|
# test assertion if flip probability is smaller than 0
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomFlip', flip_prob=-1)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test assertion if flip probability is larger than 1
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomFlip', flip_prob=2)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test assertion if direction is not horizontal and vertical
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomFlip', direction='random')
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# test assertion if direction is not lowercase
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomFlip', direction='Horizontal')
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
# read test image
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
original_img = copy.deepcopy(img)
|
||||||
|
results['img'] = img
|
||||||
|
results['img2'] = copy.deepcopy(img)
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
results['img_fields'] = ['img', 'img2']
|
||||||
|
|
||||||
|
def reset_results(results, original_img):
|
||||||
|
results['img'] = copy.deepcopy(original_img)
|
||||||
|
results['img2'] = copy.deepcopy(original_img)
|
||||||
|
results['img_shape'] = original_img.shape
|
||||||
|
results['ori_shape'] = original_img.shape
|
||||||
|
return results
|
||||||
|
|
||||||
|
# test RandomFlip when flip_prob is 0
|
||||||
|
transform = dict(type='RandomFlip', flip_prob=0)
|
||||||
|
flip_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = flip_module(results)
|
||||||
|
assert np.equal(results['img'], original_img).all()
|
||||||
|
assert np.equal(results['img'], results['img2']).all()
|
||||||
|
|
||||||
|
# test RandomFlip when flip_prob is 1
|
||||||
|
transform = dict(type='RandomFlip', flip_prob=1)
|
||||||
|
flip_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = flip_module(results)
|
||||||
|
assert np.equal(results['img'], results['img2']).all()
|
||||||
|
|
||||||
|
# compare hotizontal flip with torchvision
|
||||||
|
transform = dict(type='RandomFlip', flip_prob=1, direction='horizontal')
|
||||||
|
flip_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = reset_results(results, original_img)
|
||||||
|
results = flip_module(results)
|
||||||
|
flip_module = transforms.RandomHorizontalFlip(p=1)
|
||||||
|
pil_img = Image.fromarray(original_img)
|
||||||
|
flipped_img = flip_module(pil_img)
|
||||||
|
flipped_img = np.array(flipped_img)
|
||||||
|
assert np.equal(results['img'], results['img2']).all()
|
||||||
|
assert np.equal(results['img'], flipped_img).all()
|
||||||
|
|
||||||
|
# compare vertical flip with torchvision
|
||||||
|
transform = dict(type='RandomFlip', flip_prob=1, direction='vertical')
|
||||||
|
flip_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
results = reset_results(results, original_img)
|
||||||
|
results = flip_module(results)
|
||||||
|
flip_module = transforms.RandomVerticalFlip(p=1)
|
||||||
|
pil_img = Image.fromarray(original_img)
|
||||||
|
flipped_img = flip_module(pil_img)
|
||||||
|
flipped_img = np.array(flipped_img)
|
||||||
|
assert np.equal(results['img'], results['img2']).all()
|
||||||
|
assert np.equal(results['img'], flipped_img).all()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user