[Feature] Add decorator `avoid_cache_randomness` (#1864)

* add prohibit_cache_randomness

* rename as avoid_cache_randomness and ensure it is non-inheritable

* fix lint

* update docs
pull/2133/head
Yining Li 2022-04-27 23:02:46 +08:00 committed by zhouzaida
parent ea84b67449
commit f59aec8ffb
3 changed files with 157 additions and 21 deletions

View File

@ -229,7 +229,7 @@ pipeline = [
auto_remap=True,
# 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享)
share_random_param=True,
share_random_params=True,
transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
dict(type='RandomFlip'),
@ -249,7 +249,7 @@ pipeline = [
# 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中
auto_remap=True,
# 是否在对各目标的变换中共享随机变量
share_random_param=True,
share_random_params=True,
transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
dict(type='RandomFlip'),
@ -257,7 +257,10 @@ pipeline = [
]
```
`TransformBroadcaster` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。
#### 装饰器 `cache_randomness`
`TransformBroadcaster` 中,我们提供了 `share_random_params` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,需要在类中标注哪些随机变量是支持共享的。这可以通过装饰器 `cache_randomness` 来实现。
以上文中的 `MyFlip` 为例,我们希望以一定的概率随机执行翻转:
@ -283,4 +286,29 @@ class MyRandomFlip(BaseTransform):
return results
```
通过 `cache_randomness` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。
在上面的例子中,我们用`cache_randomness` 装饰 `do_flip`方法,即将该方法返回值 `flip` 标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。
#### 装饰器 `avoid_cache_randomness`
在一些情况下,我们无法将数据变换中产生随机变量的过程单独放在类方法中。例如,数据变换中使用了来自第三方库的模块,这些模块将随机变量相关的部分封装在了内部,导致无法将其抽出为数据变换的类方法。这样的数据变换无法通过装饰器 `cache_randomness` 标注支持共享的随机变量,进而无法在多目标扩展时共享随机变量。
为了避免在多目标扩展中误用此类数据变换,我们提供了另一个装饰器 `avoid_cache_randomness`,用来对此类数据变换进行标记:
```python
from mmcv.transforms.utils import avoid_cache_randomness
@TRANSFORMS.register_module()
@avoid_cache_randomness
class MyRandomTransform(BaseTransform):
def transform(self, results: dict) -> dict:
...
```
`avoid_cache_randomness` 标记的数据变换类,当其实例被 `TransformBroadcaster` 包装且将参数 `share_random_params` 设置为 True 时,会抛出异常,以此提醒用户不能这样使用。
在使用 `avoid_cache_randomness` 时需要注意以下几点:
1. `avoid_cache_randomness` 只用于装饰数据变换类BaseTransfrom 的子类),而不能用与装饰其他一般的类、类方法或函数
2. 被 `avoid_cache_randomness` 修饰的数据变换作为基类时,其子类将**不会继承**这一特性。如果子类仍无法共享随机变量,则应再次使用 `avoid_cache_randomness` 修饰
3. 只有当一个数据变换具有随机性,且无法共享随机参数时,才需要以 `avoid_cache_randomness` 修饰。无随机性的数据变换不需要修饰

View File

@ -21,7 +21,7 @@ class cache_randomness:
return the cached values when being invoked again.
.. note::
Only an instance method can be decorated by ``cache_randomness``.
Only an instance method can be decorated with ``cache_randomness``.
"""
def __init__(self, func):
@ -44,7 +44,11 @@ class cache_randomness:
# Maintain a record of decorated methods in the class
if not hasattr(owner, '_methods_with_randomness'):
setattr(owner, '_methods_with_randomness', [])
owner._methods_with_randomness.append(self.__name__)
# Here `name` equals to `self.__name__`, i.e., the name of the
# decorated function, due to the invocation of `update_wrapper` in
# `self.__init__()`
owner._methods_with_randomness.append(name)
def __call__(self, *args, **kwargs):
# Get the transform instance whose method is decorated
@ -79,10 +83,55 @@ class cache_randomness:
return self
def avoid_cache_randomness(cls):
"""Decorator that marks a data transform class (subclass of
:class:`BaseTransform`) prohibited from caching randomness. With this
decorator, errors will be raised in following cases:
1. A method is defined in the class with the decorate
`cache_randomness`;
2. An instance of the class is invoked with the context
`cache_random_params`.
A typical usage of `avoid_cache_randomness` is to decorate the data
transforms with non-cacheable random behaviors (e.g., the random behavior
can not be defined in a method, thus can not be decorated with
`cache_randomness`). This is for preventing unintentinoal use of such data
transforms within the context of caching randomness, which may lead to
unexpected results.
"""
# Check that cls is a data transform class
assert issubclass(cls, BaseTransform)
# Check that no method is decorated with `cache_randomness` in cls
if getattr(cls, '_methods_with_randomness', None):
raise RuntimeError(
f'Class {cls.__name__} decorated with '
'``avoid_cache_randomness`` should not have methods decorated '
'with ``cache_randomness`` (invalid methods: '
f'{cls._methods_with_randomness})')
class AvoidCacheRandomness:
def __get__(self, obj, objtype=None):
# Here we check the value in `objtype.__dict__` instead of
# directly checking the attribute
# `objtype._avoid_cache_randomness`. So if the base class is
# decorated with :func:`avoid_cache_randomness`, it will not be
# inherited by subclasses.
return objtype.__dict__.get('_avoid_cache_randomness', False)
cls.avoid_cache_randomness = AvoidCacheRandomness()
cls._avoid_cache_randomness = True
return cls
@contextmanager
def cache_random_params(transforms: Union[BaseTransform, Iterable]):
"""Context-manager that enables the cache of return values of methods
decorated by ``cache_randomness`` in transforms.
decorated with ``cache_randomness`` in transforms.
In this mode, decorated methods will cache their return values on the
first invoking, and always return the cached value afterward. This allow
@ -136,15 +185,27 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
key = f'{id(obj)}.{name}'
if key2counter[key] > 1:
raise RuntimeError(
'The method decorated by ``cache_randomness`` should '
'be invoked at most once during processing one data '
f'sample. The method {name} of {obj} has been invoked'
f' {key2counter[key]} times.')
'The method decorated with ``cache_randomness`` '
'should be invoked at most once during processing '
f'one data sample. The method {name} of {obj} has '
f'been invoked {key2counter[key]} times.')
return output
return wrapped
def _start_cache(t: BaseTransform):
# Check if cache is allowed for `t`
if getattr(t, 'avoid_cache_randomness', False):
raise RuntimeError(
f'Class {t.__class__.__name__} decorated with '
'``avoid_cache_randomness`` is not allowed to be used with'
' ``cache_random_params`` (e.g. wrapped by '
'``ApplyToMultiple`` with ``share_random_params==True``).')
# Skip transforms w/o random method
if not hasattr(t, '_methods_with_randomness'):
return
# Set cache enabled flag
setattr(t, '_cache_enabled', True)
@ -155,6 +216,10 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
setattr(t, name, _add_invoke_counter(t, name))
def _end_cache(t: BaseTransform):
# Skip transforms w/o random method
if not hasattr(t, '_methods_with_randomness'):
return
# Remove cache enabled flag
del t._cache_enabled
if hasattr(t, '_cache'):
@ -169,13 +234,9 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
key_transform = f'{id(t)}.transform'
setattr(t, 'transform', key2method[key_transform])
def _apply(t: Union[BaseTransform, Iterable],
func: Callable[[BaseTransform], None]):
# Note that BaseTransform and Iterable are not mutually exclusive,
# e.g. Compose, RandomChoice
def _apply(t: BaseTransform, func: Callable[[BaseTransform], None]):
if isinstance(t, BaseTransform):
if hasattr(t, '_methods_with_randomness'):
func(t)
func(t)
if isinstance(t, Iterable):
for _t in t:
_apply(_t, func)

View File

@ -6,7 +6,8 @@ import pytest
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
from mmcv.transforms.utils import cache_random_params, cache_randomness
from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
cache_randomness)
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
RandomChoice, TransformBroadcaster)
@ -274,7 +275,7 @@ def test_key_mapper():
_ = str(pipeline)
def test_apply_to_multiple():
def test_transform_broadcaster():
# Case 1: apply to list in results
pipeline = TransformBroadcaster(
@ -421,18 +422,64 @@ def test_utils():
# Test cache_randomness: invalid function type
with pytest.raises(TypeError):
class DummyTransform():
class DummyTransform(BaseTransform):
@cache_randomness
@staticmethod
def func():
return np.random.rand()
def transform(self, results):
return results
# Test cache_randomness: invalid function argument list
with pytest.raises(TypeError):
class DummyTransform():
class DummyTransform(BaseTransform):
@cache_randomness
def func(cls):
return np.random.rand()
def transform(self, results):
return results
# Test avoid_cache_randomness: invalid mixture with cache_randomness
with pytest.raises(RuntimeError):
@avoid_cache_randomness
class DummyTransform(BaseTransform):
@cache_randomness
def func(self):
pass
def transform(self, results):
return results
# Test avoid_cache_randomness: raise error in cache_random_params
with pytest.raises(RuntimeError):
@avoid_cache_randomness
class DummyTransform(BaseTransform):
def transform(self, results):
return results
transform = DummyTransform()
with cache_random_params(transform):
pass
# Test avoid_cache_randomness: non-inheritable
@avoid_cache_randomness
class DummyBaseTransform(BaseTransform):
def transform(self, results):
return results
class DummyTransform(DummyBaseTransform):
pass
transform = DummyTransform()
with cache_random_params(transform):
pass