Refactor `RandomErasing`
parent
4f28b9dd63
commit
c0feadf546
|
@ -7,7 +7,8 @@ from typing import Dict, Sequence
|
|||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from .compose import Compose
|
||||
|
@ -432,7 +433,7 @@ class RandomGrayscale(object):
|
|||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomErasing(object):
|
||||
class RandomErasing(BaseTransform):
|
||||
"""Randomly selects a rectangle region in an image and erase pixels.
|
||||
|
||||
Args:
|
||||
|
@ -515,6 +516,7 @@ class RandomErasing(object):
|
|||
self.fill_std = fill_std
|
||||
|
||||
def _fill_pixels(self, img, top, left, h, w):
|
||||
"""Fill pixels to the patch of image."""
|
||||
if self.mode == 'const':
|
||||
patch = np.empty((h, w, 3), dtype=np.uint8)
|
||||
patch[:, :] = np.array(self.fill_color, dtype=np.uint8)
|
||||
|
@ -529,7 +531,27 @@ class RandomErasing(object):
|
|||
img[top:top + h, left:left + w] = patch
|
||||
return img
|
||||
|
||||
def __call__(self, results):
|
||||
@cache_randomness
|
||||
def random_disable(self):
|
||||
"""Randomly disable the transform."""
|
||||
return np.random.rand() > self.erase_prob
|
||||
|
||||
@cache_randomness
|
||||
def random_patch(self, img_h, img_w):
|
||||
"""Randomly generate patch the erase."""
|
||||
log_aspect_range = np.log(
|
||||
np.array(self.aspect_range, dtype=np.float32))
|
||||
aspect_ratio = np.exp(np.random.uniform(*log_aspect_range))
|
||||
area = img_h * img_w
|
||||
area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio)
|
||||
|
||||
h = min(int(round(np.sqrt(area * aspect_ratio))), img_h)
|
||||
w = min(int(round(np.sqrt(area / aspect_ratio))), img_w)
|
||||
top = np.random.randint(0, img_h - h) if img_h > h else 0
|
||||
left = np.random.randint(0, img_w - w) if img_w > w else 0
|
||||
return top, left, h, w
|
||||
|
||||
def transform(self, results):
|
||||
"""
|
||||
Args:
|
||||
results (dict): Results dict from pipeline
|
||||
|
@ -537,26 +559,17 @@ class RandomErasing(object):
|
|||
Returns:
|
||||
dict: Results after the transformation.
|
||||
"""
|
||||
for key in results.get('img_fields', ['img']):
|
||||
if np.random.rand() > self.erase_prob:
|
||||
continue
|
||||
img = results[key]
|
||||
img_h, img_w = img.shape[:2]
|
||||
if self.random_disable():
|
||||
return results
|
||||
|
||||
# convert to log aspect to ensure equal probability of aspect ratio
|
||||
log_aspect_range = np.log(
|
||||
np.array(self.aspect_range, dtype=np.float32))
|
||||
aspect_ratio = np.exp(np.random.uniform(*log_aspect_range))
|
||||
area = img_h * img_w
|
||||
area *= np.random.uniform(self.min_area_ratio, self.max_area_ratio)
|
||||
img = results['img']
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
h = min(int(round(np.sqrt(area * aspect_ratio))), img_h)
|
||||
w = min(int(round(np.sqrt(area / aspect_ratio))), img_w)
|
||||
top = np.random.randint(0, img_h - h) if img_h > h else 0
|
||||
left = np.random.randint(0, img_w - w) if img_w > w else 0
|
||||
img = self._fill_pixels(img, top, left, h, w)
|
||||
# convert to log aspect to ensure equal probability of aspect ratio
|
||||
img = self._fill_pixels(img, *self.random_patch(img_h, img_w))
|
||||
|
||||
results['img'] = img
|
||||
|
||||
results[key] = img
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -7,6 +9,19 @@ import mmcls.datasets # noqa: F401,F403
|
|||
from mmcls.registry import TRANSFORMS
|
||||
|
||||
|
||||
def construct_toy_data():
|
||||
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
dtype=np.uint8)
|
||||
img = np.stack([img, img, img], axis=-1)
|
||||
results = dict()
|
||||
# image
|
||||
results['ori_img'] = img
|
||||
results['img'] = copy.deepcopy(img)
|
||||
results['ori_shape'] = img.shape
|
||||
results['img_shape'] = img.shape
|
||||
return results
|
||||
|
||||
|
||||
class TestResizeEdge(TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
|
@ -47,3 +62,136 @@ class TestResizeEdge(TestCase):
|
|||
self.assertEqual(
|
||||
repr(transform), 'ResizeEdge(scale=224, edge=height, backend=cv2, '
|
||||
'interpolation=bilinear)')
|
||||
|
||||
|
||||
class TestRandomErasing(TestCase):
|
||||
|
||||
def test_initialize(self):
|
||||
# test erase_prob assertion
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomErasing', erase_prob=-1.)
|
||||
TRANSFORMS.build(cfg)
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomErasing', erase_prob=1)
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
# test area_ratio assertion
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomErasing', min_area_ratio=-1.)
|
||||
TRANSFORMS.build(cfg)
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomErasing', max_area_ratio=1)
|
||||
TRANSFORMS.build(cfg)
|
||||
with self.assertRaises(AssertionError):
|
||||
# min_area_ratio should be smaller than max_area_ratio
|
||||
cfg = dict(
|
||||
type='RandomErasing', min_area_ratio=0.6, max_area_ratio=0.4)
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
# test aspect_range assertion
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomErasing', aspect_range='str')
|
||||
TRANSFORMS.build(cfg)
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomErasing', aspect_range=-1)
|
||||
TRANSFORMS.build(cfg)
|
||||
with self.assertRaises(AssertionError):
|
||||
# In aspect_range (min, max), min should be smaller than max.
|
||||
cfg = dict(type='RandomErasing', aspect_range=[1.6, 0.6])
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
# test mode assertion
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomErasing', mode='unknown')
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
# test fill_std assertion
|
||||
with self.assertRaises(AssertionError):
|
||||
cfg = dict(type='RandomErasing', fill_std='unknown')
|
||||
TRANSFORMS.build(cfg)
|
||||
|
||||
# test implicit conversion of aspect_range
|
||||
cfg = dict(type='RandomErasing', aspect_range=0.5)
|
||||
random_erasing = TRANSFORMS.build(cfg)
|
||||
assert random_erasing.aspect_range == (0.5, 2.)
|
||||
|
||||
cfg = dict(type='RandomErasing', aspect_range=2.)
|
||||
random_erasing = TRANSFORMS.build(cfg)
|
||||
assert random_erasing.aspect_range == (0.5, 2.)
|
||||
|
||||
# test implicit conversion of fill_color
|
||||
cfg = dict(type='RandomErasing', fill_color=15)
|
||||
random_erasing = TRANSFORMS.build(cfg)
|
||||
assert random_erasing.fill_color == [15, 15, 15]
|
||||
|
||||
# test implicit conversion of fill_std
|
||||
cfg = dict(type='RandomErasing', fill_std=0.5)
|
||||
random_erasing = TRANSFORMS.build(cfg)
|
||||
assert random_erasing.fill_std == [0.5, 0.5, 0.5]
|
||||
|
||||
def test_transform(self):
|
||||
# test when erase_prob=0.
|
||||
results = construct_toy_data()
|
||||
cfg = dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.,
|
||||
mode='const',
|
||||
fill_color=(255, 255, 255))
|
||||
random_erasing = TRANSFORMS.build(cfg)
|
||||
results = random_erasing(results)
|
||||
np.testing.assert_array_equal(results['img'], results['ori_img'])
|
||||
|
||||
# test mode 'const'
|
||||
results = construct_toy_data()
|
||||
cfg = dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=1.,
|
||||
mode='const',
|
||||
fill_color=(255, 255, 255))
|
||||
with patch('numpy.random', np.random.RandomState(0)):
|
||||
random_erasing = TRANSFORMS.build(cfg)
|
||||
results = random_erasing(results)
|
||||
expect_out = np.array(
|
||||
[[1, 255, 3, 4], [5, 255, 7, 8], [9, 10, 11, 12]],
|
||||
dtype=np.uint8)
|
||||
expect_out = np.stack([expect_out] * 3, axis=-1)
|
||||
np.testing.assert_array_equal(results['img'], expect_out)
|
||||
|
||||
# test mode 'rand' with normal distribution
|
||||
results = construct_toy_data()
|
||||
cfg = dict(type='RandomErasing', erase_prob=1., mode='rand')
|
||||
with patch('numpy.random', np.random.RandomState(0)):
|
||||
random_erasing = TRANSFORMS.build(cfg)
|
||||
results = random_erasing(results)
|
||||
expect_out = results['ori_img']
|
||||
expect_out[:2, 1] = [[159, 98, 76], [14, 69, 122]]
|
||||
np.testing.assert_array_equal(results['img'], expect_out)
|
||||
|
||||
# test mode 'rand' with uniform distribution
|
||||
results = construct_toy_data()
|
||||
cfg = dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=1.,
|
||||
mode='rand',
|
||||
fill_std=(10, 255, 0))
|
||||
with patch('numpy.random', np.random.RandomState(0)):
|
||||
random_erasing = TRANSFORMS.build(cfg)
|
||||
results = random_erasing(results)
|
||||
|
||||
expect_out = results['ori_img']
|
||||
expect_out[:2, 1] = [[113, 255, 128], [126, 83, 128]]
|
||||
np.testing.assert_array_equal(results['img'], expect_out)
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.5,
|
||||
mode='const',
|
||||
aspect_range=(0.3, 1.3),
|
||||
fill_color=(255, 255, 255))
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(
|
||||
repr(transform),
|
||||
'RandomErasing(erase_prob=0.5, min_area_ratio=0.02, '
|
||||
'max_area_ratio=0.4, aspect_range=(0.3, 1.3), mode=const, '
|
||||
'fill_color=(255, 255, 255), fill_std=None)')
|
||||
|
|
Loading…
Reference in New Issue