[Feature] Support random augmentation (#201)
* support random augmentation * minor fix on posterize * minor fix on posterize * minor fix on cutout * minor fix on cutout * fix bug in solarize add * revised according to commentspull/207/head
parent
2bd28435cf
commit
5195932952
|
@ -1,7 +1,7 @@
|
|||
from .auto_augment import (AutoAugment, AutoContrast, Brightness,
|
||||
ColorTransform, Contrast, Equalize, Invert,
|
||||
Posterize, Rotate, Sharpness, Shear, Solarize,
|
||||
Translate)
|
||||
ColorTransform, Contrast, Cutout, Equalize, Invert,
|
||||
Posterize, RandAugment, Rotate, Sharpness, Shear,
|
||||
Solarize, SolarizeAdd, Translate)
|
||||
from .compose import Compose
|
||||
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
|
||||
Transpose, to_tensor)
|
||||
|
@ -16,6 +16,6 @@ __all__ = [
|
|||
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
|
||||
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
|
||||
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'Lighting',
|
||||
'ColorJitter'
|
||||
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
|
||||
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import random
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -41,7 +42,7 @@ class AutoAugment(object):
|
|||
self.sub_policy = [Compose(policy) for policy in self.policies]
|
||||
|
||||
def __call__(self, results):
|
||||
sub_policy = np.random.choice(self.sub_policy)
|
||||
sub_policy = random.choice(self.sub_policy)
|
||||
return sub_policy(results)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -50,6 +51,87 @@ class AutoAugment(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandAugment(object):
|
||||
"""Random augmentation.
|
||||
This data augmentation is proposed in `RandAugment: Practical automated
|
||||
data augmentation with a reduced search space
|
||||
<https://arxiv.org/abs/1909.13719>`_.
|
||||
|
||||
Args:
|
||||
policies (list[dict]): The policies of random augmentation. Each
|
||||
policy in ``policies`` is one specific augmentation policy (dict).
|
||||
The policy shall at least have key `type`, indicating the type of
|
||||
augmentation. For those which have magnitude, (given to the fact
|
||||
they are named differently in different augmentation, )
|
||||
`magnitude_key` and `magnitude_range` shall be the magnitude
|
||||
argument (str) and the range of magnitude (tuple in the format or
|
||||
(minval, maxval)), respectively.
|
||||
num_policies (int): Number of policies to select from policies each
|
||||
time.
|
||||
magnitude_level (int | float): Magnitude level for all the augmentation
|
||||
selected.
|
||||
total_level (int | float): Total level for the magnitude. Defaults to
|
||||
30.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
policies,
|
||||
num_policies,
|
||||
magnitude_level,
|
||||
total_level=30):
|
||||
assert isinstance(num_policies, int), 'Number of policies must be ' \
|
||||
f'of int type, got {type(num_policies)} instead.'
|
||||
assert isinstance(magnitude_level, (int, float)), \
|
||||
'Magnitude level must be of int or float type, ' \
|
||||
f'got {type(magnitude_level)} instead.'
|
||||
assert isinstance(total_level, (int, float)), 'Total level must be ' \
|
||||
f'of int or float type, got {type(total_level)} instead.'
|
||||
assert isinstance(policies, list) and len(policies) > 0, \
|
||||
'Policies must be a non-empty list.'
|
||||
for policy in policies:
|
||||
assert isinstance(policy, dict) and 'type' in policy, \
|
||||
'Each policy must be a dict with key "type".'
|
||||
|
||||
assert num_policies > 0, 'num_policies must be greater than 0.'
|
||||
assert magnitude_level >= 0, 'magnitude_level must be no less than 0.'
|
||||
assert total_level > 0, 'total_level must be greater than 0.'
|
||||
|
||||
self.num_policies = num_policies
|
||||
self.magnitude_level = magnitude_level
|
||||
self.total_level = total_level
|
||||
self.policies = self._process_policies(policies)
|
||||
|
||||
def _process_policies(self, policies):
|
||||
processed_policies = []
|
||||
for policy in policies:
|
||||
processed_policy = copy.deepcopy(policy)
|
||||
magnitude_key = processed_policy.pop('magnitude_key', None)
|
||||
if magnitude_key is not None:
|
||||
minval, maxval = processed_policy.pop('magnitude_range')
|
||||
magnitude_value = (float(self.magnitude_level) /
|
||||
self.total_level) * float(maxval -
|
||||
minval) + minval
|
||||
processed_policy.update({magnitude_key: magnitude_value})
|
||||
processed_policies.append(processed_policy)
|
||||
return processed_policies
|
||||
|
||||
def __call__(self, results):
|
||||
if self.num_policies == 0:
|
||||
return results
|
||||
sub_policy = random.choices(self.policies, k=self.num_policies)
|
||||
sub_policy = Compose(sub_policy)
|
||||
return sub_policy(results)
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(policies={self.policies}, '
|
||||
repr_str += f'num_policies={self.num_policies}, '
|
||||
repr_str += f'magnitude_level={self.magnitude_level}, '
|
||||
repr_str += f'total_level={self.total_level})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Shear(object):
|
||||
"""Shear images.
|
||||
|
@ -428,25 +510,66 @@ class Solarize(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class SolarizeAdd(object):
|
||||
"""SolarizeAdd images (add a certain value to pixels below a threshold).
|
||||
|
||||
Args:
|
||||
magnitude (int | float): The value to be added to pixels below the thr.
|
||||
thr (int | float): The threshold below which the pixels value will be
|
||||
adjusted.
|
||||
prob (float): The probability for solarizing therefore should be in
|
||||
range [0, 1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, magnitude, thr=128, prob=0.5):
|
||||
assert isinstance(magnitude, (int, float)), 'The thr magnitude must '\
|
||||
f'be int or float, but got {type(magnitude)} instead.'
|
||||
assert isinstance(thr, (int, float)), 'The thr type must '\
|
||||
f'be int or float, but got {type(thr)} instead.'
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.magnitude = magnitude
|
||||
self.thr = thr
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_solarized = np.where(img < self.thr,
|
||||
np.minimum(img + self.magnitude, 255),
|
||||
img)
|
||||
results[key] = img_solarized.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(magnitude={self.magnitude}, '
|
||||
repr_str += f'thr={self.thr}, '
|
||||
repr_str += f'prob={self.prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Posterize(object):
|
||||
"""Posterize images (reduce the number of bits for each color channel).
|
||||
|
||||
Args:
|
||||
bits (int): Number of bits for each pixel in the output img, which
|
||||
should be less or equal to 8.
|
||||
bits (int | float): Number of bits for each pixel in the output img,
|
||||
which should be less or equal to 8.
|
||||
prob (float): The probability for posterizing therefore should be in
|
||||
range [0, 1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, bits, prob=0.5):
|
||||
assert isinstance(bits, int), 'The bits type must be int, '\
|
||||
f'but got {type(bits)} instead.'
|
||||
assert bits <= 8, f'The bits must be less than 8, got {bits} instead.'
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.bits = bits
|
||||
self.bits = int(bits)
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
|
@ -642,3 +765,51 @@ class Sharpness(object):
|
|||
repr_str += f'prob={self.prob}, '
|
||||
repr_str += f'random_negative_prob={self.random_negative_prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Cutout(object):
|
||||
"""Cutout images.
|
||||
|
||||
Args:
|
||||
shape (int | float | tuple(int | float)): Expected cutout shape (h, w).
|
||||
If given as a single value, the value will be used for
|
||||
both h and w.
|
||||
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If
|
||||
it is a tuple, it must have the same length with the image
|
||||
channels. Defaults to 128.
|
||||
prob (float): The probability for performing cutout therefore should
|
||||
be in range [0, 1]. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, shape, pad_val=128, prob=0.5):
|
||||
if isinstance(shape, float):
|
||||
shape = int(shape)
|
||||
elif isinstance(shape, tuple):
|
||||
shape = tuple(int(i) for i in shape)
|
||||
elif not isinstance(shape, int):
|
||||
raise TypeError(
|
||||
'shape must be of '
|
||||
f'type int, float or tuple, got {type(shape)} instead')
|
||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.shape = shape
|
||||
self.pad_val = pad_val
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
if np.random.rand() > self.prob:
|
||||
return results
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
img_cutout = mmcv.cutout(img, self.shape, pad_val=self.pad_val)
|
||||
results[key] = img_cutout.astype(img.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(shape={self.shape}, '
|
||||
repr_str += f'pad_val={self.pad_val}, '
|
||||
repr_str += f'prob={self.prob})'
|
||||
return repr_str
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import random
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -38,6 +39,189 @@ def construct_toy_data_photometric():
|
|||
return results
|
||||
|
||||
|
||||
def test_rand_augment():
|
||||
policies = [
|
||||
dict(
|
||||
type='Translate',
|
||||
magnitude_key='magnitude',
|
||||
magnitude_range=(0, 1),
|
||||
pad_val=128,
|
||||
prob=1.,
|
||||
direction='horizontal'),
|
||||
dict(type='Invert', prob=1.),
|
||||
dict(
|
||||
type='Rotate',
|
||||
magnitude_key='angle',
|
||||
magnitude_range=(0, 30),
|
||||
prob=0.)
|
||||
]
|
||||
# test assertion for num_policies
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=1.5,
|
||||
magnitude_level=12)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=-1,
|
||||
magnitude_level=12)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion for magnitude_level
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=1,
|
||||
magnitude_level=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=1,
|
||||
magnitude_level=-1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion for total_level
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=1,
|
||||
magnitude_level=12,
|
||||
total_level=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=1,
|
||||
magnitude_level=12,
|
||||
total_level=-30)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion for policies
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=[],
|
||||
num_policies=2,
|
||||
magnitude_level=12)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
invalid_policies = copy.deepcopy(policies)
|
||||
invalid_policies.append(('Wrong_policy'))
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=invalid_policies,
|
||||
num_policies=2,
|
||||
magnitude_level=12)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
invalid_policies = copy.deepcopy(policies)
|
||||
invalid_policies[2].pop('type')
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=invalid_policies,
|
||||
num_policies=2,
|
||||
magnitude_level=12)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(KeyError):
|
||||
invalid_policies = copy.deepcopy(policies)
|
||||
invalid_policies[2].pop('magnitude_range')
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=invalid_policies,
|
||||
num_policies=2,
|
||||
magnitude_level=12)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case where num_policies = 1
|
||||
random.seed(1)
|
||||
np.random.seed(0)
|
||||
results = construct_toy_data()
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=1,
|
||||
magnitude_level=12)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
# apply translate
|
||||
img_augmented = np.array(
|
||||
[[128, 128, 1, 2], [128, 128, 5, 6], [128, 128, 9, 10]],
|
||||
dtype=np.uint8)
|
||||
img_augmented = np.stack([img_augmented, img_augmented, img_augmented],
|
||||
axis=-1)
|
||||
assert (results['img'] == img_augmented).all()
|
||||
|
||||
results = construct_toy_data()
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=1,
|
||||
magnitude_level=12)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
# apply rotation with prob=0.
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case where num_policies = 2
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
results = construct_toy_data()
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=2,
|
||||
magnitude_level=12)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
# apply rotate and rotate with prob=0
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
results = construct_toy_data()
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=2,
|
||||
magnitude_level=12)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
# apply invert and translate
|
||||
img_augmented = np.array(
|
||||
[[252, 251, 128, 128], [248, 247, 128, 128], [244, 243, 128, 128]],
|
||||
dtype=np.uint8)
|
||||
img_augmented = np.stack([img_augmented, img_augmented, img_augmented],
|
||||
axis=-1)
|
||||
assert (results['img'] == img_augmented).all()
|
||||
|
||||
results = construct_toy_data()
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=2,
|
||||
magnitude_level=0)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
# apply invert and invert
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case where magnitude_level = 0
|
||||
results = construct_toy_data()
|
||||
transform = dict(
|
||||
type='RandAugment',
|
||||
policies=policies,
|
||||
num_policies=2,
|
||||
magnitude_level=0)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
# apply rotate and translate
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
|
||||
def test_shear():
|
||||
# test assertion for invalid type of magnitude
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -467,12 +651,57 @@ def test_solarize():
|
|||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
|
||||
def test_posterize():
|
||||
# test assertion for invalid type of bits
|
||||
def test_solarize_add():
|
||||
# test assertion for invalid type of magnitude
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Posterize', bits=4.5)
|
||||
transform = dict(type='SolarizeAdd', magnitude=(1, 2))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid type of thr
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='SolarizeAdd', magnitude=100, thr=(1, 2))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when prob=0, therefore no solarize
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='SolarizeAdd', magnitude=100, thr=128, prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when thr=0, therefore no solarize
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='SolarizeAdd', magnitude=100, thr=0, prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when thr=128, magnitude=100
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='SolarizeAdd', magnitude=100, thr=128, prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_solarized = np.array(
|
||||
[[100, 128, 255], [101, 227, 254], [102, 129, 253]], dtype=np.uint8)
|
||||
img_solarized = np.stack([img_solarized, img_solarized, img_solarized],
|
||||
axis=-1)
|
||||
assert (results['img'] == img_solarized).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
# test case when thr=100, magnitude=50
|
||||
results = construct_toy_data_photometric()
|
||||
transform = dict(type='SolarizeAdd', magnitude=50, thr=100, prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_solarized = np.array([[50, 128, 255], [51, 127, 254], [52, 129, 253]],
|
||||
dtype=np.uint8)
|
||||
img_solarized = np.stack([img_solarized, img_solarized, img_solarized],
|
||||
axis=-1)
|
||||
assert (results['img'] == img_solarized).all()
|
||||
assert (results['img'] == results['img2']).all()
|
||||
|
||||
|
||||
def test_posterize():
|
||||
# test assertion for invalid value of bits
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Posterize', bits=10)
|
||||
|
@ -771,3 +1000,59 @@ def test_sharpness(nb_rand_test=100):
|
|||
_adjust_sharpness(img, 1 + magnitude)[1:-1, 1:-1],
|
||||
rtol=0,
|
||||
atol=1)
|
||||
|
||||
|
||||
def test_cutout():
|
||||
|
||||
# test assertion for invalid type of shape
|
||||
with pytest.raises(TypeError):
|
||||
transform = dict(type='Cutout', shape=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid value of prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Cutout', shape=1, prob=100)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test case when prob=0, therefore no cutout
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Cutout', shape=2, prob=0.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when shape=0, therefore no cutout
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Cutout', shape=0, prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == results['ori_img']).all()
|
||||
|
||||
# test case when shape=6, therefore the whole img has been cut
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Cutout', shape=6, prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
assert (results['img'] == np.ones_like(results['ori_img']) * 128).all()
|
||||
|
||||
# test case when shape is int
|
||||
np.random.seed(0)
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Cutout', shape=1, prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_cutout = np.array([[1, 2, 3, 4], [5, 128, 7, 8], [9, 10, 11, 12]],
|
||||
dtype=np.uint8)
|
||||
img_cutout = np.stack([img_cutout, img_cutout, img_cutout], axis=-1)
|
||||
assert (results['img'] == img_cutout).all()
|
||||
|
||||
# test case when shape is tuple
|
||||
np.random.seed(0)
|
||||
results = construct_toy_data()
|
||||
transform = dict(type='Cutout', shape=(1, 2), pad_val=0, prob=1.)
|
||||
pipeline = build_from_cfg(transform, PIPELINES)
|
||||
results = pipeline(results)
|
||||
img_cutout = np.array([[1, 2, 3, 4], [5, 0, 0, 8], [9, 10, 11, 12]],
|
||||
dtype=np.uint8)
|
||||
img_cutout = np.stack([img_cutout, img_cutout, img_cutout], axis=-1)
|
||||
assert (results['img'] == img_cutout).all()
|
||||
|
|
Loading…
Reference in New Issue