mirror of https://github.com/open-mmlab/mmcv.git
586 lines
18 KiB
Python
586 lines
18 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from mmcv.transforms.base import BaseTransform
|
|
from mmcv.transforms.builder import TRANSFORMS
|
|
from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
|
|
cache_randomness)
|
|
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
|
|
RandomChoice, TransformBroadcaster)
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class AddToValue(BaseTransform):
|
|
"""Dummy transform to add a given addend to results['value']"""
|
|
|
|
def __init__(self, addend=0) -> None:
|
|
super().__init__()
|
|
self.addend = addend
|
|
|
|
def add(self, results, addend):
|
|
augend = results['value']
|
|
|
|
if isinstance(augend, list):
|
|
warnings.warn('value is a list', UserWarning)
|
|
if isinstance(augend, dict):
|
|
warnings.warn('value is a dict', UserWarning)
|
|
|
|
def _add_to_value(augend, addend):
|
|
if isinstance(augend, list):
|
|
return [_add_to_value(v, addend) for v in augend]
|
|
if isinstance(augend, dict):
|
|
return {k: _add_to_value(v, addend) for k, v in augend.items()}
|
|
return augend + addend
|
|
|
|
results['value'] = _add_to_value(results['value'], addend)
|
|
return results
|
|
|
|
def transform(self, results):
|
|
return self.add(results, self.addend)
|
|
|
|
def __repr__(self) -> str:
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'addend = {self.addend}'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class RandomAddToValue(AddToValue):
|
|
"""Dummy transform to add a random addend to results['value']"""
|
|
|
|
def __init__(self, repeat=1) -> None:
|
|
super().__init__(addend=None)
|
|
self.repeat = repeat
|
|
|
|
@cache_randomness
|
|
def get_random_addend(self):
|
|
return np.random.rand()
|
|
|
|
def transform(self, results):
|
|
for _ in range(self.repeat):
|
|
results = self.add(results, addend=self.get_random_addend())
|
|
return results
|
|
|
|
def __repr__(self) -> str:
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'repeat = {self.repeat}'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class SumTwoValues(BaseTransform):
|
|
"""Dummy transform to test transform wrappers."""
|
|
|
|
def transform(self, results):
|
|
if 'num_1' in results and 'num_2' in results:
|
|
results['sum'] = results['num_1'] + results['num_2']
|
|
elif 'num_1' in results:
|
|
results['sum'] = results['num_1']
|
|
elif 'num_2' in results:
|
|
results['sum'] = results['num_2']
|
|
else:
|
|
results['sum'] = np.nan
|
|
return results
|
|
|
|
def __repr__(self) -> str:
|
|
repr_str = self.__class__.__name__
|
|
return repr_str
|
|
|
|
|
|
def test_compose():
|
|
|
|
# Case 1: build from cfg
|
|
pipeline = [dict(type='AddToValue')]
|
|
pipeline = Compose(pipeline)
|
|
_ = str(pipeline)
|
|
|
|
# Case 2: build from transform list
|
|
pipeline = [AddToValue()]
|
|
pipeline = Compose(pipeline)
|
|
|
|
# Case 3: invalid build arguments
|
|
pipeline = [[dict(type='AddToValue')]]
|
|
with pytest.raises(TypeError):
|
|
pipeline = Compose(pipeline)
|
|
|
|
# Case 4: contain transform with None output
|
|
class DummyTransform(BaseTransform):
|
|
|
|
def transform(self, results):
|
|
return None
|
|
|
|
pipeline = Compose([DummyTransform()])
|
|
results = pipeline({})
|
|
assert results is None
|
|
|
|
|
|
def test_cache_random_parameters():
|
|
|
|
transform = RandomAddToValue()
|
|
|
|
# Case 1: cache random parameters
|
|
assert hasattr(RandomAddToValue, '_methods_with_randomness')
|
|
assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
|
|
|
|
with cache_random_params(transform):
|
|
results_1 = transform(dict(value=0))
|
|
results_2 = transform(dict(value=0))
|
|
np.testing.assert_equal(results_1['value'], results_2['value'])
|
|
|
|
# Case 2: do not cache random parameters
|
|
results_1 = transform(dict(value=0))
|
|
results_2 = transform(dict(value=0))
|
|
with pytest.raises(AssertionError):
|
|
np.testing.assert_equal(results_1['value'], results_2['value'])
|
|
|
|
# Case 3: allow to invoke random method 0 times
|
|
transform = RandomAddToValue(repeat=0)
|
|
with cache_random_params(transform):
|
|
_ = transform(dict(value=0))
|
|
|
|
# Case 4: NOT allow to invoke random method >1 times
|
|
transform = RandomAddToValue(repeat=2)
|
|
with pytest.raises(RuntimeError):
|
|
with cache_random_params(transform):
|
|
_ = transform(dict(value=0))
|
|
|
|
# Case 5: apply on nested transforms
|
|
transform = Compose([RandomAddToValue()])
|
|
with cache_random_params(transform):
|
|
results_1 = transform(dict(value=0))
|
|
results_2 = transform(dict(value=0))
|
|
np.testing.assert_equal(results_1['value'], results_2['value'])
|
|
|
|
|
|
def test_key_mapper():
|
|
# Case 0: only remap
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=1)], remapping={'value': 'v_out'})
|
|
|
|
results = dict(value=0)
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['value'], 0) # should be unchanged
|
|
np.testing.assert_equal(results['v_out'], 1)
|
|
|
|
# Case 1: simple remap
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=1)],
|
|
mapping={'value': 'v_in'},
|
|
remapping={'value': 'v_out'})
|
|
|
|
results = dict(value=0, v_in=1)
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['value'], 0) # should be unchanged
|
|
np.testing.assert_equal(results['v_in'], 1)
|
|
np.testing.assert_equal(results['v_out'], 2)
|
|
|
|
# Case 2: collecting list
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=2)],
|
|
mapping={'value': ['v_in_1', 'v_in_2']},
|
|
remapping={'value': ['v_out_1', 'v_out_2']})
|
|
results = dict(value=0, v_in_1=1, v_in_2=2)
|
|
|
|
with pytest.warns(UserWarning, match='value is a list'):
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['value'], 0) # should be unchanged
|
|
np.testing.assert_equal(results['v_in_1'], 1)
|
|
np.testing.assert_equal(results['v_in_2'], 2)
|
|
np.testing.assert_equal(results['v_out_1'], 3)
|
|
np.testing.assert_equal(results['v_out_2'], 4)
|
|
|
|
# Case 3: collecting dict
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=2)],
|
|
mapping={'value': {
|
|
'v1': 'v_in_1',
|
|
'v2': 'v_in_2'
|
|
}},
|
|
remapping={'value': {
|
|
'v1': 'v_out_1',
|
|
'v2': 'v_out_2'
|
|
}})
|
|
results = dict(value=0, v_in_1=1, v_in_2=2)
|
|
|
|
with pytest.warns(UserWarning, match='value is a dict'):
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['value'], 0) # should be unchanged
|
|
np.testing.assert_equal(results['v_in_1'], 1)
|
|
np.testing.assert_equal(results['v_in_2'], 2)
|
|
np.testing.assert_equal(results['v_out_1'], 3)
|
|
np.testing.assert_equal(results['v_out_2'], 4)
|
|
|
|
# Case 4: collecting list with auto_remap mode
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=2)],
|
|
mapping=dict(value=['v_in_1', 'v_in_2']),
|
|
auto_remap=True)
|
|
results = dict(value=0, v_in_1=1, v_in_2=2)
|
|
|
|
with pytest.warns(UserWarning, match='value is a list'):
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['value'], 0)
|
|
np.testing.assert_equal(results['v_in_1'], 3)
|
|
np.testing.assert_equal(results['v_in_2'], 4)
|
|
|
|
# Case 5: collecting dict with auto_remap mode
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=2)],
|
|
mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
|
|
auto_remap=True)
|
|
results = dict(value=0, v_in_1=1, v_in_2=2)
|
|
|
|
with pytest.warns(UserWarning, match='value is a dict'):
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['value'], 0)
|
|
np.testing.assert_equal(results['v_in_1'], 3)
|
|
np.testing.assert_equal(results['v_in_2'], 4)
|
|
|
|
# Case 6: nested collection with auto_remap mode
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=2)],
|
|
mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
|
|
auto_remap=True)
|
|
results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
|
|
|
|
with pytest.warns(UserWarning, match='value is a list'):
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['value'], 0)
|
|
np.testing.assert_equal(results['v1'], 3)
|
|
np.testing.assert_equal(results['v21'], 4)
|
|
np.testing.assert_equal(results['v22'], 5)
|
|
np.testing.assert_equal(results['v3'], 6)
|
|
|
|
# Case 7: output_map must be None if `auto_remap` is set True
|
|
with pytest.raises(ValueError):
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=1)],
|
|
mapping=dict(value='v_in'),
|
|
remapping=dict(value='v_out'),
|
|
auto_remap=True)
|
|
|
|
# Case 8: allow_nonexist_keys8
|
|
pipeline = KeyMapper(
|
|
transforms=[SumTwoValues()],
|
|
mapping=dict(num_1='a', num_2='b'),
|
|
auto_remap=False,
|
|
allow_nonexist_keys=True)
|
|
|
|
results = pipeline(dict(a=1, b=2))
|
|
np.testing.assert_equal(results['sum'], 3)
|
|
|
|
results = pipeline(dict(a=1))
|
|
np.testing.assert_equal(results['sum'], 1)
|
|
|
|
# Case 9: use wrapper as a transform
|
|
transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
|
|
results = transform(dict(a=1))
|
|
# note that the original key 'a' will not be removed
|
|
assert results == dict(a=1, b=1)
|
|
|
|
# Case 10: manually set keys ignored
|
|
pipeline = KeyMapper(
|
|
transforms=[SumTwoValues()],
|
|
mapping=dict(num_1='a', num_2=...), # num_2 (b) will be ignored
|
|
auto_remap=False,
|
|
# allow_nonexist_keys will not affect manually ignored keys
|
|
allow_nonexist_keys=False)
|
|
|
|
results = pipeline(dict(a=1, b=2))
|
|
np.testing.assert_equal(results['sum'], 1)
|
|
|
|
# Test basic functions
|
|
pipeline = KeyMapper(
|
|
transforms=[AddToValue(addend=1)],
|
|
mapping=dict(value='v_in'),
|
|
remapping=dict(value='v_out'))
|
|
|
|
# __iter__
|
|
for _ in pipeline:
|
|
pass
|
|
|
|
# __repr__
|
|
assert repr(pipeline) == (
|
|
'KeyMapper(transforms = Compose(\n ' + 'AddToValueaddend = 1' +
|
|
'\n), mapping = {\'value\': \'v_in\'}, ' +
|
|
'remapping = {\'value\': \'v_out\'}, auto_remap = False, ' +
|
|
'allow_nonexist_keys = False)')
|
|
|
|
|
|
def test_transform_broadcaster():
|
|
|
|
# Case 1: apply to list in results
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[AddToValue(addend=1)],
|
|
mapping=dict(value='values'),
|
|
auto_remap=True)
|
|
results = dict(values=[1, 2])
|
|
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['values'], [2, 3])
|
|
|
|
# Case 2: apply to multiple keys
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[AddToValue(addend=1)],
|
|
mapping=dict(value=['v_1', 'v_2']),
|
|
auto_remap=True)
|
|
results = dict(v_1=1, v_2=2)
|
|
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['v_1'], 2)
|
|
np.testing.assert_equal(results['v_2'], 3)
|
|
|
|
# Case 3: apply to multiple groups of keys
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[SumTwoValues()],
|
|
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
|
|
remapping=dict(sum=['a', 'b']),
|
|
auto_remap=False)
|
|
|
|
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['a'], 3)
|
|
np.testing.assert_equal(results['b'], 7)
|
|
|
|
# Case 3: apply to all keys
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[SumTwoValues()], mapping=None, remapping=None)
|
|
results = dict(num_1=[1, 2, 3], num_2=[4, 5, 6])
|
|
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['sum'], [5, 7, 9])
|
|
|
|
# Case 4: inconsistent sequence length
|
|
with pytest.raises(ValueError):
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[SumTwoValues()],
|
|
mapping=dict(num_1='list_1', num_2='list_2'),
|
|
auto_remap=False)
|
|
|
|
results = dict(list_1=[1, 2], list_2=[1, 2, 3])
|
|
_ = pipeline(results)
|
|
|
|
# Case 5: share random parameter
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[RandomAddToValue()],
|
|
mapping=dict(value='values'),
|
|
auto_remap=True,
|
|
share_random_params=True)
|
|
|
|
results = dict(values=[0, 0])
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['values'][0], results['values'][1])
|
|
|
|
# Case 6: partial broadcasting
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[SumTwoValues()],
|
|
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', ...]),
|
|
remapping=dict(sum=['a', 'b']),
|
|
auto_remap=False)
|
|
|
|
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['a'], 3)
|
|
np.testing.assert_equal(results['b'], 3)
|
|
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[SumTwoValues()],
|
|
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
|
|
remapping=dict(sum=['a', ...]),
|
|
auto_remap=False)
|
|
|
|
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
|
|
results = pipeline(results)
|
|
|
|
np.testing.assert_equal(results['a'], 3)
|
|
assert 'b' not in results
|
|
|
|
# Test repr
|
|
assert repr(pipeline) == (
|
|
'TransformBroadcaster(transforms = Compose(\n' + ' SumTwoValues' +
|
|
'\n), mapping = {\'num_1\': [\'a_1\', \'b_1\'], ' +
|
|
'\'num_2\': [\'a_2\', \'b_2\']}, ' +
|
|
'remapping = {\'sum\': [\'a\', Ellipsis]}, auto_remap = False, ' +
|
|
'allow_nonexist_keys = False, share_random_params = False)')
|
|
|
|
|
|
def test_random_choice():
|
|
|
|
# Case 1: given probability
|
|
pipeline = RandomChoice(
|
|
transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
|
|
prob=[1.0, 0.0])
|
|
|
|
results = pipeline(dict(value=1))
|
|
np.testing.assert_equal(results['value'], 2.0)
|
|
|
|
# Case 2: default probability
|
|
pipeline = RandomChoice(transforms=[[AddToValue(
|
|
addend=1.0)], [AddToValue(addend=2.0)]])
|
|
|
|
_ = pipeline(dict(value=1))
|
|
|
|
# Case 3: nested RandomChoice in TransformBroadcaster
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[
|
|
RandomChoice(
|
|
transforms=[[AddToValue(addend=1.0)],
|
|
[AddToValue(addend=2.0)]], ),
|
|
],
|
|
mapping={'value': 'values'},
|
|
auto_remap=True,
|
|
share_random_params=True)
|
|
|
|
results = dict(values=[0 for _ in range(10)])
|
|
results = pipeline(results)
|
|
# check share_random_params=True works so that all values are same
|
|
values = results['values']
|
|
assert all(map(lambda x: x == values[0], values))
|
|
|
|
# repr
|
|
assert repr(pipeline) == (
|
|
'TransformBroadcaster(transforms = Compose(\n' +
|
|
' RandomChoice(transforms = [Compose(\n' +
|
|
' AddToValueaddend = 1.0' + '\n), Compose(\n' +
|
|
' AddToValueaddend = 2.0' + '\n)]prob = None)' +
|
|
'\n), mapping = {\'value\': \'values\'}, ' +
|
|
'remapping = {\'value\': \'values\'}, auto_remap = True, ' +
|
|
'allow_nonexist_keys = False, share_random_params = True)')
|
|
|
|
|
|
def test_random_apply():
|
|
|
|
# Case 1: simple use
|
|
pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=1.0)
|
|
results = pipeline(dict(value=1))
|
|
np.testing.assert_equal(results['value'], 2.0)
|
|
|
|
pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=0.0)
|
|
results = pipeline(dict(value=1))
|
|
np.testing.assert_equal(results['value'], 1.0)
|
|
|
|
# Case 2: nested RandomApply in TransformBroadcaster
|
|
pipeline = TransformBroadcaster(
|
|
transforms=[RandomApply(transforms=[AddToValue(addend=1)], prob=0.5)],
|
|
mapping={'value': 'values'},
|
|
auto_remap=True,
|
|
share_random_params=True)
|
|
|
|
results = dict(values=[0 for _ in range(10)])
|
|
results = pipeline(results)
|
|
# check share_random_params=True works so that all values are same
|
|
values = results['values']
|
|
assert all(map(lambda x: x == values[0], values))
|
|
|
|
# __iter__
|
|
for _ in pipeline:
|
|
pass
|
|
|
|
# repr
|
|
assert repr(pipeline) == (
|
|
'TransformBroadcaster(transforms = Compose(\n' +
|
|
' RandomApply(transforms = Compose(\n' +
|
|
' AddToValueaddend = 1' + '\n), prob = 0.5)' +
|
|
'\n), mapping = {\'value\': \'values\'}, ' +
|
|
'remapping = {\'value\': \'values\'}, auto_remap = True, ' +
|
|
'allow_nonexist_keys = False, share_random_params = True)')
|
|
|
|
|
|
def test_utils():
|
|
# Test cache_randomness: normal case
|
|
class DummyTransform(BaseTransform):
|
|
|
|
@cache_randomness
|
|
def func(self):
|
|
return np.random.rand()
|
|
|
|
def transform(self, results):
|
|
_ = self.func()
|
|
return results
|
|
|
|
transform = DummyTransform()
|
|
_ = transform({})
|
|
with cache_random_params(transform):
|
|
_ = transform({})
|
|
|
|
# Test cache_randomness: invalid function type
|
|
with pytest.raises(TypeError):
|
|
|
|
class DummyTransform(BaseTransform):
|
|
|
|
@cache_randomness
|
|
@staticmethod
|
|
def func():
|
|
return np.random.rand()
|
|
|
|
def transform(self, results):
|
|
return results
|
|
|
|
# Test cache_randomness: invalid function argument list
|
|
with pytest.raises(TypeError):
|
|
|
|
class DummyTransform(BaseTransform):
|
|
|
|
@cache_randomness
|
|
def func(cls):
|
|
return np.random.rand()
|
|
|
|
def transform(self, results):
|
|
return results
|
|
|
|
# Test avoid_cache_randomness: invalid mixture with cache_randomness
|
|
with pytest.raises(RuntimeError):
|
|
|
|
@avoid_cache_randomness
|
|
class DummyTransform(BaseTransform):
|
|
|
|
@cache_randomness
|
|
def func(self):
|
|
pass
|
|
|
|
def transform(self, results):
|
|
return results
|
|
|
|
# Test avoid_cache_randomness: raise error in cache_random_params
|
|
with pytest.raises(RuntimeError):
|
|
|
|
@avoid_cache_randomness
|
|
class DummyTransform(BaseTransform):
|
|
|
|
def transform(self, results):
|
|
return results
|
|
|
|
transform = DummyTransform()
|
|
with cache_random_params(transform):
|
|
pass
|
|
|
|
# Test avoid_cache_randomness: non-inheritable
|
|
@avoid_cache_randomness
|
|
class DummyBaseTransform(BaseTransform):
|
|
|
|
def transform(self, results):
|
|
return results
|
|
|
|
class DummyTransform(DummyBaseTransform):
|
|
pass
|
|
|
|
transform = DummyTransform()
|
|
with cache_random_params(transform):
|
|
pass
|