[Feature] Registry supports import modules automatically. (#643)

* [Feature] Support registry auto import modules.

* update

* rebase and fix ut

* add docstring

* remove count_registered_modules

* update docstring

* resolve comments

* resolve comments

* rename ut

* fix warning

* avoid BC breaking

* update doc

* update doc

* resolve comments
This commit is contained in:
RangiLyu 2022-12-23 15:46:29 +08:00 committed by GitHub
parent 60492f4df7
commit e83ac944b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 182 additions and 62 deletions

View File

@ -18,16 +18,18 @@ There are three steps required to use the registry to manage modules in the code
Suppose we want to implement a series of activation modules and want to be able to switch to different modules by just modifying the configuration without modifying the code.
Let's create a regitry first.
Let's create a registry first.
```python
from mmengine import Registry
# scope represents the domain of the registry. If not set, the default value is the package name.
# `scope` represents the domain of the registry. If not set, the default value is the package name.
# e.g. in mmdetection, the scope is mmdet
ACTIVATION = Registry('activation', scope='mmengine')
# `locations` indicates the location where the modules in this registry are defined.
# The Registry will automatically import the modules when building them according to these predefined locations.
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
```
Then we can implement different activation modules, such as `Sigmoid`, `ReLU`, and `Softmax`.
The module `mmengine.models.activations` specified by `locations` corresponds to the `mmengine/models/activations.py` file. When building modules with registry, the ACTIVATION registry will automatically import implemented modules from this file. Therefore, we can implement different activation layers in the `mmengine/models/activations.py` file, such as `Sigmoid`, `ReLU`, and `Softmax`.
```python
import torch.nn as nn
@ -75,7 +77,14 @@ print(ACTIVATION.module_dict)
```
```{note}
The registry mechanism will only be triggered when the corresponded module file is imported, so we need to import the file somewhere or dynamically import the module using the ``custom_imports`` field to trigger the mechanism. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
The key to trigger the registry mechanism is to make the module imported.
There are three ways to register a module into the registry
1. Implement the module in the ``locations``. The registry will automatically import modules in the predefined locations. This is to ease the usage of algorithm libraries so that users can directly use ``REGISTRY.build(cfg)``.
2. Import the file manually. This is common when developers implement a new module in/out side the algorithm library.
3. Use ``custom_imports`` field in config. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
```
Once the implemented module is successfully registered, we can use the activation module in the configuration file.
@ -119,7 +128,7 @@ def build_activation(cfg, registry, *args, **kwargs):
Pass the `buid_activation` to `build_func`.
```python
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])
@ACTIVATION.register_module()
class Tanh(nn.Module):
@ -206,7 +215,7 @@ Now suppose there is a project called `MMAlpha`, which also defines a `MODELS` a
```python
from mmengine import Registry, MODELS as MMENGINE_MODELS
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
```
The following figure shows the hierarchy of `MMEngine` and `MMAlpha`.

View File

@ -24,3 +24,4 @@ mmengine.registry
build_scheduler_from_cfg
count_registered_modules
traverse_registry_tree
init_default_scope

View File

@ -23,10 +23,11 @@ MMEngine 实现的[注册器](mmengine.registry.Registry)可以看作一个映
```python
from mmengine import Registry
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
ACTIVATION = Registry('activation', scope='mmengine')
# locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
```
然后我们可以实现不同的激活模块,例如 `Sigmoid``ReLU``Softmax`
`locations` 指定的模块 `mmengine.models.activations` 对应了 `mmengine/models/activations.py` 文件。在使用注册器构建模块的时候ACTIVATION 注册器会自动从该文件中导入实现的模块。因此,我们可以在 `mmengine/models/activations.py` 文件中实现不同的激活函数,例如 `Sigmoid``ReLU``Softmax`
```python
import torch.nn as nn
@ -74,7 +75,13 @@ print(ACTIVATION.module_dict)
```
```{note}
只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 `custom_imports` 字段动态导入该模块进而触发注册机制,详情见[导入自定义 Python 模块](config.md#导入自定义-python-模块)。
只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中:
1. 在 ``locations`` 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。这种方式是为了简化算法库的使用,以便用户可以直接使用 ``REGISTRY.build(cfg)``。
2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块。
3. 在配置中使用 ``custom_imports`` 字段。 详情请参考[导入自定义Python模块](config.md#import-the-custom-module)。
```
模块成功注册后,我们可以通过配置文件使用这个激活模块。
@ -119,7 +126,7 @@ def build_activation(cfg, registry, *args, **kwargs):
并将 `build_activation` 传递给 `build_func` 参数
```python
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])
@ACTIVATION.register_module()
class Tanh(nn.Module):
@ -206,7 +213,7 @@ class RReLU(nn.Module):
```python
from mmengine import Registry, MODELS as MMENGINE_MODELS
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
```
下图是 `MMEngine``MMAlpha` 的注册器层级结构。

View File

@ -24,3 +24,4 @@ mmengine.registry
build_scheduler_from_cfg
count_registered_modules
traverse_registry_tree
init_default_scope

View File

@ -8,7 +8,8 @@ from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
from .utils import count_registered_modules, traverse_registry_tree
from .utils import (count_registered_modules, init_default_scope,
traverse_registry_tree)
__all__ = [
'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS',
@ -18,5 +19,5 @@ __all__ = [
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
'build_scheduler_from_cfg'
'build_scheduler_from_cfg', 'init_default_scope'
]

View File

@ -32,6 +32,9 @@ class Registry:
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Defaults to None.
locations (list): The locations to import the modules registered
in this registry. Defaults to [].
New in version 0.4.0.
Examples:
>>> # define a registry
@ -54,6 +57,16 @@ class Registry:
>>> pass
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
>>> # add locations to enable auto import
>>> DETECTORS = Registry('detectors', parent=MODELS,
>>> scope='det', locations=['det.models.detectors'])
>>> # define this class in 'det.models.detectors'
>>> @DETECTORS.register_module()
>>> class MaskRCNN:
>>> pass
>>> # The registry will auto import det.models.detectors.MaskRCNN
>>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN'))
More advanced usages can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
@ -62,11 +75,14 @@ class Registry:
name: str,
build_func: Optional[Callable] = None,
parent: Optional['Registry'] = None,
scope: Optional[str] = None):
scope: Optional[str] = None,
locations: List = []):
from .build_functions import build_from_cfg
self._name = name
self._module_dict: Dict[str, Type] = dict()
self._children: Dict[str, 'Registry'] = dict()
self._locations = locations
self._imported = False
if scope is not None:
assert isinstance(scope, str)
@ -240,27 +256,25 @@ class Registry:
# Get registry by scope
if default_scope is not None:
scope_name = default_scope.scope_name
if scope_name in PKG2PROJECT:
try:
module = import_module(f'{scope_name}.utils')
module.register_all_modules(False) # type: ignore
except (ImportError, AttributeError, ModuleNotFoundError):
if scope in PKG2PROJECT:
print_log(
f'{scope} is not installed and its '
'modules will not be registered. If you '
'want to use modules defined in '
f'{scope}, Please install {scope} by '
f'`pip install {PKG2PROJECT[scope]}.',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import {scope} and register '
'its modules, please make sure you '
'have registered the module manually.',
logger='current',
level=logging.WARNING)
try:
import_module(f'{scope_name}.registry')
except (ImportError, AttributeError, ModuleNotFoundError):
if scope in PKG2PROJECT:
print_log(
f'{scope} is not installed and its '
'modules will not be registered. If you '
'want to use modules defined in '
f'{scope}, Please install {scope} by '
f'`pip install {PKG2PROJECT[scope]}.',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import `{scope}.registry` '
f'make sure the registry.py exists in `{scope}` '
'package.',
logger='current',
level=logging.WARNING)
root = self._get_root_registry()
registry = root._search_child(scope_name)
if registry is None:
@ -290,6 +304,59 @@ class Registry:
root = root.parent
return root
def import_from_location(self) -> None:
"""import modules from the pre-defined locations in self._location."""
if not self._imported:
# Avoid circular import
from ..logging import print_log
# avoid BC breaking
if len(self._locations) == 0 and self.scope in PKG2PROJECT:
print_log(
f'The "{self.name}" registry in {self.scope} did not '
'set import location. Fallback to call '
f'`{self.scope}.utils.register_all_modules` '
'instead.',
logger='current',
level=logging.WARNING)
try:
module = import_module(f'{self.scope}.utils')
module.register_all_modules(False) # type: ignore
except (ImportError, AttributeError, ModuleNotFoundError):
if self.scope in PKG2PROJECT:
print_log(
f'{self.scope} is not installed and its '
'modules will not be registered. If you '
'want to use modules defined in '
f'{self.scope}, Please install {self.scope} by '
f'`pip install {PKG2PROJECT[self.scope]}.',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import {self.scope} and register '
'its modules, please make sure you '
'have registered the module manually.',
logger='current',
level=logging.WARNING)
for loc in self._locations:
try:
import_module(loc)
print_log(
f"Modules of {self.scope}'s {self.name} registry have "
f'been automatically imported from {loc}',
logger='current',
level=logging.DEBUG)
except (ImportError, AttributeError, ModuleNotFoundError):
print_log(
f'Failed to import {loc}, please check the '
f'location of the registry {self.name} is '
'correct.',
logger='current',
level=logging.WARNING)
self._imported = True
def get(self, key: str) -> Optional[Type]:
"""Get the registry record.
@ -346,11 +413,14 @@ class Registry:
obj_cls = None
registry_name = self.name
scope_name = self.scope
# lazy import the modules to register them into the registry
self.import_from_location()
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
obj_cls = self._module_dict[real_key]
elif scope is None:
# try to get the target from its parent or ancestors
parent = self.parent
@ -362,24 +432,21 @@ class Registry:
break
parent = parent.parent
else:
# import the registry to add the nodes into the registry tree
try:
module = import_module(f'{scope}.utils')
module.register_all_modules(False) # type: ignore
import_module(f'{scope}.registry')
print_log(
f'Registry node of {scope} has been automatically '
'imported.',
logger='current',
level=logging.DEBUG)
except (ImportError, AttributeError, ModuleNotFoundError):
if scope in PKG2PROJECT:
print_log(
f'{scope} is not installed and its modules '
'will not be registered. If you want to use '
f'modules defined in {scope}, Please install '
f'{scope} by `pip install {PKG2PROJECT[scope]} ',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import "{scope}", and register its '
f'modules. Please register {real_key} manually.',
logger='current',
level=logging.WARNING)
print_log(
f'Cannot auto import {scope}.registry, please check '
f'whether the package "{scope}" is installed correctly '
'or import the registry manually.',
logger='current',
level=logging.DEBUG)
# get from self._children
if scope in self._children:
obj_cls = self._children[scope].get(real_key)

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import os.path as osp
import warnings
from typing import Optional
from mmengine.fileio import dump
from mmengine.logging import print_log
from . import root
from .default_scope import DefaultScope
from .registry import Registry
@ -90,3 +92,25 @@ def count_registered_modules(save_path: Optional[str] = None,
dump(scan_data, json_path, indent=2)
print_log(f'Result has been saved to {json_path}', logger='current')
return scan_data
def init_default_scope(scope: str) -> None:
"""Initialize the given default scope.
Args:
scope (str): The name of the default scope.
"""
never_created = DefaultScope.get_current_instance(
) is None or not DefaultScope.check_instance_created(scope)
if never_created:
DefaultScope.get_instance(scope, scope_name=scope)
return
current_scope = DefaultScope.get_current_instance() # type: ignore
if current_scope.scope_name != scope: # type: ignore
warnings.warn('The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", '
'`init_default_scope` will force set the current'
f'default scope to "{scope}".')
# avoid name conflict
new_instance_name = f'{scope}-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name=scope)

View File

@ -34,8 +34,7 @@ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
VISUALIZERS, DefaultScope,
count_registered_modules)
VISUALIZERS, DefaultScope)
from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
@ -372,11 +371,6 @@ class Runner:
# Collect and log environment information.
self._log_env(env_cfg)
# collect information of all modules registered in the registries
registries_info = count_registered_modules(
self.work_dir if self.rank == 0 else None, verbose=False)
self.logger.debug(registries_info)
# Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not

View File

@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import os.path as osp
from tempfile import TemporaryDirectory
from unittest import TestCase, skipIf
from mmengine.registry import (Registry, count_registered_modules, root,
traverse_registry_tree)
from mmengine.registry import (DefaultScope, Registry,
count_registered_modules, init_default_scope,
root, traverse_registry_tree)
from mmengine.utils import is_installed
@ -62,3 +64,17 @@ class TestUtils(TestCase):
self.assertFalse(
osp.exists(
osp.join(temp_dir.name, 'modules_statistic_results.json')))
@skipIf(not is_installed('torch'), 'tests requires torch')
def test_init_default_scope(self):
# init default scope
init_default_scope('mmdet')
self.assertEqual(DefaultScope.get_current_instance().scope_name,
'mmdet')
# init default scope when another scope is init
name = f'test-{datetime.datetime.now()}'
DefaultScope.get_instance(name, scope_name='test')
with self.assertWarnsRegex(
Warning, 'The current default scope "test" is not "mmdet"'):
init_default_scope('mmdet')