mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add decorator `avoid_cache_randomness` (#1864)
* add prohibit_cache_randomness * rename as avoid_cache_randomness and ensure it is non-inheritable * fix lint * update docspull/2133/head
parent
ea84b67449
commit
f59aec8ffb
|
@ -229,7 +229,7 @@ pipeline = [
|
|||
auto_remap=True,
|
||||
# 是否在对各目标的变换中共享随机变量
|
||||
# 更多介绍参加后续章节(随机变量共享)
|
||||
share_random_param=True,
|
||||
share_random_params=True,
|
||||
transforms=[
|
||||
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
|
||||
dict(type='RandomFlip'),
|
||||
|
@ -249,7 +249,7 @@ pipeline = [
|
|||
# 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中
|
||||
auto_remap=True,
|
||||
# 是否在对各目标的变换中共享随机变量
|
||||
share_random_param=True,
|
||||
share_random_params=True,
|
||||
transforms=[
|
||||
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
|
||||
dict(type='RandomFlip'),
|
||||
|
@ -257,7 +257,10 @@ pipeline = [
|
|||
]
|
||||
```
|
||||
|
||||
在 `TransformBroadcaster` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。
|
||||
|
||||
#### 装饰器 `cache_randomness`
|
||||
|
||||
在 `TransformBroadcaster` 中,我们提供了 `share_random_params` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,需要在类中标注哪些随机变量是支持共享的。这可以通过装饰器 `cache_randomness` 来实现。
|
||||
|
||||
以上文中的 `MyFlip` 为例,我们希望以一定的概率随机执行翻转:
|
||||
|
||||
|
@ -283,4 +286,29 @@ class MyRandomFlip(BaseTransform):
|
|||
return results
|
||||
```
|
||||
|
||||
通过 `cache_randomness` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。
|
||||
在上面的例子中,我们用`cache_randomness` 装饰 `do_flip`方法,即将该方法返回值 `flip` 标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。
|
||||
|
||||
#### 装饰器 `avoid_cache_randomness`
|
||||
|
||||
在一些情况下,我们无法将数据变换中产生随机变量的过程单独放在类方法中。例如,数据变换中使用了来自第三方库的模块,这些模块将随机变量相关的部分封装在了内部,导致无法将其抽出为数据变换的类方法。这样的数据变换无法通过装饰器 `cache_randomness` 标注支持共享的随机变量,进而无法在多目标扩展时共享随机变量。
|
||||
|
||||
为了避免在多目标扩展中误用此类数据变换,我们提供了另一个装饰器 `avoid_cache_randomness`,用来对此类数据变换进行标记:
|
||||
|
||||
```python
|
||||
from mmcv.transforms.utils import avoid_cache_randomness
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
@avoid_cache_randomness
|
||||
class MyRandomTransform(BaseTransform):
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
...
|
||||
```
|
||||
|
||||
用 `avoid_cache_randomness` 标记的数据变换类,当其实例被 `TransformBroadcaster` 包装且将参数 `share_random_params` 设置为 True 时,会抛出异常,以此提醒用户不能这样使用。
|
||||
|
||||
在使用 `avoid_cache_randomness` 时需要注意以下几点:
|
||||
|
||||
1. `avoid_cache_randomness` 只用于装饰数据变换类(BaseTransfrom 的子类),而不能用与装饰其他一般的类、类方法或函数
|
||||
2. 被 `avoid_cache_randomness` 修饰的数据变换作为基类时,其子类将**不会继承**这一特性。如果子类仍无法共享随机变量,则应再次使用 `avoid_cache_randomness` 修饰
|
||||
3. 只有当一个数据变换具有随机性,且无法共享随机参数时,才需要以 `avoid_cache_randomness` 修饰。无随机性的数据变换不需要修饰
|
||||
|
|
|
@ -21,7 +21,7 @@ class cache_randomness:
|
|||
return the cached values when being invoked again.
|
||||
|
||||
.. note::
|
||||
Only an instance method can be decorated by ``cache_randomness``.
|
||||
Only an instance method can be decorated with ``cache_randomness``.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
|
@ -44,7 +44,11 @@ class cache_randomness:
|
|||
# Maintain a record of decorated methods in the class
|
||||
if not hasattr(owner, '_methods_with_randomness'):
|
||||
setattr(owner, '_methods_with_randomness', [])
|
||||
owner._methods_with_randomness.append(self.__name__)
|
||||
|
||||
# Here `name` equals to `self.__name__`, i.e., the name of the
|
||||
# decorated function, due to the invocation of `update_wrapper` in
|
||||
# `self.__init__()`
|
||||
owner._methods_with_randomness.append(name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Get the transform instance whose method is decorated
|
||||
|
@ -79,10 +83,55 @@ class cache_randomness:
|
|||
return self
|
||||
|
||||
|
||||
def avoid_cache_randomness(cls):
|
||||
"""Decorator that marks a data transform class (subclass of
|
||||
:class:`BaseTransform`) prohibited from caching randomness. With this
|
||||
decorator, errors will be raised in following cases:
|
||||
|
||||
1. A method is defined in the class with the decorate
|
||||
`cache_randomness`;
|
||||
2. An instance of the class is invoked with the context
|
||||
`cache_random_params`.
|
||||
|
||||
A typical usage of `avoid_cache_randomness` is to decorate the data
|
||||
transforms with non-cacheable random behaviors (e.g., the random behavior
|
||||
can not be defined in a method, thus can not be decorated with
|
||||
`cache_randomness`). This is for preventing unintentinoal use of such data
|
||||
transforms within the context of caching randomness, which may lead to
|
||||
unexpected results.
|
||||
"""
|
||||
|
||||
# Check that cls is a data transform class
|
||||
assert issubclass(cls, BaseTransform)
|
||||
|
||||
# Check that no method is decorated with `cache_randomness` in cls
|
||||
if getattr(cls, '_methods_with_randomness', None):
|
||||
raise RuntimeError(
|
||||
f'Class {cls.__name__} decorated with '
|
||||
'``avoid_cache_randomness`` should not have methods decorated '
|
||||
'with ``cache_randomness`` (invalid methods: '
|
||||
f'{cls._methods_with_randomness})')
|
||||
|
||||
class AvoidCacheRandomness:
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
# Here we check the value in `objtype.__dict__` instead of
|
||||
# directly checking the attribute
|
||||
# `objtype._avoid_cache_randomness`. So if the base class is
|
||||
# decorated with :func:`avoid_cache_randomness`, it will not be
|
||||
# inherited by subclasses.
|
||||
return objtype.__dict__.get('_avoid_cache_randomness', False)
|
||||
|
||||
cls.avoid_cache_randomness = AvoidCacheRandomness()
|
||||
cls._avoid_cache_randomness = True
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cache_random_params(transforms: Union[BaseTransform, Iterable]):
|
||||
"""Context-manager that enables the cache of return values of methods
|
||||
decorated by ``cache_randomness`` in transforms.
|
||||
decorated with ``cache_randomness`` in transforms.
|
||||
|
||||
In this mode, decorated methods will cache their return values on the
|
||||
first invoking, and always return the cached value afterward. This allow
|
||||
|
@ -136,15 +185,27 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
|
|||
key = f'{id(obj)}.{name}'
|
||||
if key2counter[key] > 1:
|
||||
raise RuntimeError(
|
||||
'The method decorated by ``cache_randomness`` should '
|
||||
'be invoked at most once during processing one data '
|
||||
f'sample. The method {name} of {obj} has been invoked'
|
||||
f' {key2counter[key]} times.')
|
||||
'The method decorated with ``cache_randomness`` '
|
||||
'should be invoked at most once during processing '
|
||||
f'one data sample. The method {name} of {obj} has '
|
||||
f'been invoked {key2counter[key]} times.')
|
||||
return output
|
||||
|
||||
return wrapped
|
||||
|
||||
def _start_cache(t: BaseTransform):
|
||||
# Check if cache is allowed for `t`
|
||||
if getattr(t, 'avoid_cache_randomness', False):
|
||||
raise RuntimeError(
|
||||
f'Class {t.__class__.__name__} decorated with '
|
||||
'``avoid_cache_randomness`` is not allowed to be used with'
|
||||
' ``cache_random_params`` (e.g. wrapped by '
|
||||
'``ApplyToMultiple`` with ``share_random_params==True``).')
|
||||
|
||||
# Skip transforms w/o random method
|
||||
if not hasattr(t, '_methods_with_randomness'):
|
||||
return
|
||||
|
||||
# Set cache enabled flag
|
||||
setattr(t, '_cache_enabled', True)
|
||||
|
||||
|
@ -155,6 +216,10 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
|
|||
setattr(t, name, _add_invoke_counter(t, name))
|
||||
|
||||
def _end_cache(t: BaseTransform):
|
||||
# Skip transforms w/o random method
|
||||
if not hasattr(t, '_methods_with_randomness'):
|
||||
return
|
||||
|
||||
# Remove cache enabled flag
|
||||
del t._cache_enabled
|
||||
if hasattr(t, '_cache'):
|
||||
|
@ -169,13 +234,9 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
|
|||
key_transform = f'{id(t)}.transform'
|
||||
setattr(t, 'transform', key2method[key_transform])
|
||||
|
||||
def _apply(t: Union[BaseTransform, Iterable],
|
||||
func: Callable[[BaseTransform], None]):
|
||||
# Note that BaseTransform and Iterable are not mutually exclusive,
|
||||
# e.g. Compose, RandomChoice
|
||||
def _apply(t: BaseTransform, func: Callable[[BaseTransform], None]):
|
||||
if isinstance(t, BaseTransform):
|
||||
if hasattr(t, '_methods_with_randomness'):
|
||||
func(t)
|
||||
func(t)
|
||||
if isinstance(t, Iterable):
|
||||
for _t in t:
|
||||
_apply(_t, func)
|
||||
|
|
|
@ -6,7 +6,8 @@ import pytest
|
|||
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmcv.transforms.builder import TRANSFORMS
|
||||
from mmcv.transforms.utils import cache_random_params, cache_randomness
|
||||
from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
|
||||
cache_randomness)
|
||||
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
|
||||
RandomChoice, TransformBroadcaster)
|
||||
|
||||
|
@ -274,7 +275,7 @@ def test_key_mapper():
|
|||
_ = str(pipeline)
|
||||
|
||||
|
||||
def test_apply_to_multiple():
|
||||
def test_transform_broadcaster():
|
||||
|
||||
# Case 1: apply to list in results
|
||||
pipeline = TransformBroadcaster(
|
||||
|
@ -421,18 +422,64 @@ def test_utils():
|
|||
# Test cache_randomness: invalid function type
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class DummyTransform():
|
||||
class DummyTransform(BaseTransform):
|
||||
|
||||
@cache_randomness
|
||||
@staticmethod
|
||||
def func():
|
||||
return np.random.rand()
|
||||
|
||||
def transform(self, results):
|
||||
return results
|
||||
|
||||
# Test cache_randomness: invalid function argument list
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class DummyTransform():
|
||||
class DummyTransform(BaseTransform):
|
||||
|
||||
@cache_randomness
|
||||
def func(cls):
|
||||
return np.random.rand()
|
||||
|
||||
def transform(self, results):
|
||||
return results
|
||||
|
||||
# Test avoid_cache_randomness: invalid mixture with cache_randomness
|
||||
with pytest.raises(RuntimeError):
|
||||
|
||||
@avoid_cache_randomness
|
||||
class DummyTransform(BaseTransform):
|
||||
|
||||
@cache_randomness
|
||||
def func(self):
|
||||
pass
|
||||
|
||||
def transform(self, results):
|
||||
return results
|
||||
|
||||
# Test avoid_cache_randomness: raise error in cache_random_params
|
||||
with pytest.raises(RuntimeError):
|
||||
|
||||
@avoid_cache_randomness
|
||||
class DummyTransform(BaseTransform):
|
||||
|
||||
def transform(self, results):
|
||||
return results
|
||||
|
||||
transform = DummyTransform()
|
||||
with cache_random_params(transform):
|
||||
pass
|
||||
|
||||
# Test avoid_cache_randomness: non-inheritable
|
||||
@avoid_cache_randomness
|
||||
class DummyBaseTransform(BaseTransform):
|
||||
|
||||
def transform(self, results):
|
||||
return results
|
||||
|
||||
class DummyTransform(DummyBaseTransform):
|
||||
pass
|
||||
|
||||
transform = DummyTransform()
|
||||
with cache_random_params(transform):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue