[Feature] Add RandomRotate transform (#215)
* add RandomRotate for transforms * change rotation function to mmcv.imrotate * refactor * add unittest * fixed test * fixed docstring * fixed test * add more test * fixed repr * rename to prob * fixed unittest Co-authored-by: hkzhang95 <GodBlessZhk@outlook.com>pull/1801/head
parent
be94bbf1cc
commit
500babf958
|
@ -9,7 +9,7 @@ train_pipeline = [
|
|||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
|
|
|
@ -9,7 +9,7 @@ train_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
|
|
|
@ -7,7 +7,7 @@ train_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
|
|
|
@ -12,7 +12,7 @@ train_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
|
|
|
@ -9,7 +9,7 @@ train_pipeline = [
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import deprecated_api_warning
|
||||
from numpy import random
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
@ -232,16 +233,17 @@ class RandomFlip(object):
|
|||
method.
|
||||
|
||||
Args:
|
||||
flip_ratio (float, optional): The flipping probability. Default: None.
|
||||
prob (float, optional): The flipping probability. Default: None.
|
||||
direction(str, optional): The flipping direction. Options are
|
||||
'horizontal' and 'vertical'. Default: 'horizontal'.
|
||||
"""
|
||||
|
||||
def __init__(self, flip_ratio=None, direction='horizontal'):
|
||||
self.flip_ratio = flip_ratio
|
||||
@deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
|
||||
def __init__(self, prob=None, direction='horizontal'):
|
||||
self.prob = prob
|
||||
self.direction = direction
|
||||
if flip_ratio is not None:
|
||||
assert flip_ratio >= 0 and flip_ratio <= 1
|
||||
if prob is not None:
|
||||
assert prob >= 0 and prob <= 1
|
||||
assert direction in ['horizontal', 'vertical']
|
||||
|
||||
def __call__(self, results):
|
||||
|
@ -257,7 +259,7 @@ class RandomFlip(object):
|
|||
"""
|
||||
|
||||
if 'flip' not in results:
|
||||
flip = True if np.random.rand() < self.flip_ratio else False
|
||||
flip = True if np.random.rand() < self.prob else False
|
||||
results['flip'] = flip
|
||||
if 'flip_direction' not in results:
|
||||
results['flip_direction'] = self.direction
|
||||
|
@ -274,7 +276,7 @@ class RandomFlip(object):
|
|||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
|
||||
return self.__class__.__name__ + f'(prob={self.prob})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
|
@ -463,6 +465,89 @@ class RandomCrop(object):
|
|||
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomRotate(object):
|
||||
"""Rotate the image & seg.
|
||||
|
||||
Args:
|
||||
prob (float): The rotation probability.
|
||||
degree (float, tuple[float]): Range of degrees to select from. If
|
||||
degree is a number instead of tuple like (min, max),
|
||||
the range of degree will be (``-degree``, ``+degree``)
|
||||
pad_val (float, optional): Padding value of image. Default: 0.
|
||||
seg_pad_val (float, optional): Padding value of segmentation map.
|
||||
Default: 255.
|
||||
center (tuple[float], optional): Center point (w, h) of the rotation in
|
||||
the source image. If not specified, the center of the image will be
|
||||
used. Default: None.
|
||||
auto_bound (bool): Whether to adjust the image size to cover the whole
|
||||
rotated image. Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prob,
|
||||
degree,
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
center=None,
|
||||
auto_bound=False):
|
||||
self.prob = prob
|
||||
assert prob >= 0 and prob <= 1
|
||||
if isinstance(degree, (float, int)):
|
||||
assert degree > 0, f'degree {degree} should be positive'
|
||||
self.degree = (-degree, degree)
|
||||
else:
|
||||
self.degree = degree
|
||||
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
|
||||
f'tuple of (min, max)'
|
||||
self.pal_val = pad_val
|
||||
self.seg_pad_val = seg_pad_val
|
||||
self.center = center
|
||||
self.auto_bound = auto_bound
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to rotate image, semantic segmentation maps.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Rotated results.
|
||||
"""
|
||||
|
||||
rotate = True if np.random.rand() < self.prob else False
|
||||
degree = np.random.uniform(min(*self.degree), max(*self.degree))
|
||||
if rotate:
|
||||
# rotate image
|
||||
results['img'] = mmcv.imrotate(
|
||||
results['img'],
|
||||
angle=degree,
|
||||
border_value=self.pal_val,
|
||||
center=self.center,
|
||||
auto_bound=self.auto_bound)
|
||||
|
||||
# rotate segs
|
||||
for key in results.get('seg_fields', []):
|
||||
results[key] = mmcv.imrotate(
|
||||
results[key],
|
||||
angle=degree,
|
||||
border_value=self.seg_pad_val,
|
||||
center=self.center,
|
||||
auto_bound=self.auto_bound,
|
||||
interpolation='nearest')
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(prob={self.prob}, ' \
|
||||
f'degree={self.degree}, ' \
|
||||
f'pad_val={self.pal_val}, ' \
|
||||
f'seg_pad_val={self.seg_pad_val}, ' \
|
||||
f'center={self.center}, ' \
|
||||
f'auto_bound={self.auto_bound})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class SegRescale(object):
|
||||
"""Rescale semantic segmentation maps.
|
||||
|
|
|
@ -69,7 +69,7 @@ def test_custom_dataset():
|
|||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
|
|
|
@ -94,18 +94,17 @@ def test_resize():
|
|||
|
||||
|
||||
def test_flip():
|
||||
# test assertion for invalid flip_ratio
|
||||
# test assertion for invalid prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', flip_ratio=1.5)
|
||||
transform = dict(type='RandomFlip', prob=1.5)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid direction
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomFlip', flip_ratio=1, direction='horizonta')
|
||||
transform = dict(type='RandomFlip', prob=1, direction='horizonta')
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
transform = dict(type='RandomFlip', flip_ratio=1)
|
||||
transform = dict(type='RandomFlip', prob=1)
|
||||
flip_module = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
results = dict()
|
||||
|
@ -197,6 +196,47 @@ def test_pad():
|
|||
assert img_shape[1] % 32 == 0
|
||||
|
||||
|
||||
def test_rotate():
|
||||
# test assertion degree should be tuple[float] or float
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion degree should be tuple[float] or float
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
transform = dict(type='RandomRotate', degree=10., prob=1.)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
assert str(transform) == f'RandomRotate(' \
|
||||
f'prob={1.}, ' \
|
||||
f'degree=({-10.}, {10.}), ' \
|
||||
f'pad_val={0}, ' \
|
||||
f'seg_pad_val={255}, ' \
|
||||
f'center={None}, ' \
|
||||
f'auto_bound={False})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, _ = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape[:2] == (h, w)
|
||||
assert results['gt_semantic_seg'].shape[:2] == (h, w)
|
||||
|
||||
|
||||
def test_normalize():
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
|
|
Loading…
Reference in New Issue