mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feat] support registering function (#302)
This commit is contained in:
parent
4cd91ffe15
commit
5016332588
@ -6,11 +6,11 @@ OpenMMLab 大多数算法库均使用注册器来管理他们的代码模块,
|
|||||||
|
|
||||||
## 什么是注册器
|
## 什么是注册器
|
||||||
|
|
||||||
MMEngine 实现的注册器可以看作一个映射表和模块构建方法(build function)的组合。映射表维护了一个字符串到类的映射,使得用户可以借助字符串查找到相应的类,例如维护字符串 `"ResNet"` 到 `ResNet` 类的映射,使得用户可以通过 `"ResNet"` 找到 `ResNet` 类。
|
MMEngine 实现的注册器可以看作一个映射表和模块构建方法(build function)的组合。映射表维护了一个字符串到类或者函数的映射,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 `"ResNet"` 到 `ResNet` 类或函数的映射,使得用户可以通过 `"ResNet"` 找到 `ResNet` 类或函数;
|
||||||
而模块构建方法则定义了如何根据字符串查找到对应的类,并定义了如何实例化这个类,例如根据规则通过字符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模块。
|
而模块构建方法则定义了如何根据字符串查找到对应的类或函数,并定义了如何实例化这个类或调用这个函数,例如根据规则通过字符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模块。又或者根据规则通过字符串 `"bn"` 找到 `build_batchnorm2d`,并且调用函数获得 `BatchNorm2d` 模块。
|
||||||
MMEngine 中的注册器默认使用 [build_from_cfg 函数](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.build_from_cfg) 来查找并实例化字符串对应的类。
|
MMEngine 中的注册器默认使用 [build_from_cfg 函数](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.build_from_cfg) 来查找并实例化字符串对应的类。
|
||||||
|
|
||||||
一个注册器管理的类通常有相似的接口和功能,因此该注册器可以被视作这些类的抽象。例如注册器 `Classifier` 可以被视作所有分类网络的抽象,管理了 `ResNet`, `SEResNet` 和 `RegNetX` 等分类网络的类。
|
一个注册器管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 `Classifier` 可以被视作所有分类网络的抽象,管理了 `ResNet`, `SEResNet` 和 `RegNetX` 等分类网络的类以及 `build_ResNet`, `build_SEResNet` 和 `build_RegNetX` 等分类网络的构建函数。
|
||||||
使用注册器管理功能相似的模块可以显著提高代码的扩展性和灵活性。用户可以跳至`使用注册器提高代码的扩展性`章节了解注册器是如何提高代码拓展性的。
|
使用注册器管理功能相似的模块可以显著提高代码的扩展性和灵活性。用户可以跳至`使用注册器提高代码的扩展性`章节了解注册器是如何提高代码拓展性的。
|
||||||
|
|
||||||
## 入门用法
|
## 入门用法
|
||||||
@ -32,10 +32,10 @@ from mmengine import Registry
|
|||||||
CONVERTERS = Registry('converter')
|
CONVERTERS = Registry('converter')
|
||||||
```
|
```
|
||||||
|
|
||||||
然后我们可以实现不同的转换器。
|
然后我们可以实现不同的转换器。例如,在 `converters/converter_cls.py` 中实现 `Converter1` 和 `Converter2`,在 `converters/converter_func.py` 中实现 `converter3`。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# converters/converter.py
|
# converters/converter_cls.py
|
||||||
from .builder import CONVERTERS
|
from .builder import CONVERTERS
|
||||||
|
|
||||||
# 使用注册器管理模块
|
# 使用注册器管理模块
|
||||||
@ -53,12 +53,23 @@ class Converter2(object):
|
|||||||
self.c = c
|
self.c = c
|
||||||
```
|
```
|
||||||
|
|
||||||
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类之间的映射就可以由 `CONVERTERS` 构建和维护,我们也可以通过 `CONVERTERS.register_module(module=Converter1)` 实现同样的功能。
|
```python
|
||||||
|
# converters/converter_func.py
|
||||||
|
from .builder import CONVERTERS
|
||||||
|
from .converter_cls import Converter1
|
||||||
|
@CONVERTERS.register_module()
|
||||||
|
def converter3(a, b)
|
||||||
|
return Converter1(a, b)
|
||||||
|
```
|
||||||
|
|
||||||
通过注册,我们就可以通过 `CONVERTERS` 建立字符串与类之间的映射,
|
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类或函数之间的映射就可以由 `CONVERTERS` 构建和维护,我们也可以通过 `CONVERTERS.register_module(module=Converter1)` 实现同样的功能。
|
||||||
|
|
||||||
|
通过注册,我们就可以通过 `CONVERTERS` 建立字符串与类或函数之间的映射,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
'Converter1' -> <class 'Converter1'>
|
'Converter1' -> <class 'Converter1'>
|
||||||
|
'Converter2' -> <class 'Converter2'>
|
||||||
|
'Converter3' -> <function 'Converter3'>
|
||||||
```
|
```
|
||||||
|
|
||||||
```{note}
|
```{note}
|
||||||
@ -72,6 +83,9 @@ class Converter2(object):
|
|||||||
# 注意,converter_cfg 可以通过解析配置文件得到
|
# 注意,converter_cfg 可以通过解析配置文件得到
|
||||||
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
|
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
|
||||||
converter = CONVERTERS.build(converter_cfg)
|
converter = CONVERTERS.build(converter_cfg)
|
||||||
|
converter3_cfg = dict(type='converter3', a=a_value, b=b_value)
|
||||||
|
# returns the calling result
|
||||||
|
converter3 = CONVERTERS.build(converter3_cfg)
|
||||||
```
|
```
|
||||||
|
|
||||||
如果我们想使用 `Converter2`,仅需修改配置。
|
如果我们想使用 `Converter2`,仅需修改配置。
|
||||||
|
@ -191,7 +191,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||||||
f'RGB or gray image, but got {len(mean)}')
|
f'RGB or gray image, but got {len(mean)}')
|
||||||
assert len(std) == 3 or len(std) == 1, ( # type: ignore
|
assert len(std) == 3 or len(std) == 1, ( # type: ignore
|
||||||
'The length of std should be 1 or 3 to be compatible with RGB ' # type: ignore # noqa: E501
|
'The length of std should be 1 or 3 to be compatible with RGB ' # type: ignore # noqa: E501
|
||||||
f'or gray image, but got {len(std)}')
|
f'or gray image, but got {len(std)}') # type: ignore
|
||||||
self._enable_normalize = True
|
self._enable_normalize = True
|
||||||
self.register_buffer('mean',
|
self.register_buffer('mean',
|
||||||
torch.tensor(mean).view(-1, 1, 1), False)
|
torch.tensor(mean).view(-1, 1, 1), False)
|
||||||
|
@ -81,7 +81,8 @@ def build_from_cfg(
|
|||||||
cfg: Union[dict, ConfigDict, Config],
|
cfg: Union[dict, ConfigDict, Config],
|
||||||
registry: 'Registry',
|
registry: 'Registry',
|
||||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
|
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
|
||||||
"""Build a module from config dict.
|
"""Build a module from config dict when it is a class configuration, or
|
||||||
|
call a function from config dict when it is a function configuration.
|
||||||
|
|
||||||
At least one of the ``cfg`` and ``default_args`` contains the key "type"
|
At least one of the ``cfg`` and ``default_args`` contains the key "type"
|
||||||
which type should be either str or class. If they all contain it, the key
|
which type should be either str or class. If they all contain it, the key
|
||||||
@ -101,6 +102,12 @@ def build_from_cfg(
|
|||||||
>>> self.stages = stages
|
>>> self.stages = stages
|
||||||
>>> cfg = dict(type='ResNet', depth=50)
|
>>> cfg = dict(type='ResNet', depth=50)
|
||||||
>>> model = build_from_cfg(cfg, MODELS)
|
>>> model = build_from_cfg(cfg, MODELS)
|
||||||
|
>>> # Returns an instantiated object
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> def resnet50():
|
||||||
|
>>> pass
|
||||||
|
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
|
||||||
|
>>> # Return a result of the calling function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (dict or ConfigDict or Config): Config dict. It should at least
|
cfg (dict or ConfigDict or Config): Config dict. It should at least
|
||||||
@ -151,7 +158,7 @@ def build_from_cfg(
|
|||||||
' it was registered as expected. More details can be found at'
|
' it was registered as expected. More details can be found at'
|
||||||
' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
|
' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
|
||||||
)
|
)
|
||||||
elif inspect.isclass(obj_type):
|
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
||||||
obj_cls = obj_type
|
obj_cls = obj_type
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -182,9 +189,10 @@ def build_from_cfg(
|
|||||||
|
|
||||||
|
|
||||||
class Registry:
|
class Registry:
|
||||||
"""A registry to map strings to classes.
|
"""A registry to map strings to classes or functions.
|
||||||
|
|
||||||
Registered objects could be built from registry.
|
Registered object could be built from registry. Meanwhile, registered
|
||||||
|
functions could be called from registry.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): Registry name.
|
name (str): Registry name.
|
||||||
@ -210,6 +218,10 @@ class Registry:
|
|||||||
>>> pass
|
>>> pass
|
||||||
>>> # build model from `MODELS`
|
>>> # build model from `MODELS`
|
||||||
>>> resnet = MODELS.build(dict(type='ResNet'))
|
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> def resnet50():
|
||||||
|
>>> pass
|
||||||
|
>>> resnet = MODELS.build(dict(type='resnet50'))
|
||||||
|
|
||||||
>>> # hierarchical registry
|
>>> # hierarchical registry
|
||||||
>>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
|
>>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
|
||||||
@ -525,25 +537,25 @@ class Registry:
|
|||||||
self.children[registry.scope] = registry
|
self.children[registry.scope] = registry
|
||||||
|
|
||||||
def _register_module(self,
|
def _register_module(self,
|
||||||
module_class: Type,
|
module: Type,
|
||||||
module_name: Optional[Union[str, List[str]]] = None,
|
module_name: Optional[Union[str, List[str]]] = None,
|
||||||
force: bool = False) -> None:
|
force: bool = False) -> None:
|
||||||
"""Register a module.
|
"""Register a module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_class (type): Module class to be registered.
|
module (type): Module class or function to be registered.
|
||||||
module_name (str or list of str, optional): The module name to be
|
module_name (str or list of str, optional): The module name to be
|
||||||
registered. If not specified, the class name will be used.
|
registered. If not specified, the class name will be used.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
force (bool): Whether to override an existing class with the same
|
force (bool): Whether to override an existing class with the same
|
||||||
name. Defaults to False.
|
name. Defaults to False.
|
||||||
"""
|
"""
|
||||||
if not inspect.isclass(module_class):
|
if not inspect.isclass(module) and not inspect.isfunction(module):
|
||||||
raise TypeError('module must be a class, '
|
raise TypeError('module must be a class or a function, '
|
||||||
f'but got {type(module_class)}')
|
f'but got {type(module)}')
|
||||||
|
|
||||||
if module_name is None:
|
if module_name is None:
|
||||||
module_name = module_class.__name__
|
module_name = module.__name__
|
||||||
if isinstance(module_name, str):
|
if isinstance(module_name, str):
|
||||||
module_name = [module_name]
|
module_name = [module_name]
|
||||||
for name in module_name:
|
for name in module_name:
|
||||||
@ -551,7 +563,7 @@ class Registry:
|
|||||||
existed_module = self.module_dict[name]
|
existed_module = self.module_dict[name]
|
||||||
raise KeyError(f'{name} is already registered in {self.name} '
|
raise KeyError(f'{name} is already registered in {self.name} '
|
||||||
f'at {existed_module.__module__}')
|
f'at {existed_module.__module__}')
|
||||||
self._module_dict[name] = module_class
|
self._module_dict[name] = module
|
||||||
|
|
||||||
def register_module(
|
def register_module(
|
||||||
self,
|
self,
|
||||||
@ -569,8 +581,8 @@ class Registry:
|
|||||||
registered. If not specified, the class name will be used.
|
registered. If not specified, the class name will be used.
|
||||||
force (bool): Whether to override an existing class with the same
|
force (bool): Whether to override an existing class with the same
|
||||||
name. Default to False.
|
name. Default to False.
|
||||||
module (type, optional): Module class to be registered. Defaults to
|
module (type, optional): Module class or function to be registered.
|
||||||
None.
|
Defaults to None.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> backbones = Registry('backbone')
|
>>> backbones = Registry('backbone')
|
||||||
@ -599,14 +611,12 @@ class Registry:
|
|||||||
|
|
||||||
# use it as a normal method: x.register_module(module=SomeClass)
|
# use it as a normal method: x.register_module(module=SomeClass)
|
||||||
if module is not None:
|
if module is not None:
|
||||||
self._register_module(
|
self._register_module(module=module, module_name=name, force=force)
|
||||||
module_class=module, module_name=name, force=force)
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
# use it as a decorator: @x.register_module()
|
# use it as a decorator: @x.register_module()
|
||||||
def _register(cls):
|
def _register(module):
|
||||||
self._register_module(
|
self._register_module(module=module, module_name=name, force=force)
|
||||||
module_class=cls, module_name=name, force=force)
|
return module
|
||||||
return cls
|
|
||||||
|
|
||||||
return _register
|
return _register
|
||||||
|
@ -57,13 +57,24 @@ class TestRegistry:
|
|||||||
def test_register_module(self):
|
def test_register_module(self):
|
||||||
CATS = Registry('cat')
|
CATS = Registry('cat')
|
||||||
|
|
||||||
# can only decorate a class
|
@CATS.register_module()
|
||||||
|
def muchkin():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert CATS.get('muchkin') is muchkin
|
||||||
|
assert 'muchkin' in CATS
|
||||||
|
|
||||||
|
# can only decorate a class or a function
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
|
|
||||||
@CATS.register_module()
|
class Demo:
|
||||||
def some_method():
|
|
||||||
|
def some_method(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
method = Demo().some_method
|
||||||
|
CATS.register_module(name='some_method', module=method)
|
||||||
|
|
||||||
# test `name` parameter which must be either of None, a string or a
|
# test `name` parameter which must be either of None, a string or a
|
||||||
# sequence of string
|
# sequence of string
|
||||||
# `name` is None
|
# `name` is None
|
||||||
@ -71,7 +82,7 @@ class TestRegistry:
|
|||||||
class BritishShorthair:
|
class BritishShorthair:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert len(CATS) == 1
|
assert len(CATS) == 2
|
||||||
assert CATS.get('BritishShorthair') is BritishShorthair
|
assert CATS.get('BritishShorthair') is BritishShorthair
|
||||||
|
|
||||||
# `name` is a string
|
# `name` is a string
|
||||||
@ -79,7 +90,7 @@ class TestRegistry:
|
|||||||
class Munchkin:
|
class Munchkin:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert len(CATS) == 2
|
assert len(CATS) == 3
|
||||||
assert CATS.get('Munchkin') is Munchkin
|
assert CATS.get('Munchkin') is Munchkin
|
||||||
assert 'Munchkin' in CATS
|
assert 'Munchkin' in CATS
|
||||||
|
|
||||||
@ -90,7 +101,7 @@ class TestRegistry:
|
|||||||
|
|
||||||
assert CATS.get('Siamese') is SiameseCat
|
assert CATS.get('Siamese') is SiameseCat
|
||||||
assert CATS.get('Siamese2') is SiameseCat
|
assert CATS.get('Siamese2') is SiameseCat
|
||||||
assert len(CATS) == 4
|
assert len(CATS) == 5
|
||||||
|
|
||||||
# `name` is an invalid type
|
# `name` is an invalid type
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
@ -127,14 +138,15 @@ class TestRegistry:
|
|||||||
class BritishShorthair:
|
class BritishShorthair:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert len(CATS) == 4
|
assert len(CATS) == 5
|
||||||
|
|
||||||
# test `module` parameter, which is either None or a class
|
# test `module` parameter, which is either None or a class
|
||||||
# when the `register_module`` is called as a method rather than a
|
# when the `register_module`` is called as a method rather than a
|
||||||
# decorator, which must be a class
|
# decorator, which must be a class
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match="module must be a class, but got <class 'str'>"):
|
match='module must be a class or a function,'
|
||||||
|
" but got <class 'str'>"):
|
||||||
CATS.register_module(module='string')
|
CATS.register_module(module='string')
|
||||||
|
|
||||||
class SphynxCat:
|
class SphynxCat:
|
||||||
@ -142,16 +154,16 @@ class TestRegistry:
|
|||||||
|
|
||||||
CATS.register_module(module=SphynxCat)
|
CATS.register_module(module=SphynxCat)
|
||||||
assert CATS.get('SphynxCat') is SphynxCat
|
assert CATS.get('SphynxCat') is SphynxCat
|
||||||
assert len(CATS) == 5
|
assert len(CATS) == 6
|
||||||
|
|
||||||
CATS.register_module(name='Sphynx1', module=SphynxCat)
|
CATS.register_module(name='Sphynx1', module=SphynxCat)
|
||||||
assert CATS.get('Sphynx1') is SphynxCat
|
assert CATS.get('Sphynx1') is SphynxCat
|
||||||
assert len(CATS) == 6
|
assert len(CATS) == 7
|
||||||
|
|
||||||
CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
|
CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
|
||||||
assert CATS.get('Sphynx2') is SphynxCat
|
assert CATS.get('Sphynx2') is SphynxCat
|
||||||
assert CATS.get('Sphynx3') is SphynxCat
|
assert CATS.get('Sphynx3') is SphynxCat
|
||||||
assert len(CATS) == 8
|
assert len(CATS) == 9
|
||||||
|
|
||||||
def _build_registry(self):
|
def _build_registry(self):
|
||||||
"""A helper function to build a Hierarchical Registry."""
|
"""A helper function to build a Hierarchical Registry."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user