Refactor `RandomErasing`

pull/913/head
mzr1996 2022-05-31 10:47:29 +08:00
parent 4f28b9dd63
commit c0feadf546
2 changed files with 181 additions and 20 deletions

View File

@ -7,7 +7,8 @@ from typing import Dict, Sequence
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv.transforms import BaseTransform from mmcv import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmcls.registry import TRANSFORMS from mmcls.registry import TRANSFORMS
from .compose import Compose from .compose import Compose
@ -432,7 +433,7 @@ class RandomGrayscale(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomErasing(object): class RandomErasing(BaseTransform):
"""Randomly selects a rectangle region in an image and erase pixels. """Randomly selects a rectangle region in an image and erase pixels.
Args: Args:
@ -515,6 +516,7 @@ class RandomErasing(object):
self.fill_std = fill_std self.fill_std = fill_std
def _fill_pixels(self, img, top, left, h, w): def _fill_pixels(self, img, top, left, h, w):
"""Fill pixels to the patch of image."""
if self.mode == 'const': if self.mode == 'const':
patch = np.empty((h, w, 3), dtype=np.uint8) patch = np.empty((h, w, 3), dtype=np.uint8)
patch[:, :] = np.array(self.fill_color, dtype=np.uint8) patch[:, :] = np.array(self.fill_color, dtype=np.uint8)
@ -529,7 +531,27 @@ class RandomErasing(object):
img[top:top + h, left:left + w] = patch img[top:top + h, left:left + w] = patch
return img return img
def __call__(self, results): @cache_randomness
def random_disable(self):
"""Randomly disable the transform."""
return np.random.rand() > self.erase_prob
@cache_randomness
def random_patch(self, img_h, img_w):
"""Randomly generate patch the erase."""
log_aspect_range = np.log(
np.array(self.aspect_range, dtype=np.float32))
aspect_ratio = np.exp(np.random.uniform(*log_aspect_range))
area = img_h * img_w
area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio)
h = min(int(round(np.sqrt(area * aspect_ratio))), img_h)
w = min(int(round(np.sqrt(area / aspect_ratio))), img_w)
top = np.random.randint(0, img_h - h) if img_h > h else 0
left = np.random.randint(0, img_w - w) if img_w > w else 0
return top, left, h, w
def transform(self, results):
""" """
Args: Args:
results (dict): Results dict from pipeline results (dict): Results dict from pipeline
@ -537,26 +559,17 @@ class RandomErasing(object):
Returns: Returns:
dict: Results after the transformation. dict: Results after the transformation.
""" """
for key in results.get('img_fields', ['img']): if self.random_disable():
if np.random.rand() > self.erase_prob: return results
continue
img = results[key]
img_h, img_w = img.shape[:2]
# convert to log aspect to ensure equal probability of aspect ratio img = results['img']
log_aspect_range = np.log( img_h, img_w = img.shape[:2]
np.array(self.aspect_range, dtype=np.float32))
aspect_ratio = np.exp(np.random.uniform(*log_aspect_range))
area = img_h * img_w
area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio)
h = min(int(round(np.sqrt(area * aspect_ratio))), img_h) # convert to log aspect to ensure equal probability of aspect ratio
w = min(int(round(np.sqrt(area / aspect_ratio))), img_w) img = self._fill_pixels(img, *self.random_patch(img_h, img_w))
top = np.random.randint(0, img_h - h) if img_h > h else 0
left = np.random.randint(0, img_w - w) if img_w > w else 0 results['img'] = img
img = self._fill_pixels(img, top, left, h, w)
results[key] = img
return results return results
def __repr__(self): def __repr__(self):

View File

@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch
import numpy as np import numpy as np
@ -7,6 +9,19 @@ import mmcls.datasets # noqa: F401,F403
from mmcls.registry import TRANSFORMS from mmcls.registry import TRANSFORMS
def construct_toy_data():
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
dtype=np.uint8)
img = np.stack([img, img, img], axis=-1)
results = dict()
# image
results['ori_img'] = img
results['img'] = copy.deepcopy(img)
results['ori_shape'] = img.shape
results['img_shape'] = img.shape
return results
class TestResizeEdge(TestCase): class TestResizeEdge(TestCase):
def test_transform(self): def test_transform(self):
@ -47,3 +62,136 @@ class TestResizeEdge(TestCase):
self.assertEqual( self.assertEqual(
repr(transform), 'ResizeEdge(scale=224, edge=height, backend=cv2, ' repr(transform), 'ResizeEdge(scale=224, edge=height, backend=cv2, '
'interpolation=bilinear)') 'interpolation=bilinear)')
class TestRandomErasing(TestCase):
def test_initialize(self):
# test erase_prob assertion
with self.assertRaises(AssertionError):
cfg = dict(type='RandomErasing', erase_prob=-1.)
TRANSFORMS.build(cfg)
with self.assertRaises(AssertionError):
cfg = dict(type='RandomErasing', erase_prob=1)
TRANSFORMS.build(cfg)
# test area_ratio assertion
with self.assertRaises(AssertionError):
cfg = dict(type='RandomErasing', min_area_ratio=-1.)
TRANSFORMS.build(cfg)
with self.assertRaises(AssertionError):
cfg = dict(type='RandomErasing', max_area_ratio=1)
TRANSFORMS.build(cfg)
with self.assertRaises(AssertionError):
# min_area_ratio should be smaller than max_area_ratio
cfg = dict(
type='RandomErasing', min_area_ratio=0.6, max_area_ratio=0.4)
TRANSFORMS.build(cfg)
# test aspect_range assertion
with self.assertRaises(AssertionError):
cfg = dict(type='RandomErasing', aspect_range='str')
TRANSFORMS.build(cfg)
with self.assertRaises(AssertionError):
cfg = dict(type='RandomErasing', aspect_range=-1)
TRANSFORMS.build(cfg)
with self.assertRaises(AssertionError):
# In aspect_range (min, max), min should be smaller than max.
cfg = dict(type='RandomErasing', aspect_range=[1.6, 0.6])
TRANSFORMS.build(cfg)
# test mode assertion
with self.assertRaises(AssertionError):
cfg = dict(type='RandomErasing', mode='unknown')
TRANSFORMS.build(cfg)
# test fill_std assertion
with self.assertRaises(AssertionError):
cfg = dict(type='RandomErasing', fill_std='unknown')
TRANSFORMS.build(cfg)
# test implicit conversion of aspect_range
cfg = dict(type='RandomErasing', aspect_range=0.5)
random_erasing = TRANSFORMS.build(cfg)
assert random_erasing.aspect_range == (0.5, 2.)
cfg = dict(type='RandomErasing', aspect_range=2.)
random_erasing = TRANSFORMS.build(cfg)
assert random_erasing.aspect_range == (0.5, 2.)
# test implicit conversion of fill_color
cfg = dict(type='RandomErasing', fill_color=15)
random_erasing = TRANSFORMS.build(cfg)
assert random_erasing.fill_color == [15, 15, 15]
# test implicit conversion of fill_std
cfg = dict(type='RandomErasing', fill_std=0.5)
random_erasing = TRANSFORMS.build(cfg)
assert random_erasing.fill_std == [0.5, 0.5, 0.5]
def test_transform(self):
# test when erase_prob=0.
results = construct_toy_data()
cfg = dict(
type='RandomErasing',
erase_prob=0.,
mode='const',
fill_color=(255, 255, 255))
random_erasing = TRANSFORMS.build(cfg)
results = random_erasing(results)
np.testing.assert_array_equal(results['img'], results['ori_img'])
# test mode 'const'
results = construct_toy_data()
cfg = dict(
type='RandomErasing',
erase_prob=1.,
mode='const',
fill_color=(255, 255, 255))
with patch('numpy.random', np.random.RandomState(0)):
random_erasing = TRANSFORMS.build(cfg)
results = random_erasing(results)
expect_out = np.array(
[[1, 255, 3, 4], [5, 255, 7, 8], [9, 10, 11, 12]],
dtype=np.uint8)
expect_out = np.stack([expect_out] * 3, axis=-1)
np.testing.assert_array_equal(results['img'], expect_out)
# test mode 'rand' with normal distribution
results = construct_toy_data()
cfg = dict(type='RandomErasing', erase_prob=1., mode='rand')
with patch('numpy.random', np.random.RandomState(0)):
random_erasing = TRANSFORMS.build(cfg)
results = random_erasing(results)
expect_out = results['ori_img']
expect_out[:2, 1] = [[159, 98, 76], [14, 69, 122]]
np.testing.assert_array_equal(results['img'], expect_out)
# test mode 'rand' with uniform distribution
results = construct_toy_data()
cfg = dict(
type='RandomErasing',
erase_prob=1.,
mode='rand',
fill_std=(10, 255, 0))
with patch('numpy.random', np.random.RandomState(0)):
random_erasing = TRANSFORMS.build(cfg)
results = random_erasing(results)
expect_out = results['ori_img']
expect_out[:2, 1] = [[113, 255, 128], [126, 83, 128]]
np.testing.assert_array_equal(results['img'], expect_out)
def test_repr(self):
cfg = dict(
type='RandomErasing',
erase_prob=0.5,
mode='const',
aspect_range=(0.3, 1.3),
fill_color=(255, 255, 255))
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform),
'RandomErasing(erase_prob=0.5, min_area_ratio=0.02, '
'max_area_ratio=0.4, aspect_range=(0.3, 1.3), mode=const, '
'fill_color=(255, 255, 255), fill_std=None)')