mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Registry supports import modules automatically. (#643)
* [Feature] Support registry auto import modules. * update * rebase and fix ut * add docstring * remove count_registered_modules * update docstring * resolve comments * resolve comments * rename ut * fix warning * avoid BC breaking * update doc * update doc * resolve comments
This commit is contained in:
parent
60492f4df7
commit
e83ac944b6
@ -18,16 +18,18 @@ There are three steps required to use the registry to manage modules in the code
|
||||
|
||||
Suppose we want to implement a series of activation modules and want to be able to switch to different modules by just modifying the configuration without modifying the code.
|
||||
|
||||
Let's create a regitry first.
|
||||
Let's create a registry first.
|
||||
|
||||
```python
|
||||
from mmengine import Registry
|
||||
# scope represents the domain of the registry. If not set, the default value is the package name.
|
||||
# `scope` represents the domain of the registry. If not set, the default value is the package name.
|
||||
# e.g. in mmdetection, the scope is mmdet
|
||||
ACTIVATION = Registry('activation', scope='mmengine')
|
||||
# `locations` indicates the location where the modules in this registry are defined.
|
||||
# The Registry will automatically import the modules when building them according to these predefined locations.
|
||||
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
|
||||
```
|
||||
|
||||
Then we can implement different activation modules, such as `Sigmoid`, `ReLU`, and `Softmax`.
|
||||
The module `mmengine.models.activations` specified by `locations` corresponds to the `mmengine/models/activations.py` file. When building modules with registry, the ACTIVATION registry will automatically import implemented modules from this file. Therefore, we can implement different activation layers in the `mmengine/models/activations.py` file, such as `Sigmoid`, `ReLU`, and `Softmax`.
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
@ -75,7 +77,14 @@ print(ACTIVATION.module_dict)
|
||||
```
|
||||
|
||||
```{note}
|
||||
The registry mechanism will only be triggered when the corresponded module file is imported, so we need to import the file somewhere or dynamically import the module using the ``custom_imports`` field to trigger the mechanism. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
|
||||
The key to trigger the registry mechanism is to make the module imported.
|
||||
There are three ways to register a module into the registry
|
||||
|
||||
1. Implement the module in the ``locations``. The registry will automatically import modules in the predefined locations. This is to ease the usage of algorithm libraries so that users can directly use ``REGISTRY.build(cfg)``.
|
||||
|
||||
2. Import the file manually. This is common when developers implement a new module in/out side the algorithm library.
|
||||
|
||||
3. Use ``custom_imports`` field in config. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
|
||||
```
|
||||
|
||||
Once the implemented module is successfully registered, we can use the activation module in the configuration file.
|
||||
@ -119,7 +128,7 @@ def build_activation(cfg, registry, *args, **kwargs):
|
||||
Pass the `buid_activation` to `build_func`.
|
||||
|
||||
```python
|
||||
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
|
||||
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])
|
||||
|
||||
@ACTIVATION.register_module()
|
||||
class Tanh(nn.Module):
|
||||
@ -206,7 +215,7 @@ Now suppose there is a project called `MMAlpha`, which also defines a `MODELS` a
|
||||
```python
|
||||
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
||||
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
|
||||
```
|
||||
|
||||
The following figure shows the hierarchy of `MMEngine` and `MMAlpha`.
|
||||
|
@ -24,3 +24,4 @@ mmengine.registry
|
||||
build_scheduler_from_cfg
|
||||
count_registered_modules
|
||||
traverse_registry_tree
|
||||
init_default_scope
|
||||
|
@ -23,10 +23,11 @@ MMEngine 实现的[注册器](mmengine.registry.Registry)可以看作一个映
|
||||
```python
|
||||
from mmengine import Registry
|
||||
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
|
||||
ACTIVATION = Registry('activation', scope='mmengine')
|
||||
# locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import
|
||||
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
|
||||
```
|
||||
|
||||
然后我们可以实现不同的激活模块,例如 `Sigmoid`,`ReLU` 和 `Softmax`。
|
||||
`locations` 指定的模块 `mmengine.models.activations` 对应了 `mmengine/models/activations.py` 文件。在使用注册器构建模块的时候,ACTIVATION 注册器会自动从该文件中导入实现的模块。因此,我们可以在 `mmengine/models/activations.py` 文件中实现不同的激活函数,例如 `Sigmoid`,`ReLU` 和 `Softmax`。
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
@ -74,7 +75,13 @@ print(ACTIVATION.module_dict)
|
||||
```
|
||||
|
||||
```{note}
|
||||
只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 `custom_imports` 字段动态导入该模块进而触发注册机制,详情见[导入自定义 Python 模块](config.md#导入自定义-python-模块)。
|
||||
只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中:
|
||||
|
||||
1. 在 ``locations`` 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。这种方式是为了简化算法库的使用,以便用户可以直接使用 ``REGISTRY.build(cfg)``。
|
||||
|
||||
2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块。
|
||||
|
||||
3. 在配置中使用 ``custom_imports`` 字段。 详情请参考[导入自定义Python模块](config.md#import-the-custom-module)。
|
||||
```
|
||||
|
||||
模块成功注册后,我们可以通过配置文件使用这个激活模块。
|
||||
@ -119,7 +126,7 @@ def build_activation(cfg, registry, *args, **kwargs):
|
||||
并将 `build_activation` 传递给 `build_func` 参数
|
||||
|
||||
```python
|
||||
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
|
||||
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])
|
||||
|
||||
@ACTIVATION.register_module()
|
||||
class Tanh(nn.Module):
|
||||
@ -206,7 +213,7 @@ class RReLU(nn.Module):
|
||||
```python
|
||||
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
||||
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
|
||||
```
|
||||
|
||||
下图是 `MMEngine` 和 `MMAlpha` 的注册器层级结构。
|
||||
|
@ -24,3 +24,4 @@ mmengine.registry
|
||||
build_scheduler_from_cfg
|
||||
count_registered_modules
|
||||
traverse_registry_tree
|
||||
init_default_scope
|
||||
|
@ -8,7 +8,8 @@ from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
||||
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
|
||||
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
|
||||
TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
|
||||
from .utils import count_registered_modules, traverse_registry_tree
|
||||
from .utils import (count_registered_modules, init_default_scope,
|
||||
traverse_registry_tree)
|
||||
|
||||
__all__ = [
|
||||
'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS',
|
||||
@ -18,5 +19,5 @@ __all__ = [
|
||||
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
|
||||
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
|
||||
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
|
||||
'build_scheduler_from_cfg'
|
||||
'build_scheduler_from_cfg', 'init_default_scope'
|
||||
]
|
||||
|
@ -32,6 +32,9 @@ class Registry:
|
||||
for children registry. If not specified, scope will be the name of
|
||||
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||
Defaults to None.
|
||||
locations (list): The locations to import the modules registered
|
||||
in this registry. Defaults to [].
|
||||
New in version 0.4.0.
|
||||
|
||||
Examples:
|
||||
>>> # define a registry
|
||||
@ -54,6 +57,16 @@ class Registry:
|
||||
>>> pass
|
||||
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
|
||||
|
||||
>>> # add locations to enable auto import
|
||||
>>> DETECTORS = Registry('detectors', parent=MODELS,
|
||||
>>> scope='det', locations=['det.models.detectors'])
|
||||
>>> # define this class in 'det.models.detectors'
|
||||
>>> @DETECTORS.register_module()
|
||||
>>> class MaskRCNN:
|
||||
>>> pass
|
||||
>>> # The registry will auto import det.models.detectors.MaskRCNN
|
||||
>>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN'))
|
||||
|
||||
More advanced usages can be found at
|
||||
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||
"""
|
||||
@ -62,11 +75,14 @@ class Registry:
|
||||
name: str,
|
||||
build_func: Optional[Callable] = None,
|
||||
parent: Optional['Registry'] = None,
|
||||
scope: Optional[str] = None):
|
||||
scope: Optional[str] = None,
|
||||
locations: List = []):
|
||||
from .build_functions import build_from_cfg
|
||||
self._name = name
|
||||
self._module_dict: Dict[str, Type] = dict()
|
||||
self._children: Dict[str, 'Registry'] = dict()
|
||||
self._locations = locations
|
||||
self._imported = False
|
||||
|
||||
if scope is not None:
|
||||
assert isinstance(scope, str)
|
||||
@ -240,27 +256,25 @@ class Registry:
|
||||
# Get registry by scope
|
||||
if default_scope is not None:
|
||||
scope_name = default_scope.scope_name
|
||||
if scope_name in PKG2PROJECT:
|
||||
try:
|
||||
module = import_module(f'{scope_name}.utils')
|
||||
module.register_all_modules(False) # type: ignore
|
||||
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||
if scope in PKG2PROJECT:
|
||||
print_log(
|
||||
f'{scope} is not installed and its '
|
||||
'modules will not be registered. If you '
|
||||
'want to use modules defined in '
|
||||
f'{scope}, Please install {scope} by '
|
||||
f'`pip install {PKG2PROJECT[scope]}.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
print_log(
|
||||
f'Failed to import {scope} and register '
|
||||
'its modules, please make sure you '
|
||||
'have registered the module manually.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
try:
|
||||
import_module(f'{scope_name}.registry')
|
||||
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||
if scope in PKG2PROJECT:
|
||||
print_log(
|
||||
f'{scope} is not installed and its '
|
||||
'modules will not be registered. If you '
|
||||
'want to use modules defined in '
|
||||
f'{scope}, Please install {scope} by '
|
||||
f'`pip install {PKG2PROJECT[scope]}.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
print_log(
|
||||
f'Failed to import `{scope}.registry` '
|
||||
f'make sure the registry.py exists in `{scope}` '
|
||||
'package.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
root = self._get_root_registry()
|
||||
registry = root._search_child(scope_name)
|
||||
if registry is None:
|
||||
@ -290,6 +304,59 @@ class Registry:
|
||||
root = root.parent
|
||||
return root
|
||||
|
||||
def import_from_location(self) -> None:
|
||||
"""import modules from the pre-defined locations in self._location."""
|
||||
if not self._imported:
|
||||
# Avoid circular import
|
||||
from ..logging import print_log
|
||||
|
||||
# avoid BC breaking
|
||||
if len(self._locations) == 0 and self.scope in PKG2PROJECT:
|
||||
print_log(
|
||||
f'The "{self.name}" registry in {self.scope} did not '
|
||||
'set import location. Fallback to call '
|
||||
f'`{self.scope}.utils.register_all_modules` '
|
||||
'instead.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
try:
|
||||
module = import_module(f'{self.scope}.utils')
|
||||
module.register_all_modules(False) # type: ignore
|
||||
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||
if self.scope in PKG2PROJECT:
|
||||
print_log(
|
||||
f'{self.scope} is not installed and its '
|
||||
'modules will not be registered. If you '
|
||||
'want to use modules defined in '
|
||||
f'{self.scope}, Please install {self.scope} by '
|
||||
f'`pip install {PKG2PROJECT[self.scope]}.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
print_log(
|
||||
f'Failed to import {self.scope} and register '
|
||||
'its modules, please make sure you '
|
||||
'have registered the module manually.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
for loc in self._locations:
|
||||
try:
|
||||
import_module(loc)
|
||||
print_log(
|
||||
f"Modules of {self.scope}'s {self.name} registry have "
|
||||
f'been automatically imported from {loc}',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||
print_log(
|
||||
f'Failed to import {loc}, please check the '
|
||||
f'location of the registry {self.name} is '
|
||||
'correct.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self._imported = True
|
||||
|
||||
def get(self, key: str) -> Optional[Type]:
|
||||
"""Get the registry record.
|
||||
|
||||
@ -346,11 +413,14 @@ class Registry:
|
||||
obj_cls = None
|
||||
registry_name = self.name
|
||||
scope_name = self.scope
|
||||
|
||||
# lazy import the modules to register them into the registry
|
||||
self.import_from_location()
|
||||
|
||||
if scope is None or scope == self._scope:
|
||||
# get from self
|
||||
if real_key in self._module_dict:
|
||||
obj_cls = self._module_dict[real_key]
|
||||
|
||||
elif scope is None:
|
||||
# try to get the target from its parent or ancestors
|
||||
parent = self.parent
|
||||
@ -362,24 +432,21 @@ class Registry:
|
||||
break
|
||||
parent = parent.parent
|
||||
else:
|
||||
# import the registry to add the nodes into the registry tree
|
||||
try:
|
||||
module = import_module(f'{scope}.utils')
|
||||
module.register_all_modules(False) # type: ignore
|
||||
import_module(f'{scope}.registry')
|
||||
print_log(
|
||||
f'Registry node of {scope} has been automatically '
|
||||
'imported.',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||
if scope in PKG2PROJECT:
|
||||
print_log(
|
||||
f'{scope} is not installed and its modules '
|
||||
'will not be registered. If you want to use '
|
||||
f'modules defined in {scope}, Please install '
|
||||
f'{scope} by `pip install {PKG2PROJECT[scope]} ',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
print_log(
|
||||
f'Failed to import "{scope}", and register its '
|
||||
f'modules. Please register {real_key} manually.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
print_log(
|
||||
f'Cannot auto import {scope}.registry, please check '
|
||||
f'whether the package "{scope}" is installed correctly '
|
||||
'or import the registry manually.',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
# get from self._children
|
||||
if scope in self._children:
|
||||
obj_cls = self._children[scope].get(real_key)
|
||||
|
@ -1,11 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.fileio import dump
|
||||
from mmengine.logging import print_log
|
||||
from . import root
|
||||
from .default_scope import DefaultScope
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
@ -90,3 +92,25 @@ def count_registered_modules(save_path: Optional[str] = None,
|
||||
dump(scan_data, json_path, indent=2)
|
||||
print_log(f'Result has been saved to {json_path}', logger='current')
|
||||
return scan_data
|
||||
|
||||
|
||||
def init_default_scope(scope: str) -> None:
|
||||
"""Initialize the given default scope.
|
||||
|
||||
Args:
|
||||
scope (str): The name of the default scope.
|
||||
"""
|
||||
never_created = DefaultScope.get_current_instance(
|
||||
) is None or not DefaultScope.check_instance_created(scope)
|
||||
if never_created:
|
||||
DefaultScope.get_instance(scope, scope_name=scope)
|
||||
return
|
||||
current_scope = DefaultScope.get_current_instance() # type: ignore
|
||||
if current_scope.scope_name != scope: # type: ignore
|
||||
warnings.warn('The current default scope ' # type: ignore
|
||||
f'"{current_scope.scope_name}" is not "{scope}", '
|
||||
'`init_default_scope` will force set the current'
|
||||
f'default scope to "{scope}".')
|
||||
# avoid name conflict
|
||||
new_instance_name = f'{scope}-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(new_instance_name, scope_name=scope)
|
||||
|
@ -34,8 +34,7 @@ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
|
||||
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
|
||||
VISUALIZERS, DefaultScope,
|
||||
count_registered_modules)
|
||||
VISUALIZERS, DefaultScope)
|
||||
from mmengine.utils import digit_version, get_git_hash, is_seq_of
|
||||
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
|
||||
set_multi_processing)
|
||||
@ -372,11 +371,6 @@ class Runner:
|
||||
# Collect and log environment information.
|
||||
self._log_env(env_cfg)
|
||||
|
||||
# collect information of all modules registered in the registries
|
||||
registries_info = count_registered_modules(
|
||||
self.work_dir if self.rank == 0 else None, verbose=False)
|
||||
self.logger.debug(registries_info)
|
||||
|
||||
# Build `message_hub` for communication among components.
|
||||
# `message_hub` can store log scalars (loss, learning rate) and
|
||||
# runtime information (iter and epoch). Those components that do not
|
||||
|
@ -1,10 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import os.path as osp
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest import TestCase, skipIf
|
||||
|
||||
from mmengine.registry import (Registry, count_registered_modules, root,
|
||||
traverse_registry_tree)
|
||||
from mmengine.registry import (DefaultScope, Registry,
|
||||
count_registered_modules, init_default_scope,
|
||||
root, traverse_registry_tree)
|
||||
from mmengine.utils import is_installed
|
||||
|
||||
|
||||
@ -62,3 +64,17 @@ class TestUtils(TestCase):
|
||||
self.assertFalse(
|
||||
osp.exists(
|
||||
osp.join(temp_dir.name, 'modules_statistic_results.json')))
|
||||
|
||||
@skipIf(not is_installed('torch'), 'tests requires torch')
|
||||
def test_init_default_scope(self):
|
||||
# init default scope
|
||||
init_default_scope('mmdet')
|
||||
self.assertEqual(DefaultScope.get_current_instance().scope_name,
|
||||
'mmdet')
|
||||
|
||||
# init default scope when another scope is init
|
||||
name = f'test-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(name, scope_name='test')
|
||||
with self.assertWarnsRegex(
|
||||
Warning, 'The current default scope "test" is not "mmdet"'):
|
||||
init_default_scope('mmdet')
|
||||
|
Loading…
x
Reference in New Issue
Block a user