[Feature] Add decorator avoid_cache_randomness (#1864)

* add prohibit_cache_randomness

* rename as avoid_cache_randomness and ensure it is non-inheritable

* fix lint

* update docs
This commit is contained in:
Yining Li 2022-04-27 23:02:46 +08:00 committed by zhouzaida
parent ea84b67449
commit f59aec8ffb
3 changed files with 157 additions and 21 deletions

View File

@ -229,7 +229,7 @@ pipeline = [
auto_remap=True, auto_remap=True,
# 是否在对各目标的变换中共享随机变量 # 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享) # 更多介绍参加后续章节(随机变量共享)
share_random_param=True, share_random_params=True,
transforms=[ transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可 # 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
dict(type='RandomFlip'), dict(type='RandomFlip'),
@ -249,7 +249,7 @@ pipeline = [
# 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中 # 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中
auto_remap=True, auto_remap=True,
# 是否在对各目标的变换中共享随机变量 # 是否在对各目标的变换中共享随机变量
share_random_param=True, share_random_params=True,
transforms=[ transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可 # 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
dict(type='RandomFlip'), dict(type='RandomFlip'),
@ -257,7 +257,10 @@ pipeline = [
] ]
``` ```
`TransformBroadcaster` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。
#### 装饰器 `cache_randomness`
`TransformBroadcaster` 中,我们提供了 `share_random_params` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,需要在类中标注哪些随机变量是支持共享的。这可以通过装饰器 `cache_randomness` 来实现。
以上文中的 `MyFlip` 为例,我们希望以一定的概率随机执行翻转: 以上文中的 `MyFlip` 为例,我们希望以一定的概率随机执行翻转:
@ -283,4 +286,29 @@ class MyRandomFlip(BaseTransform):
return results return results
``` ```
通过 `cache_randomness` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。 在上面的例子中,我们用`cache_randomness` 装饰 `do_flip`方法,即将该方法返回值 `flip` 标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。
#### 装饰器 `avoid_cache_randomness`
在一些情况下,我们无法将数据变换中产生随机变量的过程单独放在类方法中。例如,数据变换中使用了来自第三方库的模块,这些模块将随机变量相关的部分封装在了内部,导致无法将其抽出为数据变换的类方法。这样的数据变换无法通过装饰器 `cache_randomness` 标注支持共享的随机变量,进而无法在多目标扩展时共享随机变量。
为了避免在多目标扩展中误用此类数据变换,我们提供了另一个装饰器 `avoid_cache_randomness`,用来对此类数据变换进行标记:
```python
from mmcv.transforms.utils import avoid_cache_randomness
@TRANSFORMS.register_module()
@avoid_cache_randomness
class MyRandomTransform(BaseTransform):
def transform(self, results: dict) -> dict:
...
```
`avoid_cache_randomness` 标记的数据变换类,当其实例被 `TransformBroadcaster` 包装且将参数 `share_random_params` 设置为 True 时,会抛出异常,以此提醒用户不能这样使用。
在使用 `avoid_cache_randomness` 时需要注意以下几点:
1. `avoid_cache_randomness` 只用于装饰数据变换类BaseTransfrom 的子类),而不能用与装饰其他一般的类、类方法或函数
2. 被 `avoid_cache_randomness` 修饰的数据变换作为基类时,其子类将**不会继承**这一特性。如果子类仍无法共享随机变量,则应再次使用 `avoid_cache_randomness` 修饰
3. 只有当一个数据变换具有随机性,且无法共享随机参数时,才需要以 `avoid_cache_randomness` 修饰。无随机性的数据变换不需要修饰

View File

@ -21,7 +21,7 @@ class cache_randomness:
return the cached values when being invoked again. return the cached values when being invoked again.
.. note:: .. note::
Only an instance method can be decorated by ``cache_randomness``. Only an instance method can be decorated with ``cache_randomness``.
""" """
def __init__(self, func): def __init__(self, func):
@ -44,7 +44,11 @@ class cache_randomness:
# Maintain a record of decorated methods in the class # Maintain a record of decorated methods in the class
if not hasattr(owner, '_methods_with_randomness'): if not hasattr(owner, '_methods_with_randomness'):
setattr(owner, '_methods_with_randomness', []) setattr(owner, '_methods_with_randomness', [])
owner._methods_with_randomness.append(self.__name__)
# Here `name` equals to `self.__name__`, i.e., the name of the
# decorated function, due to the invocation of `update_wrapper` in
# `self.__init__()`
owner._methods_with_randomness.append(name)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Get the transform instance whose method is decorated # Get the transform instance whose method is decorated
@ -79,10 +83,55 @@ class cache_randomness:
return self return self
def avoid_cache_randomness(cls):
"""Decorator that marks a data transform class (subclass of
:class:`BaseTransform`) prohibited from caching randomness. With this
decorator, errors will be raised in following cases:
1. A method is defined in the class with the decorate
`cache_randomness`;
2. An instance of the class is invoked with the context
`cache_random_params`.
A typical usage of `avoid_cache_randomness` is to decorate the data
transforms with non-cacheable random behaviors (e.g., the random behavior
can not be defined in a method, thus can not be decorated with
`cache_randomness`). This is for preventing unintentinoal use of such data
transforms within the context of caching randomness, which may lead to
unexpected results.
"""
# Check that cls is a data transform class
assert issubclass(cls, BaseTransform)
# Check that no method is decorated with `cache_randomness` in cls
if getattr(cls, '_methods_with_randomness', None):
raise RuntimeError(
f'Class {cls.__name__} decorated with '
'``avoid_cache_randomness`` should not have methods decorated '
'with ``cache_randomness`` (invalid methods: '
f'{cls._methods_with_randomness})')
class AvoidCacheRandomness:
def __get__(self, obj, objtype=None):
# Here we check the value in `objtype.__dict__` instead of
# directly checking the attribute
# `objtype._avoid_cache_randomness`. So if the base class is
# decorated with :func:`avoid_cache_randomness`, it will not be
# inherited by subclasses.
return objtype.__dict__.get('_avoid_cache_randomness', False)
cls.avoid_cache_randomness = AvoidCacheRandomness()
cls._avoid_cache_randomness = True
return cls
@contextmanager @contextmanager
def cache_random_params(transforms: Union[BaseTransform, Iterable]): def cache_random_params(transforms: Union[BaseTransform, Iterable]):
"""Context-manager that enables the cache of return values of methods """Context-manager that enables the cache of return values of methods
decorated by ``cache_randomness`` in transforms. decorated with ``cache_randomness`` in transforms.
In this mode, decorated methods will cache their return values on the In this mode, decorated methods will cache their return values on the
first invoking, and always return the cached value afterward. This allow first invoking, and always return the cached value afterward. This allow
@ -136,15 +185,27 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
key = f'{id(obj)}.{name}' key = f'{id(obj)}.{name}'
if key2counter[key] > 1: if key2counter[key] > 1:
raise RuntimeError( raise RuntimeError(
'The method decorated by ``cache_randomness`` should ' 'The method decorated with ``cache_randomness`` '
'be invoked at most once during processing one data ' 'should be invoked at most once during processing '
f'sample. The method {name} of {obj} has been invoked' f'one data sample. The method {name} of {obj} has '
f' {key2counter[key]} times.') f'been invoked {key2counter[key]} times.')
return output return output
return wrapped return wrapped
def _start_cache(t: BaseTransform): def _start_cache(t: BaseTransform):
# Check if cache is allowed for `t`
if getattr(t, 'avoid_cache_randomness', False):
raise RuntimeError(
f'Class {t.__class__.__name__} decorated with '
'``avoid_cache_randomness`` is not allowed to be used with'
' ``cache_random_params`` (e.g. wrapped by '
'``ApplyToMultiple`` with ``share_random_params==True``).')
# Skip transforms w/o random method
if not hasattr(t, '_methods_with_randomness'):
return
# Set cache enabled flag # Set cache enabled flag
setattr(t, '_cache_enabled', True) setattr(t, '_cache_enabled', True)
@ -155,6 +216,10 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
setattr(t, name, _add_invoke_counter(t, name)) setattr(t, name, _add_invoke_counter(t, name))
def _end_cache(t: BaseTransform): def _end_cache(t: BaseTransform):
# Skip transforms w/o random method
if not hasattr(t, '_methods_with_randomness'):
return
# Remove cache enabled flag # Remove cache enabled flag
del t._cache_enabled del t._cache_enabled
if hasattr(t, '_cache'): if hasattr(t, '_cache'):
@ -169,12 +234,8 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
key_transform = f'{id(t)}.transform' key_transform = f'{id(t)}.transform'
setattr(t, 'transform', key2method[key_transform]) setattr(t, 'transform', key2method[key_transform])
def _apply(t: Union[BaseTransform, Iterable], def _apply(t: BaseTransform, func: Callable[[BaseTransform], None]):
func: Callable[[BaseTransform], None]):
# Note that BaseTransform and Iterable are not mutually exclusive,
# e.g. Compose, RandomChoice
if isinstance(t, BaseTransform): if isinstance(t, BaseTransform):
if hasattr(t, '_methods_with_randomness'):
func(t) func(t)
if isinstance(t, Iterable): if isinstance(t, Iterable):
for _t in t: for _t in t:

View File

@ -6,7 +6,8 @@ import pytest
from mmcv.transforms.base import BaseTransform from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS from mmcv.transforms.builder import TRANSFORMS
from mmcv.transforms.utils import cache_random_params, cache_randomness from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
cache_randomness)
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply, from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
RandomChoice, TransformBroadcaster) RandomChoice, TransformBroadcaster)
@ -274,7 +275,7 @@ def test_key_mapper():
_ = str(pipeline) _ = str(pipeline)
def test_apply_to_multiple(): def test_transform_broadcaster():
# Case 1: apply to list in results # Case 1: apply to list in results
pipeline = TransformBroadcaster( pipeline = TransformBroadcaster(
@ -421,18 +422,64 @@ def test_utils():
# Test cache_randomness: invalid function type # Test cache_randomness: invalid function type
with pytest.raises(TypeError): with pytest.raises(TypeError):
class DummyTransform(): class DummyTransform(BaseTransform):
@cache_randomness @cache_randomness
@staticmethod @staticmethod
def func(): def func():
return np.random.rand() return np.random.rand()
def transform(self, results):
return results
# Test cache_randomness: invalid function argument list # Test cache_randomness: invalid function argument list
with pytest.raises(TypeError): with pytest.raises(TypeError):
class DummyTransform(): class DummyTransform(BaseTransform):
@cache_randomness @cache_randomness
def func(cls): def func(cls):
return np.random.rand() return np.random.rand()
def transform(self, results):
return results
# Test avoid_cache_randomness: invalid mixture with cache_randomness
with pytest.raises(RuntimeError):
@avoid_cache_randomness
class DummyTransform(BaseTransform):
@cache_randomness
def func(self):
pass
def transform(self, results):
return results
# Test avoid_cache_randomness: raise error in cache_random_params
with pytest.raises(RuntimeError):
@avoid_cache_randomness
class DummyTransform(BaseTransform):
def transform(self, results):
return results
transform = DummyTransform()
with cache_random_params(transform):
pass
# Test avoid_cache_randomness: non-inheritable
@avoid_cache_randomness
class DummyBaseTransform(BaseTransform):
def transform(self, results):
return results
class DummyTransform(DummyBaseTransform):
pass
transform = DummyTransform()
with cache_random_params(transform):
pass