[Feature] Registry supports import modules automatically. (#643)

* [Feature] Support registry auto import modules.

* update

* rebase and fix ut

* add docstring

* remove count_registered_modules

* update docstring

* resolve comments

* resolve comments

* rename ut

* fix warning

* avoid BC breaking

* update doc

* update doc

* resolve comments
This commit is contained in:
RangiLyu 2022-12-23 15:46:29 +08:00 committed by GitHub
parent 60492f4df7
commit e83ac944b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 182 additions and 62 deletions

View File

@ -18,16 +18,18 @@ There are three steps required to use the registry to manage modules in the code
Suppose we want to implement a series of activation modules and want to be able to switch to different modules by just modifying the configuration without modifying the code. Suppose we want to implement a series of activation modules and want to be able to switch to different modules by just modifying the configuration without modifying the code.
Let's create a regitry first. Let's create a registry first.
```python ```python
from mmengine import Registry from mmengine import Registry
# scope represents the domain of the registry. If not set, the default value is the package name. # `scope` represents the domain of the registry. If not set, the default value is the package name.
# e.g. in mmdetection, the scope is mmdet # e.g. in mmdetection, the scope is mmdet
ACTIVATION = Registry('activation', scope='mmengine') # `locations` indicates the location where the modules in this registry are defined.
# The Registry will automatically import the modules when building them according to these predefined locations.
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
``` ```
Then we can implement different activation modules, such as `Sigmoid`, `ReLU`, and `Softmax`. The module `mmengine.models.activations` specified by `locations` corresponds to the `mmengine/models/activations.py` file. When building modules with registry, the ACTIVATION registry will automatically import implemented modules from this file. Therefore, we can implement different activation layers in the `mmengine/models/activations.py` file, such as `Sigmoid`, `ReLU`, and `Softmax`.
```python ```python
import torch.nn as nn import torch.nn as nn
@ -75,7 +77,14 @@ print(ACTIVATION.module_dict)
``` ```
```{note} ```{note}
The registry mechanism will only be triggered when the corresponded module file is imported, so we need to import the file somewhere or dynamically import the module using the ``custom_imports`` field to trigger the mechanism. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details. The key to trigger the registry mechanism is to make the module imported.
There are three ways to register a module into the registry
1. Implement the module in the ``locations``. The registry will automatically import modules in the predefined locations. This is to ease the usage of algorithm libraries so that users can directly use ``REGISTRY.build(cfg)``.
2. Import the file manually. This is common when developers implement a new module in/out side the algorithm library.
3. Use ``custom_imports`` field in config. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
``` ```
Once the implemented module is successfully registered, we can use the activation module in the configuration file. Once the implemented module is successfully registered, we can use the activation module in the configuration file.
@ -119,7 +128,7 @@ def build_activation(cfg, registry, *args, **kwargs):
Pass the `buid_activation` to `build_func`. Pass the `buid_activation` to `build_func`.
```python ```python
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine') ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])
@ACTIVATION.register_module() @ACTIVATION.register_module()
class Tanh(nn.Module): class Tanh(nn.Module):
@ -206,7 +215,7 @@ Now suppose there is a project called `MMAlpha`, which also defines a `MODELS` a
```python ```python
from mmengine import Registry, MODELS as MMENGINE_MODELS from mmengine import Registry, MODELS as MMENGINE_MODELS
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha') MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
``` ```
The following figure shows the hierarchy of `MMEngine` and `MMAlpha`. The following figure shows the hierarchy of `MMEngine` and `MMAlpha`.

View File

@ -24,3 +24,4 @@ mmengine.registry
build_scheduler_from_cfg build_scheduler_from_cfg
count_registered_modules count_registered_modules
traverse_registry_tree traverse_registry_tree
init_default_scope

View File

@ -23,10 +23,11 @@ MMEngine 实现的[注册器](mmengine.registry.Registry)可以看作一个映
```python ```python
from mmengine import Registry from mmengine import Registry
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet # scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
ACTIVATION = Registry('activation', scope='mmengine') # locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
``` ```
然后我们可以实现不同的激活模块,例如 `Sigmoid``ReLU``Softmax` `locations` 指定的模块 `mmengine.models.activations` 对应了 `mmengine/models/activations.py` 文件。在使用注册器构建模块的时候ACTIVATION 注册器会自动从该文件中导入实现的模块。因此,我们可以在 `mmengine/models/activations.py` 文件中实现不同的激活函数,例如 `Sigmoid``ReLU``Softmax`
```python ```python
import torch.nn as nn import torch.nn as nn
@ -74,7 +75,13 @@ print(ACTIVATION.module_dict)
``` ```
```{note} ```{note}
只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 `custom_imports` 字段动态导入该模块进而触发注册机制,详情见[导入自定义 Python 模块](config.md#导入自定义-python-模块)。 只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中:
1. 在 ``locations`` 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。这种方式是为了简化算法库的使用,以便用户可以直接使用 ``REGISTRY.build(cfg)``。
2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块。
3. 在配置中使用 ``custom_imports`` 字段。 详情请参考[导入自定义Python模块](config.md#import-the-custom-module)。
``` ```
模块成功注册后,我们可以通过配置文件使用这个激活模块。 模块成功注册后,我们可以通过配置文件使用这个激活模块。
@ -119,7 +126,7 @@ def build_activation(cfg, registry, *args, **kwargs):
并将 `build_activation` 传递给 `build_func` 参数 并将 `build_activation` 传递给 `build_func` 参数
```python ```python
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine') ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])
@ACTIVATION.register_module() @ACTIVATION.register_module()
class Tanh(nn.Module): class Tanh(nn.Module):
@ -206,7 +213,7 @@ class RReLU(nn.Module):
```python ```python
from mmengine import Registry, MODELS as MMENGINE_MODELS from mmengine import Registry, MODELS as MMENGINE_MODELS
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha') MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
``` ```
下图是 `MMEngine``MMAlpha` 的注册器层级结构。 下图是 `MMEngine``MMAlpha` 的注册器层级结构。

View File

@ -24,3 +24,4 @@ mmengine.registry
build_scheduler_from_cfg build_scheduler_from_cfg
count_registered_modules count_registered_modules
traverse_registry_tree traverse_registry_tree
init_default_scope

View File

@ -8,7 +8,8 @@ from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
from .utils import count_registered_modules, traverse_registry_tree from .utils import (count_registered_modules, init_default_scope,
traverse_registry_tree)
__all__ = [ __all__ = [
'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS',
@ -18,5 +19,5 @@ __all__ = [
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR', 'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules', 'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg', 'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
'build_scheduler_from_cfg' 'build_scheduler_from_cfg', 'init_default_scope'
] ]

View File

@ -32,6 +32,9 @@ class Registry:
for children registry. If not specified, scope will be the name of for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg. the package where class is defined, e.g. mmdet, mmcls, mmseg.
Defaults to None. Defaults to None.
locations (list): The locations to import the modules registered
in this registry. Defaults to [].
New in version 0.4.0.
Examples: Examples:
>>> # define a registry >>> # define a registry
@ -54,6 +57,16 @@ class Registry:
>>> pass >>> pass
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN')) >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
>>> # add locations to enable auto import
>>> DETECTORS = Registry('detectors', parent=MODELS,
>>> scope='det', locations=['det.models.detectors'])
>>> # define this class in 'det.models.detectors'
>>> @DETECTORS.register_module()
>>> class MaskRCNN:
>>> pass
>>> # The registry will auto import det.models.detectors.MaskRCNN
>>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN'))
More advanced usages can be found at More advanced usages can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
""" """
@ -62,11 +75,14 @@ class Registry:
name: str, name: str,
build_func: Optional[Callable] = None, build_func: Optional[Callable] = None,
parent: Optional['Registry'] = None, parent: Optional['Registry'] = None,
scope: Optional[str] = None): scope: Optional[str] = None,
locations: List = []):
from .build_functions import build_from_cfg from .build_functions import build_from_cfg
self._name = name self._name = name
self._module_dict: Dict[str, Type] = dict() self._module_dict: Dict[str, Type] = dict()
self._children: Dict[str, 'Registry'] = dict() self._children: Dict[str, 'Registry'] = dict()
self._locations = locations
self._imported = False
if scope is not None: if scope is not None:
assert isinstance(scope, str) assert isinstance(scope, str)
@ -240,10 +256,8 @@ class Registry:
# Get registry by scope # Get registry by scope
if default_scope is not None: if default_scope is not None:
scope_name = default_scope.scope_name scope_name = default_scope.scope_name
if scope_name in PKG2PROJECT:
try: try:
module = import_module(f'{scope_name}.utils') import_module(f'{scope_name}.registry')
module.register_all_modules(False) # type: ignore
except (ImportError, AttributeError, ModuleNotFoundError): except (ImportError, AttributeError, ModuleNotFoundError):
if scope in PKG2PROJECT: if scope in PKG2PROJECT:
print_log( print_log(
@ -256,9 +270,9 @@ class Registry:
level=logging.WARNING) level=logging.WARNING)
else: else:
print_log( print_log(
f'Failed to import {scope} and register ' f'Failed to import `{scope}.registry` '
'its modules, please make sure you ' f'make sure the registry.py exists in `{scope}` '
'have registered the module manually.', 'package.',
logger='current', logger='current',
level=logging.WARNING) level=logging.WARNING)
root = self._get_root_registry() root = self._get_root_registry()
@ -290,6 +304,59 @@ class Registry:
root = root.parent root = root.parent
return root return root
def import_from_location(self) -> None:
"""import modules from the pre-defined locations in self._location."""
if not self._imported:
# Avoid circular import
from ..logging import print_log
# avoid BC breaking
if len(self._locations) == 0 and self.scope in PKG2PROJECT:
print_log(
f'The "{self.name}" registry in {self.scope} did not '
'set import location. Fallback to call '
f'`{self.scope}.utils.register_all_modules` '
'instead.',
logger='current',
level=logging.WARNING)
try:
module = import_module(f'{self.scope}.utils')
module.register_all_modules(False) # type: ignore
except (ImportError, AttributeError, ModuleNotFoundError):
if self.scope in PKG2PROJECT:
print_log(
f'{self.scope} is not installed and its '
'modules will not be registered. If you '
'want to use modules defined in '
f'{self.scope}, Please install {self.scope} by '
f'`pip install {PKG2PROJECT[self.scope]}.',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import {self.scope} and register '
'its modules, please make sure you '
'have registered the module manually.',
logger='current',
level=logging.WARNING)
for loc in self._locations:
try:
import_module(loc)
print_log(
f"Modules of {self.scope}'s {self.name} registry have "
f'been automatically imported from {loc}',
logger='current',
level=logging.DEBUG)
except (ImportError, AttributeError, ModuleNotFoundError):
print_log(
f'Failed to import {loc}, please check the '
f'location of the registry {self.name} is '
'correct.',
logger='current',
level=logging.WARNING)
self._imported = True
def get(self, key: str) -> Optional[Type]: def get(self, key: str) -> Optional[Type]:
"""Get the registry record. """Get the registry record.
@ -346,11 +413,14 @@ class Registry:
obj_cls = None obj_cls = None
registry_name = self.name registry_name = self.name
scope_name = self.scope scope_name = self.scope
# lazy import the modules to register them into the registry
self.import_from_location()
if scope is None or scope == self._scope: if scope is None or scope == self._scope:
# get from self # get from self
if real_key in self._module_dict: if real_key in self._module_dict:
obj_cls = self._module_dict[real_key] obj_cls = self._module_dict[real_key]
elif scope is None: elif scope is None:
# try to get the target from its parent or ancestors # try to get the target from its parent or ancestors
parent = self.parent parent = self.parent
@ -362,24 +432,21 @@ class Registry:
break break
parent = parent.parent parent = parent.parent
else: else:
# import the registry to add the nodes into the registry tree
try: try:
module = import_module(f'{scope}.utils') import_module(f'{scope}.registry')
module.register_all_modules(False) # type: ignore print_log(
f'Registry node of {scope} has been automatically '
'imported.',
logger='current',
level=logging.DEBUG)
except (ImportError, AttributeError, ModuleNotFoundError): except (ImportError, AttributeError, ModuleNotFoundError):
if scope in PKG2PROJECT:
print_log( print_log(
f'{scope} is not installed and its modules ' f'Cannot auto import {scope}.registry, please check '
'will not be registered. If you want to use ' f'whether the package "{scope}" is installed correctly '
f'modules defined in {scope}, Please install ' 'or import the registry manually.',
f'{scope} by `pip install {PKG2PROJECT[scope]} ',
logger='current', logger='current',
level=logging.WARNING) level=logging.DEBUG)
else:
print_log(
f'Failed to import "{scope}", and register its '
f'modules. Please register {real_key} manually.',
logger='current',
level=logging.WARNING)
# get from self._children # get from self._children
if scope in self._children: if scope in self._children:
obj_cls = self._children[scope].get(real_key) obj_cls = self._children[scope].get(real_key)

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import datetime import datetime
import os.path as osp import os.path as osp
import warnings
from typing import Optional from typing import Optional
from mmengine.fileio import dump from mmengine.fileio import dump
from mmengine.logging import print_log from mmengine.logging import print_log
from . import root from . import root
from .default_scope import DefaultScope
from .registry import Registry from .registry import Registry
@ -90,3 +92,25 @@ def count_registered_modules(save_path: Optional[str] = None,
dump(scan_data, json_path, indent=2) dump(scan_data, json_path, indent=2)
print_log(f'Result has been saved to {json_path}', logger='current') print_log(f'Result has been saved to {json_path}', logger='current')
return scan_data return scan_data
def init_default_scope(scope: str) -> None:
"""Initialize the given default scope.
Args:
scope (str): The name of the default scope.
"""
never_created = DefaultScope.get_current_instance(
) is None or not DefaultScope.check_instance_created(scope)
if never_created:
DefaultScope.get_instance(scope, scope_name=scope)
return
current_scope = DefaultScope.get_current_instance() # type: ignore
if current_scope.scope_name != scope: # type: ignore
warnings.warn('The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", '
'`init_default_scope` will force set the current'
f'default scope to "{scope}".')
# avoid name conflict
new_instance_name = f'{scope}-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name=scope)

View File

@ -34,8 +34,7 @@ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
VISUALIZERS, DefaultScope, VISUALIZERS, DefaultScope)
count_registered_modules)
from mmengine.utils import digit_version, get_git_hash, is_seq_of from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing) set_multi_processing)
@ -372,11 +371,6 @@ class Runner:
# Collect and log environment information. # Collect and log environment information.
self._log_env(env_cfg) self._log_env(env_cfg)
# collect information of all modules registered in the registries
registries_info = count_registered_modules(
self.work_dir if self.rank == 0 else None, verbose=False)
self.logger.debug(registries_info)
# Build `message_hub` for communication among components. # Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and # `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not # runtime information (iter and epoch). Those components that do not

View File

@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import datetime
import os.path as osp import os.path as osp
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from unittest import TestCase, skipIf from unittest import TestCase, skipIf
from mmengine.registry import (Registry, count_registered_modules, root, from mmengine.registry import (DefaultScope, Registry,
traverse_registry_tree) count_registered_modules, init_default_scope,
root, traverse_registry_tree)
from mmengine.utils import is_installed from mmengine.utils import is_installed
@ -62,3 +64,17 @@ class TestUtils(TestCase):
self.assertFalse( self.assertFalse(
osp.exists( osp.exists(
osp.join(temp_dir.name, 'modules_statistic_results.json'))) osp.join(temp_dir.name, 'modules_statistic_results.json')))
@skipIf(not is_installed('torch'), 'tests requires torch')
def test_init_default_scope(self):
# init default scope
init_default_scope('mmdet')
self.assertEqual(DefaultScope.get_current_instance().scope_name,
'mmdet')
# init default scope when another scope is init
name = f'test-{datetime.datetime.now()}'
DefaultScope.get_instance(name, scope_name='test')
with self.assertWarnsRegex(
Warning, 'The current default scope "test" is not "mmdet"'):
init_default_scope('mmdet')