[Feat] support registering function (#302)

pull/306/head
Alex Yang 2022-06-14 14:50:24 +08:00 committed by GitHub
parent 4cd91ffe15
commit 5016332588
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 39 deletions

View File

@ -6,11 +6,11 @@ OpenMMLab 大多数算法库均使用注册器来管理他们的代码模块,
## 什么是注册器
MMEngine 实现的注册器可以看作一个映射表和模块构建方法build function的组合。映射表维护了一个字符串到类的映射使得用户可以借助字符串查找到相应的类例如维护字符串 `"ResNet"``ResNet` 类的映射,使得用户可以通过 `"ResNet"` 找到 `ResNet`
而模块构建方法则定义了如何根据字符串查找到对应的类,并定义了如何实例化这个类,例如根据规则通过字符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模块。
MMEngine 实现的注册器可以看作一个映射表和模块构建方法build function的组合。映射表维护了一个字符串到类或者函数的映射,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 `"ResNet"``ResNet`或函数的映射,使得用户可以通过 `"ResNet"` 找到 `ResNet`或函数;
而模块构建方法则定义了如何根据字符串查找到对应的类或函数,并定义了如何实例化这个类或调用这个函数,例如根据规则通过字符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模块。又或者根据规则通过字符串 `"bn"` 找到 `build_batchnorm2d`,并且调用函数获得 `BatchNorm2d` 模块。
MMEngine 中的注册器默认使用 [build_from_cfg 函数](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.build_from_cfg) 来查找并实例化字符串对应的类。
一个注册器管理的类通常有相似的接口和功能,因此该注册器可以被视作这些类的抽象。例如注册器 `Classifier` 可以被视作所有分类网络的抽象,管理了 `ResNet` `SEResNet``RegNetX` 等分类网络的类。
一个注册器管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 `Classifier` 可以被视作所有分类网络的抽象,管理了 `ResNet` `SEResNet``RegNetX` 等分类网络的类以及 `build_ResNet`, `build_SEResNet``build_RegNetX` 等分类网络的构建函数
使用注册器管理功能相似的模块可以显著提高代码的扩展性和灵活性。用户可以跳至`使用注册器提高代码的扩展性`章节了解注册器是如何提高代码拓展性的。
## 入门用法
@ -32,10 +32,10 @@ from mmengine import Registry
CONVERTERS = Registry('converter')
```
然后我们可以实现不同的转换器。
然后我们可以实现不同的转换器。例如,在 `converters/converter_cls.py` 中实现 `Converter1``Converter2`,在 `converters/converter_func.py` 中实现 `converter3`
```python
# converters/converter.py
# converters/converter_cls.py
from .builder import CONVERTERS
# 使用注册器管理模块
@ -53,12 +53,23 @@ class Converter2(object):
self.c = c
```
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类之间的映射就可以由 `CONVERTERS` 构建和维护,我们也可以通过 `CONVERTERS.register_module(module=Converter1)` 实现同样的功能。
```python
# converters/converter_func.py
from .builder import CONVERTERS
from .converter_cls import Converter1
@CONVERTERS.register_module()
def converter3(a, b)
return Converter1(a, b)
```
通过注册,我们就可以通过 `CONVERTERS` 建立字符串与类之间的映射,
使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类或函数之间的映射就可以由 `CONVERTERS` 构建和维护,我们也可以通过 `CONVERTERS.register_module(module=Converter1)` 实现同样的功能。
通过注册,我们就可以通过 `CONVERTERS` 建立字符串与类或函数之间的映射,
```python
'Converter1' -> <class 'Converter1'>
'Converter2' -> <class 'Converter2'>
'Converter3' -> <function 'Converter3'>
```
```{note}
@ -72,6 +83,9 @@ class Converter2(object):
# 注意converter_cfg 可以通过解析配置文件得到
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = CONVERTERS.build(converter_cfg)
converter3_cfg = dict(type='converter3', a=a_value, b=b_value)
# returns the calling result
converter3 = CONVERTERS.build(converter3_cfg)
```
如果我们想使用 `Converter2`,仅需修改配置。

View File

@ -191,7 +191,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
f'RGB or gray image, but got {len(mean)}')
assert len(std) == 3 or len(std) == 1, ( # type: ignore
'The length of std should be 1 or 3 to be compatible with RGB ' # type: ignore # noqa: E501
f'or gray image, but got {len(std)}')
f'or gray image, but got {len(std)}') # type: ignore
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)

View File

@ -81,7 +81,8 @@ def build_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: 'Registry',
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
"""Build a module from config dict.
"""Build a module from config dict when it is a class configuration, or
call a function from config dict when it is a function configuration.
At least one of the ``cfg`` and ``default_args`` contains the key "type"
which type should be either str or class. If they all contain it, the key
@ -101,6 +102,12 @@ def build_from_cfg(
>>> self.stages = stages
>>> cfg = dict(type='ResNet', depth=50)
>>> model = build_from_cfg(cfg, MODELS)
>>> # Returns an instantiated object
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
>>> # Return a result of the calling function
Args:
cfg (dict or ConfigDict or Config): Config dict. It should at least
@ -151,7 +158,7 @@ def build_from_cfg(
' it was registered as expected. More details can be found at'
' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
)
elif inspect.isclass(obj_type):
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
obj_cls = obj_type
else:
raise TypeError(
@ -182,9 +189,10 @@ def build_from_cfg(
class Registry:
"""A registry to map strings to classes.
"""A registry to map strings to classes or functions.
Registered objects could be built from registry.
Registered object could be built from registry. Meanwhile, registered
functions could be called from registry.
Args:
name (str): Registry name.
@ -210,6 +218,10 @@ class Registry:
>>> pass
>>> # build model from `MODELS`
>>> resnet = MODELS.build(dict(type='ResNet'))
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = MODELS.build(dict(type='resnet50'))
>>> # hierarchical registry
>>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
@ -525,25 +537,25 @@ class Registry:
self.children[registry.scope] = registry
def _register_module(self,
module_class: Type,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = False) -> None:
"""Register a module.
Args:
module_class (type): Module class to be registered.
module (type): Module class or function to be registered.
module_name (str or list of str, optional): The module name to be
registered. If not specified, the class name will be used.
Defaults to None.
force (bool): Whether to override an existing class with the same
name. Defaults to False.
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if not inspect.isclass(module) and not inspect.isfunction(module):
raise TypeError('module must be a class or a function, '
f'but got {type(module)}')
if module_name is None:
module_name = module_class.__name__
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
@ -551,7 +563,7 @@ class Registry:
existed_module = self.module_dict[name]
raise KeyError(f'{name} is already registered in {self.name} '
f'at {existed_module.__module__}')
self._module_dict[name] = module_class
self._module_dict[name] = module
def register_module(
self,
@ -569,8 +581,8 @@ class Registry:
registered. If not specified, the class name will be used.
force (bool): Whether to override an existing class with the same
name. Default to False.
module (type, optional): Module class to be registered. Defaults to
None.
module (type, optional): Module class or function to be registered.
Defaults to None.
Examples:
>>> backbones = Registry('backbone')
@ -599,14 +611,12 @@ class Registry:
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
self._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
def _register(module):
self._register_module(module=module, module_name=name, force=force)
return module
return _register

View File

@ -57,13 +57,24 @@ class TestRegistry:
def test_register_module(self):
CATS = Registry('cat')
# can only decorate a class
@CATS.register_module()
def muchkin():
pass
assert CATS.get('muchkin') is muchkin
assert 'muchkin' in CATS
# can only decorate a class or a function
with pytest.raises(TypeError):
@CATS.register_module()
def some_method():
class Demo:
def some_method(self):
pass
method = Demo().some_method
CATS.register_module(name='some_method', module=method)
# test `name` parameter which must be either of None, a string or a
# sequence of string
# `name` is None
@ -71,7 +82,7 @@ class TestRegistry:
class BritishShorthair:
pass
assert len(CATS) == 1
assert len(CATS) == 2
assert CATS.get('BritishShorthair') is BritishShorthair
# `name` is a string
@ -79,7 +90,7 @@ class TestRegistry:
class Munchkin:
pass
assert len(CATS) == 2
assert len(CATS) == 3
assert CATS.get('Munchkin') is Munchkin
assert 'Munchkin' in CATS
@ -90,7 +101,7 @@ class TestRegistry:
assert CATS.get('Siamese') is SiameseCat
assert CATS.get('Siamese2') is SiameseCat
assert len(CATS) == 4
assert len(CATS) == 5
# `name` is an invalid type
with pytest.raises(
@ -127,14 +138,15 @@ class TestRegistry:
class BritishShorthair:
pass
assert len(CATS) == 4
assert len(CATS) == 5
# test `module` parameter, which is either None or a class
# when the `register_module`` is called as a method rather than a
# decorator, which must be a class
with pytest.raises(
TypeError,
match="module must be a class, but got <class 'str'>"):
match='module must be a class or a function,'
" but got <class 'str'>"):
CATS.register_module(module='string')
class SphynxCat:
@ -142,16 +154,16 @@ class TestRegistry:
CATS.register_module(module=SphynxCat)
assert CATS.get('SphynxCat') is SphynxCat
assert len(CATS) == 5
assert len(CATS) == 6
CATS.register_module(name='Sphynx1', module=SphynxCat)
assert CATS.get('Sphynx1') is SphynxCat
assert len(CATS) == 6
assert len(CATS) == 7
CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
assert CATS.get('Sphynx2') is SphynxCat
assert CATS.get('Sphynx3') is SphynxCat
assert len(CATS) == 8
assert len(CATS) == 9
def _build_registry(self):
"""A helper function to build a Hierarchical Registry."""