mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Support registering function (#1858)
* [Enhance] Support register function. * fix unittest error * add docs and unittest of register function * modify the docs * fix version to 1.5.1 * Update docs/zh_cn/understand_mmcv/registry.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/understand_mmcv/registry.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * refine the docs * modify module_class to module Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/2041/head
parent
9f5a03dc2c
commit
a3a078eefb
|
@ -3,11 +3,15 @@
|
||||||
MMCV implements [registry](https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py) to manage different modules that share similar functionalities, e.g., backbones, head, and necks, in detectors.
|
MMCV implements [registry](https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py) to manage different modules that share similar functionalities, e.g., backbones, head, and necks, in detectors.
|
||||||
Most projects in OpenMMLab use registry to manage modules of datasets and models, such as [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [MMClassification](https://github.com/open-mmlab/mmclassification), [MMEditing](https://github.com/open-mmlab/mmediting), etc.
|
Most projects in OpenMMLab use registry to manage modules of datasets and models, such as [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [MMClassification](https://github.com/open-mmlab/mmclassification), [MMEditing](https://github.com/open-mmlab/mmediting), etc.
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
In v1.5.1 and later, the Registry supports registering functions and calling them.
|
||||||
|
```
|
||||||
|
|
||||||
### What is registry
|
### What is registry
|
||||||
|
|
||||||
In MMCV, registry can be regarded as a mapping that maps a class to a string.
|
In MMCV, registry can be regarded as a mapping that maps a class or function to a string.
|
||||||
These classes contained by a single registry usually have similar APIs but implement different algorithms or support different datasets.
|
These classes or functions contained by a single registry usually have similar APIs but implement different algorithms or support different datasets.
|
||||||
With the registry, users can find and instantiate the class through its corresponding string, and use the instantiated module as they want.
|
With the registry, users can find the class or function through its corresponding string, and instantiate the corresponding module or call the function to obtain the result according to needs.
|
||||||
One typical example is the config systems in most OpenMMLab projects, which use the registry to create hooks, runners, models, and datasets, through configs.
|
One typical example is the config systems in most OpenMMLab projects, which use the registry to create hooks, runners, models, and datasets, through configs.
|
||||||
The API reference could be found [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry).
|
The API reference could be found [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry).
|
||||||
|
|
||||||
|
@ -17,7 +21,7 @@ To manage your modules in the codebase by `Registry`, there are three steps as b
|
||||||
2. Create a registry.
|
2. Create a registry.
|
||||||
3. Use this registry to manage the modules.
|
3. Use this registry to manage the modules.
|
||||||
|
|
||||||
`build_func` argument of `Registry` is to customize how to instantiate the class instance, the default one is `build_from_cfg` implemented [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg).
|
`build_func` argument of `Registry` is to customize how to instantiate the class instance or how to call the function to obtain the result, the default one is `build_from_cfg` implemented [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg).
|
||||||
|
|
||||||
### A Simple Example
|
### A Simple Example
|
||||||
|
|
||||||
|
@ -34,7 +38,7 @@ from mmcv.utils import Registry
|
||||||
CONVERTERS = Registry('converters')
|
CONVERTERS = Registry('converters')
|
||||||
```
|
```
|
||||||
|
|
||||||
Then we can implement different converters in the package. For example, implement `Converter1` in `converters/converter1.py`
|
Then we can implement different converters that is class or function in the package. For example, implement `Converter1` in `converters/converter1.py`, and `converter2` in `converters/converter2.py`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
||||||
|
@ -47,11 +51,22 @@ class Converter1(object):
|
||||||
self.a = a
|
self.a = a
|
||||||
self.b = b
|
self.b = b
|
||||||
```
|
```
|
||||||
|
```python
|
||||||
|
# converter2.py
|
||||||
|
from .builder import CONVERTERS
|
||||||
|
from .converter1 import Converter1
|
||||||
|
|
||||||
|
# 使用注册器管理模块
|
||||||
|
@CONVERTERS.register_module()
|
||||||
|
def converter2(a, b)
|
||||||
|
return Converter1(a, b)
|
||||||
|
```
|
||||||
The key step to use registry for managing the modules is to register the implemented module into the registry `CONVERTERS` through
|
The key step to use registry for managing the modules is to register the implemented module into the registry `CONVERTERS` through
|
||||||
`@CONVERTERS.register_module()` when you are creating the module. By this way, a mapping between a string and the class is built and maintained by `CONVERTERS` as below
|
`@CONVERTERS.register_module()` when you are creating the module. By this way, a mapping between a string and the class (function) is built and maintained by `CONVERTERS` as below
|
||||||
|
|
||||||
```python
|
```python
|
||||||
'Converter1' -> <class 'Converter1'>
|
'Converter1' -> <class 'Converter1'>
|
||||||
|
'converter2' -> <function 'converter2'>
|
||||||
```
|
```
|
||||||
```{note}
|
```{note}
|
||||||
The registry mechanism will be triggered only when the file where the module is located is imported.
|
The registry mechanism will be triggered only when the file where the module is located is imported.
|
||||||
|
@ -61,8 +76,11 @@ So you need to import that file somewhere. More details can be found at https://
|
||||||
If the module is successfully registered, you can use this converter through configs as
|
If the module is successfully registered, you can use this converter through configs as
|
||||||
|
|
||||||
```python
|
```python
|
||||||
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
|
converter1_cfg = dict(type='Converter1', a=a_value, b=b_value)
|
||||||
converter = CONVERTERS.build(converter_cfg)
|
converter2_cfg = dict(type='converter2', a=a_value, b=b_value)
|
||||||
|
converter1 = CONVERTERS.build(converter1_cfg)
|
||||||
|
# returns the calling result
|
||||||
|
result = CONVERTERS.build(converter2_cfg)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Customize Build Function
|
### Customize Build Function
|
||||||
|
|
|
@ -2,10 +2,14 @@
|
||||||
MMCV 使用 [注册器](https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py) 来管理具有相似功能的不同模块, 例如, 检测器中的主干网络、头部、和模型颈部。
|
MMCV 使用 [注册器](https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py) 来管理具有相似功能的不同模块, 例如, 检测器中的主干网络、头部、和模型颈部。
|
||||||
在 OpenMMLab 家族中的绝大部分开源项目使用注册器去管理数据集和模型的模块,例如 [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [MMClassification](https://github.com/open-mmlab/mmclassification), [MMEditing](https://github.com/open-mmlab/mmediting) 等。
|
在 OpenMMLab 家族中的绝大部分开源项目使用注册器去管理数据集和模型的模块,例如 [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [MMClassification](https://github.com/open-mmlab/mmclassification), [MMEditing](https://github.com/open-mmlab/mmediting) 等。
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
在 v1.5.1 版本开始支持注册函数的功能。
|
||||||
|
```
|
||||||
|
|
||||||
### 什么是注册器
|
### 什么是注册器
|
||||||
在MMCV中,注册器可以看作类到字符串的映射。
|
在MMCV中,注册器可以看作类或函数到字符串的映射。
|
||||||
一个注册器中的类通常有相似的接口,但是可以实现不同的算法或支持不同的数据集。
|
一个注册器中的类或函数通常有相似的接口,但是可以实现不同的算法或支持不同的数据集。
|
||||||
借助注册器,用户可以通过使用相应的字符串查找并实例化该类,并根据他们的需要实例化对应模块。
|
借助注册器,用户可以通过使用相应的字符串查找类或函数,并根据他们的需要实例化对应模块或调用函数获取结果。
|
||||||
一个典型的案例是,OpenMMLab 中的大部分开源项目的配置系统,这些系统通过配置文件来使用注册器创建钩子、执行器、模型和数据集。
|
一个典型的案例是,OpenMMLab 中的大部分开源项目的配置系统,这些系统通过配置文件来使用注册器创建钩子、执行器、模型和数据集。
|
||||||
可以在[这里](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry)找到注册器接口使用文档。
|
可以在[这里](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry)找到注册器接口使用文档。
|
||||||
|
|
||||||
|
@ -15,7 +19,7 @@ MMCV 使用 [注册器](https://github.com/open-mmlab/mmcv/blob/master/mmcv/util
|
||||||
2. 创建注册器
|
2. 创建注册器
|
||||||
3. 使用此注册器来管理模块
|
3. 使用此注册器来管理模块
|
||||||
|
|
||||||
`Registry`(注册器)的参数 `build_func`(构建函数) 用来自定以如何实例化类的实例,默认使用 [这里](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg)实现的`build_from_cfg`。
|
`Registry`(注册器)的参数 `build_func`(构建函数) 用来自定义如何实例化类的实例或如何调用函数获取结果,默认使用 [这里](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg) 实现的`build_from_cfg`。
|
||||||
|
|
||||||
### 一个简单的例子
|
### 一个简单的例子
|
||||||
|
|
||||||
|
@ -29,9 +33,10 @@ from mmcv.utils import Registry
|
||||||
CONVERTERS = Registry('converter')
|
CONVERTERS = Registry('converter')
|
||||||
```
|
```
|
||||||
|
|
||||||
然后我们在包中可以实现不同的转换器(converter)。例如,在 `converters/converter1.py` 中实现 `Converter1`。
|
然后我们在包中可以实现不同的转换器(converter),其可以为类或函数。例如,在 `converters/converter1.py` 中实现 `Converter1`,在 `converters/converter2.py` 中实现 `converter2`。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# converter1.py
|
||||||
from .builder import CONVERTERS
|
from .builder import CONVERTERS
|
||||||
|
|
||||||
# 使用注册器管理模块
|
# 使用注册器管理模块
|
||||||
|
@ -41,12 +46,23 @@ class Converter1(object):
|
||||||
self.a = a
|
self.a = a
|
||||||
self.b = b
|
self.b = b
|
||||||
```
|
```
|
||||||
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类之间的映射就可以由 `CONVERTERS` 构建和维护,如下所示:
|
```python
|
||||||
|
# converter2.py
|
||||||
|
from .builder import CONVERTERS
|
||||||
|
from .converter1 import Converter1
|
||||||
|
|
||||||
通过这种方式,就可以通过 `CONVERTERS` 建立字符串与类之间的映射,如下所示:
|
# 使用注册器管理模块
|
||||||
|
@CONVERTERS.register_module()
|
||||||
|
def converter2(a, b)
|
||||||
|
return Converter1(a, b)
|
||||||
|
```
|
||||||
|
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串到类或函数之间的映射就可以由 `CONVERTERS` 构建和维护,如下所示:
|
||||||
|
|
||||||
|
通过这种方式,就可以通过 `CONVERTERS` 建立字符串与类或函数之间的映射,如下所示:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
'Converter1' -> <class 'Converter1'>
|
'Converter1' -> <class 'Converter1'>
|
||||||
|
'converter2' -> <function 'converter2'>
|
||||||
```
|
```
|
||||||
```{note}
|
```{note}
|
||||||
只有模块所在的文件被导入时,注册机制才会被触发,所以您需要在某处导入该文件。更多详情请查看 https://github.com/open-mmlab/mmdetection/issues/5974。
|
只有模块所在的文件被导入时,注册机制才会被触发,所以您需要在某处导入该文件。更多详情请查看 https://github.com/open-mmlab/mmdetection/issues/5974。
|
||||||
|
@ -54,8 +70,11 @@ class Converter1(object):
|
||||||
如果模块被成功注册了,你可以通过配置文件使用这个转换器(converter),如下所示:
|
如果模块被成功注册了,你可以通过配置文件使用这个转换器(converter),如下所示:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
|
converter1_cfg = dict(type='Converter1', a=a_value, b=b_value)
|
||||||
converter = CONVERTERS.build(converter_cfg)
|
converter2_cfg = dict(type='converter2', a=a_value, b=b_value)
|
||||||
|
converter1 = CONVERTERS.build(converter1_cfg)
|
||||||
|
# returns the calling result
|
||||||
|
result = CONVERTERS.build(converter2_cfg)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 自定义构建函数
|
### 自定义构建函数
|
||||||
|
|
|
@ -3,11 +3,25 @@ import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from .misc import is_seq_of
|
from .misc import deprecated_api_warning, is_seq_of
|
||||||
|
|
||||||
|
|
||||||
def build_from_cfg(cfg, registry, default_args=None):
|
def build_from_cfg(cfg, registry, default_args=None):
|
||||||
"""Build a module from config dict.
|
"""Build a module from config dict when it is a class configuration, or
|
||||||
|
call a function from config dict when it is a function configuration.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> MODELS = Registry('models')
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
|
||||||
|
>>> # Returns an instantiated object
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> def resnet50():
|
||||||
|
>>> pass
|
||||||
|
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
|
||||||
|
>>> # Return a result of the calling function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (dict): Config dict. It should at least contain the key "type".
|
cfg (dict): Config dict. It should at least contain the key "type".
|
||||||
|
@ -43,7 +57,7 @@ def build_from_cfg(cfg, registry, default_args=None):
|
||||||
if obj_cls is None:
|
if obj_cls is None:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
f'{obj_type} is not in the {registry.name} registry')
|
f'{obj_type} is not in the {registry.name} registry')
|
||||||
elif inspect.isclass(obj_type):
|
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
||||||
obj_cls = obj_type
|
obj_cls = obj_type
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
@ -56,9 +70,10 @@ def build_from_cfg(cfg, registry, default_args=None):
|
||||||
|
|
||||||
|
|
||||||
class Registry:
|
class Registry:
|
||||||
"""A registry to map strings to classes.
|
"""A registry to map strings to classes or functions.
|
||||||
|
|
||||||
Registered object could be built from registry.
|
Registered object could be built from registry. Meanwhile, registered
|
||||||
|
functions could be called from registry.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> MODELS = Registry('models')
|
>>> MODELS = Registry('models')
|
||||||
|
@ -66,6 +81,10 @@ class Registry:
|
||||||
>>> class ResNet:
|
>>> class ResNet:
|
||||||
>>> pass
|
>>> pass
|
||||||
>>> resnet = MODELS.build(dict(type='ResNet'))
|
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> def resnet50():
|
||||||
|
>>> pass
|
||||||
|
>>> resnet = MODELS.build(dict(type='resnet50'))
|
||||||
|
|
||||||
Please refer to
|
Please refer to
|
||||||
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
||||||
|
@ -235,20 +254,21 @@ class Registry:
|
||||||
f'scope {registry.scope} exists in {self.name} registry'
|
f'scope {registry.scope} exists in {self.name} registry'
|
||||||
self.children[registry.scope] = registry
|
self.children[registry.scope] = registry
|
||||||
|
|
||||||
def _register_module(self, module_class, module_name=None, force=False):
|
@deprecated_api_warning(name_dict=dict(module_class='module'))
|
||||||
if not inspect.isclass(module_class):
|
def _register_module(self, module, module_name=None, force=False):
|
||||||
raise TypeError('module must be a class, '
|
if not inspect.isclass(module) and not inspect.isfunction(module):
|
||||||
f'but got {type(module_class)}')
|
raise TypeError('module must be a class or a function, '
|
||||||
|
f'but got {type(module)}')
|
||||||
|
|
||||||
if module_name is None:
|
if module_name is None:
|
||||||
module_name = module_class.__name__
|
module_name = module.__name__
|
||||||
if isinstance(module_name, str):
|
if isinstance(module_name, str):
|
||||||
module_name = [module_name]
|
module_name = [module_name]
|
||||||
for name in module_name:
|
for name in module_name:
|
||||||
if not force and name in self._module_dict:
|
if not force and name in self._module_dict:
|
||||||
raise KeyError(f'{name} is already registered '
|
raise KeyError(f'{name} is already registered '
|
||||||
f'in {self.name}')
|
f'in {self.name}')
|
||||||
self._module_dict[name] = module_class
|
self._module_dict[name] = module
|
||||||
|
|
||||||
def deprecated_register_module(self, cls=None, force=False):
|
def deprecated_register_module(self, cls=None, force=False):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -289,7 +309,7 @@ class Registry:
|
||||||
specified, the class name will be used.
|
specified, the class name will be used.
|
||||||
force (bool, optional): Whether to override an existing class with
|
force (bool, optional): Whether to override an existing class with
|
||||||
the same name. Default: False.
|
the same name. Default: False.
|
||||||
module (type): Module class to be registered.
|
module (type): Module class or function to be registered.
|
||||||
"""
|
"""
|
||||||
if not isinstance(force, bool):
|
if not isinstance(force, bool):
|
||||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||||
|
@ -306,14 +326,12 @@ class Registry:
|
||||||
|
|
||||||
# use it as a normal method: x.register_module(module=SomeClass)
|
# use it as a normal method: x.register_module(module=SomeClass)
|
||||||
if module is not None:
|
if module is not None:
|
||||||
self._register_module(
|
self._register_module(module=module, module_name=name, force=force)
|
||||||
module_class=module, module_name=name, force=force)
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
# use it as a decorator: @x.register_module()
|
# use it as a decorator: @x.register_module()
|
||||||
def _register(cls):
|
def _register(module):
|
||||||
self._register_module(
|
self._register_module(module=module, module_name=name, force=force)
|
||||||
module_class=cls, module_name=name, force=force)
|
return module
|
||||||
return cls
|
|
||||||
|
|
||||||
return _register
|
return _register
|
||||||
|
|
|
@ -89,12 +89,23 @@ def test_registry():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
CATS.register_module(0)
|
CATS.register_module(0)
|
||||||
|
|
||||||
# can only decorate a class
|
@CATS.register_module()
|
||||||
|
def muchkin():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert CATS.get('muchkin') is muchkin
|
||||||
|
assert 'muchkin' in CATS
|
||||||
|
|
||||||
|
# can only decorate a class or a function
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
|
|
||||||
@CATS.register_module()
|
class Demo:
|
||||||
def some_method():
|
|
||||||
pass
|
def some_method(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
method = Demo().some_method
|
||||||
|
CATS.register_module(name='some_method', module=method)
|
||||||
|
|
||||||
# begin: test old APIs
|
# begin: test old APIs
|
||||||
with pytest.warns(DeprecationWarning):
|
with pytest.warns(DeprecationWarning):
|
||||||
|
|
Loading…
Reference in New Issue