mmcv/tests/test_transforms/test_transforms_wrapper.py

586 lines
18 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import pytest
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
cache_randomness)
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
RandomChoice, TransformBroadcaster)
@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
"""Dummy transform to add a given addend to results['value']"""
def __init__(self, addend=0) -> None:
super().__init__()
self.addend = addend
def add(self, results, addend):
augend = results['value']
if isinstance(augend, list):
warnings.warn('value is a list', UserWarning)
if isinstance(augend, dict):
warnings.warn('value is a dict', UserWarning)
def _add_to_value(augend, addend):
if isinstance(augend, list):
return [_add_to_value(v, addend) for v in augend]
if isinstance(augend, dict):
return {k: _add_to_value(v, addend) for k, v in augend.items()}
return augend + addend
results['value'] = _add_to_value(results['value'], addend)
return results
def transform(self, results):
return self.add(results, self.addend)
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'addend = {self.addend}'
return repr_str
@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
"""Dummy transform to add a random addend to results['value']"""
def __init__(self, repeat=1) -> None:
super().__init__(addend=None)
self.repeat = repeat
@cache_randomness
def get_random_addend(self):
return np.random.rand()
def transform(self, results):
for _ in range(self.repeat):
results = self.add(results, addend=self.get_random_addend())
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'repeat = {self.repeat}'
return repr_str
@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
"""Dummy transform to test transform wrappers."""
def transform(self, results):
if 'num_1' in results and 'num_2' in results:
results['sum'] = results['num_1'] + results['num_2']
elif 'num_1' in results:
results['sum'] = results['num_1']
elif 'num_2' in results:
results['sum'] = results['num_2']
else:
results['sum'] = np.nan
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
def test_compose():
# Case 1: build from cfg
pipeline = [dict(type='AddToValue')]
pipeline = Compose(pipeline)
_ = str(pipeline)
# Case 2: build from transform list
pipeline = [AddToValue()]
pipeline = Compose(pipeline)
# Case 3: invalid build arguments
pipeline = [[dict(type='AddToValue')]]
with pytest.raises(TypeError):
pipeline = Compose(pipeline)
# Case 4: contain transform with None output
class DummyTransform(BaseTransform):
def transform(self, results):
return None
pipeline = Compose([DummyTransform()])
results = pipeline({})
assert results is None
def test_cache_random_parameters():
transform = RandomAddToValue()
# Case 1: cache random parameters
assert hasattr(RandomAddToValue, '_methods_with_randomness')
assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
with cache_random_params(transform):
results_1 = transform(dict(value=0))
results_2 = transform(dict(value=0))
np.testing.assert_equal(results_1['value'], results_2['value'])
# Case 2: do not cache random parameters
results_1 = transform(dict(value=0))
results_2 = transform(dict(value=0))
with pytest.raises(AssertionError):
np.testing.assert_equal(results_1['value'], results_2['value'])
# Case 3: allow to invoke random method 0 times
transform = RandomAddToValue(repeat=0)
with cache_random_params(transform):
_ = transform(dict(value=0))
# Case 4: NOT allow to invoke random method >1 times
transform = RandomAddToValue(repeat=2)
with pytest.raises(RuntimeError):
with cache_random_params(transform):
_ = transform(dict(value=0))
# Case 5: apply on nested transforms
transform = Compose([RandomAddToValue()])
with cache_random_params(transform):
results_1 = transform(dict(value=0))
results_2 = transform(dict(value=0))
np.testing.assert_equal(results_1['value'], results_2['value'])
def test_key_mapper():
# Case 0: only remap
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)], remapping={'value': 'v_out'})
results = dict(value=0)
results = pipeline(results)
np.testing.assert_equal(results['value'], 0) # should be unchanged
np.testing.assert_equal(results['v_out'], 1)
# Case 1: simple remap
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
mapping={'value': 'v_in'},
remapping={'value': 'v_out'})
results = dict(value=0, v_in=1)
results = pipeline(results)
np.testing.assert_equal(results['value'], 0) # should be unchanged
np.testing.assert_equal(results['v_in'], 1)
np.testing.assert_equal(results['v_out'], 2)
# Case 2: collecting list
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
mapping={'value': ['v_in_1', 'v_in_2']},
remapping={'value': ['v_out_1', 'v_out_2']})
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0) # should be unchanged
np.testing.assert_equal(results['v_in_1'], 1)
np.testing.assert_equal(results['v_in_2'], 2)
np.testing.assert_equal(results['v_out_1'], 3)
np.testing.assert_equal(results['v_out_2'], 4)
# Case 3: collecting dict
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
mapping={'value': {
'v1': 'v_in_1',
'v2': 'v_in_2'
}},
remapping={'value': {
'v1': 'v_out_1',
'v2': 'v_out_2'
}})
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0) # should be unchanged
np.testing.assert_equal(results['v_in_1'], 1)
np.testing.assert_equal(results['v_in_2'], 2)
np.testing.assert_equal(results['v_out_1'], 3)
np.testing.assert_equal(results['v_out_2'], 4)
# Case 4: collecting list with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
mapping=dict(value=['v_in_1', 'v_in_2']),
auto_remap=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0)
np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4)
# Case 5: collecting dict with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
auto_remap=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0)
np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4)
# Case 6: nested collection with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
auto_remap=True)
results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
with pytest.warns(UserWarning, match='value is a list'):
results = pipeline(results)
np.testing.assert_equal(results['value'], 0)
np.testing.assert_equal(results['v1'], 3)
np.testing.assert_equal(results['v21'], 4)
np.testing.assert_equal(results['v22'], 5)
np.testing.assert_equal(results['v3'], 6)
# Case 7: output_map must be None if `auto_remap` is set True
with pytest.raises(ValueError):
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
mapping=dict(value='v_in'),
remapping=dict(value='v_out'),
auto_remap=True)
# Case 8: allow_nonexist_keys8
pipeline = KeyMapper(
transforms=[SumTwoValues()],
mapping=dict(num_1='a', num_2='b'),
auto_remap=False,
allow_nonexist_keys=True)
results = pipeline(dict(a=1, b=2))
np.testing.assert_equal(results['sum'], 3)
results = pipeline(dict(a=1))
np.testing.assert_equal(results['sum'], 1)
# Case 9: use wrapper as a transform
transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
results = transform(dict(a=1))
# note that the original key 'a' will not be removed
assert results == dict(a=1, b=1)
# Case 10: manually set keys ignored
pipeline = KeyMapper(
transforms=[SumTwoValues()],
mapping=dict(num_1='a', num_2=...), # num_2 (b) will be ignored
auto_remap=False,
# allow_nonexist_keys will not affect manually ignored keys
allow_nonexist_keys=False)
results = pipeline(dict(a=1, b=2))
np.testing.assert_equal(results['sum'], 1)
# Test basic functions
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
mapping=dict(value='v_in'),
remapping=dict(value='v_out'))
# __iter__
for _ in pipeline:
pass
# __repr__
assert repr(pipeline) == (
'KeyMapper(transforms = Compose(\n ' + 'AddToValueaddend = 1' +
'\n), mapping = {\'value\': \'v_in\'}, ' +
'remapping = {\'value\': \'v_out\'}, auto_remap = False, ' +
'allow_nonexist_keys = False)')
def test_transform_broadcaster():
# Case 1: apply to list in results
pipeline = TransformBroadcaster(
transforms=[AddToValue(addend=1)],
mapping=dict(value='values'),
auto_remap=True)
results = dict(values=[1, 2])
results = pipeline(results)
np.testing.assert_equal(results['values'], [2, 3])
# Case 2: apply to multiple keys
pipeline = TransformBroadcaster(
transforms=[AddToValue(addend=1)],
mapping=dict(value=['v_1', 'v_2']),
auto_remap=True)
results = dict(v_1=1, v_2=2)
results = pipeline(results)
np.testing.assert_equal(results['v_1'], 2)
np.testing.assert_equal(results['v_2'], 3)
# Case 3: apply to multiple groups of keys
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
remapping=dict(sum=['a', 'b']),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
np.testing.assert_equal(results['a'], 3)
np.testing.assert_equal(results['b'], 7)
# Case 3: apply to all keys
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()], mapping=None, remapping=None)
results = dict(num_1=[1, 2, 3], num_2=[4, 5, 6])
results = pipeline(results)
np.testing.assert_equal(results['sum'], [5, 7, 9])
# Case 4: inconsistent sequence length
with pytest.raises(ValueError):
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
mapping=dict(num_1='list_1', num_2='list_2'),
auto_remap=False)
results = dict(list_1=[1, 2], list_2=[1, 2, 3])
_ = pipeline(results)
# Case 5: share random parameter
pipeline = TransformBroadcaster(
transforms=[RandomAddToValue()],
mapping=dict(value='values'),
auto_remap=True,
share_random_params=True)
results = dict(values=[0, 0])
results = pipeline(results)
np.testing.assert_equal(results['values'][0], results['values'][1])
# Case 6: partial broadcasting
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', ...]),
remapping=dict(sum=['a', 'b']),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
np.testing.assert_equal(results['a'], 3)
np.testing.assert_equal(results['b'], 3)
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
remapping=dict(sum=['a', ...]),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
np.testing.assert_equal(results['a'], 3)
assert 'b' not in results
# Test repr
assert repr(pipeline) == (
'TransformBroadcaster(transforms = Compose(\n' + ' SumTwoValues' +
'\n), mapping = {\'num_1\': [\'a_1\', \'b_1\'], ' +
'\'num_2\': [\'a_2\', \'b_2\']}, ' +
'remapping = {\'sum\': [\'a\', Ellipsis]}, auto_remap = False, ' +
'allow_nonexist_keys = False, share_random_params = False)')
def test_random_choice():
# Case 1: given probability
pipeline = RandomChoice(
transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
prob=[1.0, 0.0])
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0)
# Case 2: default probability
pipeline = RandomChoice(transforms=[[AddToValue(
addend=1.0)], [AddToValue(addend=2.0)]])
_ = pipeline(dict(value=1))
# Case 3: nested RandomChoice in TransformBroadcaster
pipeline = TransformBroadcaster(
transforms=[
RandomChoice(
transforms=[[AddToValue(addend=1.0)],
[AddToValue(addend=2.0)]], ),
],
mapping={'value': 'values'},
auto_remap=True,
share_random_params=True)
results = dict(values=[0 for _ in range(10)])
results = pipeline(results)
# check share_random_params=True works so that all values are same
values = results['values']
assert all(map(lambda x: x == values[0], values))
# repr
assert repr(pipeline) == (
'TransformBroadcaster(transforms = Compose(\n' +
' RandomChoice(transforms = [Compose(\n' +
' AddToValueaddend = 1.0' + '\n), Compose(\n' +
' AddToValueaddend = 2.0' + '\n)]prob = None)' +
'\n), mapping = {\'value\': \'values\'}, ' +
'remapping = {\'value\': \'values\'}, auto_remap = True, ' +
'allow_nonexist_keys = False, share_random_params = True)')
def test_random_apply():
# Case 1: simple use
pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=1.0)
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0)
pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=0.0)
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 1.0)
# Case 2: nested RandomApply in TransformBroadcaster
pipeline = TransformBroadcaster(
transforms=[RandomApply(transforms=[AddToValue(addend=1)], prob=0.5)],
mapping={'value': 'values'},
auto_remap=True,
share_random_params=True)
results = dict(values=[0 for _ in range(10)])
results = pipeline(results)
# check share_random_params=True works so that all values are same
values = results['values']
assert all(map(lambda x: x == values[0], values))
# __iter__
for _ in pipeline:
pass
# repr
assert repr(pipeline) == (
'TransformBroadcaster(transforms = Compose(\n' +
' RandomApply(transforms = Compose(\n' +
' AddToValueaddend = 1' + '\n), prob = 0.5)' +
'\n), mapping = {\'value\': \'values\'}, ' +
'remapping = {\'value\': \'values\'}, auto_remap = True, ' +
'allow_nonexist_keys = False, share_random_params = True)')
def test_utils():
# Test cache_randomness: normal case
class DummyTransform(BaseTransform):
@cache_randomness
def func(self):
return np.random.rand()
def transform(self, results):
_ = self.func()
return results
transform = DummyTransform()
_ = transform({})
with cache_random_params(transform):
_ = transform({})
# Test cache_randomness: invalid function type
with pytest.raises(TypeError):
class DummyTransform(BaseTransform):
@cache_randomness
@staticmethod
def func():
return np.random.rand()
def transform(self, results):
return results
# Test cache_randomness: invalid function argument list
with pytest.raises(TypeError):
class DummyTransform(BaseTransform):
@cache_randomness
def func(cls):
return np.random.rand()
def transform(self, results):
return results
# Test avoid_cache_randomness: invalid mixture with cache_randomness
with pytest.raises(RuntimeError):
@avoid_cache_randomness
class DummyTransform(BaseTransform):
@cache_randomness
def func(self):
pass
def transform(self, results):
return results
# Test avoid_cache_randomness: raise error in cache_random_params
with pytest.raises(RuntimeError):
@avoid_cache_randomness
class DummyTransform(BaseTransform):
def transform(self, results):
return results
transform = DummyTransform()
with cache_random_params(transform):
pass
# Test avoid_cache_randomness: non-inheritable
@avoid_cache_randomness
class DummyBaseTransform(BaseTransform):
def transform(self, results):
return results
class DummyTransform(DummyBaseTransform):
pass
transform = DummyTransform()
with cache_random_params(transform):
pass