mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Registry supports import modules automatically. (#643)
* [Feature] Support registry auto import modules. * update * rebase and fix ut * add docstring * remove count_registered_modules * update docstring * resolve comments * resolve comments * rename ut * fix warning * avoid BC breaking * update doc * update doc * resolve comments
This commit is contained in:
parent
60492f4df7
commit
e83ac944b6
@ -18,16 +18,18 @@ There are three steps required to use the registry to manage modules in the code
|
|||||||
|
|
||||||
Suppose we want to implement a series of activation modules and want to be able to switch to different modules by just modifying the configuration without modifying the code.
|
Suppose we want to implement a series of activation modules and want to be able to switch to different modules by just modifying the configuration without modifying the code.
|
||||||
|
|
||||||
Let's create a regitry first.
|
Let's create a registry first.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine import Registry
|
from mmengine import Registry
|
||||||
# scope represents the domain of the registry. If not set, the default value is the package name.
|
# `scope` represents the domain of the registry. If not set, the default value is the package name.
|
||||||
# e.g. in mmdetection, the scope is mmdet
|
# e.g. in mmdetection, the scope is mmdet
|
||||||
ACTIVATION = Registry('activation', scope='mmengine')
|
# `locations` indicates the location where the modules in this registry are defined.
|
||||||
|
# The Registry will automatically import the modules when building them according to these predefined locations.
|
||||||
|
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
|
||||||
```
|
```
|
||||||
|
|
||||||
Then we can implement different activation modules, such as `Sigmoid`, `ReLU`, and `Softmax`.
|
The module `mmengine.models.activations` specified by `locations` corresponds to the `mmengine/models/activations.py` file. When building modules with registry, the ACTIVATION registry will automatically import implemented modules from this file. Therefore, we can implement different activation layers in the `mmengine/models/activations.py` file, such as `Sigmoid`, `ReLU`, and `Softmax`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -75,7 +77,14 @@ print(ACTIVATION.module_dict)
|
|||||||
```
|
```
|
||||||
|
|
||||||
```{note}
|
```{note}
|
||||||
The registry mechanism will only be triggered when the corresponded module file is imported, so we need to import the file somewhere or dynamically import the module using the ``custom_imports`` field to trigger the mechanism. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
|
The key to trigger the registry mechanism is to make the module imported.
|
||||||
|
There are three ways to register a module into the registry
|
||||||
|
|
||||||
|
1. Implement the module in the ``locations``. The registry will automatically import modules in the predefined locations. This is to ease the usage of algorithm libraries so that users can directly use ``REGISTRY.build(cfg)``.
|
||||||
|
|
||||||
|
2. Import the file manually. This is common when developers implement a new module in/out side the algorithm library.
|
||||||
|
|
||||||
|
3. Use ``custom_imports`` field in config. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
|
||||||
```
|
```
|
||||||
|
|
||||||
Once the implemented module is successfully registered, we can use the activation module in the configuration file.
|
Once the implemented module is successfully registered, we can use the activation module in the configuration file.
|
||||||
@ -119,7 +128,7 @@ def build_activation(cfg, registry, *args, **kwargs):
|
|||||||
Pass the `buid_activation` to `build_func`.
|
Pass the `buid_activation` to `build_func`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
|
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])
|
||||||
|
|
||||||
@ACTIVATION.register_module()
|
@ACTIVATION.register_module()
|
||||||
class Tanh(nn.Module):
|
class Tanh(nn.Module):
|
||||||
@ -206,7 +215,7 @@ Now suppose there is a project called `MMAlpha`, which also defines a `MODELS` a
|
|||||||
```python
|
```python
|
||||||
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
||||||
|
|
||||||
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
|
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
|
||||||
```
|
```
|
||||||
|
|
||||||
The following figure shows the hierarchy of `MMEngine` and `MMAlpha`.
|
The following figure shows the hierarchy of `MMEngine` and `MMAlpha`.
|
||||||
|
@ -24,3 +24,4 @@ mmengine.registry
|
|||||||
build_scheduler_from_cfg
|
build_scheduler_from_cfg
|
||||||
count_registered_modules
|
count_registered_modules
|
||||||
traverse_registry_tree
|
traverse_registry_tree
|
||||||
|
init_default_scope
|
||||||
|
@ -23,10 +23,11 @@ MMEngine 实现的[注册器](mmengine.registry.Registry)可以看作一个映
|
|||||||
```python
|
```python
|
||||||
from mmengine import Registry
|
from mmengine import Registry
|
||||||
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
|
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
|
||||||
ACTIVATION = Registry('activation', scope='mmengine')
|
# locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import
|
||||||
|
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
|
||||||
```
|
```
|
||||||
|
|
||||||
然后我们可以实现不同的激活模块,例如 `Sigmoid`,`ReLU` 和 `Softmax`。
|
`locations` 指定的模块 `mmengine.models.activations` 对应了 `mmengine/models/activations.py` 文件。在使用注册器构建模块的时候,ACTIVATION 注册器会自动从该文件中导入实现的模块。因此,我们可以在 `mmengine/models/activations.py` 文件中实现不同的激活函数,例如 `Sigmoid`,`ReLU` 和 `Softmax`。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -74,7 +75,13 @@ print(ACTIVATION.module_dict)
|
|||||||
```
|
```
|
||||||
|
|
||||||
```{note}
|
```{note}
|
||||||
只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 `custom_imports` 字段动态导入该模块进而触发注册机制,详情见[导入自定义 Python 模块](config.md#导入自定义-python-模块)。
|
只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中:
|
||||||
|
|
||||||
|
1. 在 ``locations`` 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。这种方式是为了简化算法库的使用,以便用户可以直接使用 ``REGISTRY.build(cfg)``。
|
||||||
|
|
||||||
|
2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块。
|
||||||
|
|
||||||
|
3. 在配置中使用 ``custom_imports`` 字段。 详情请参考[导入自定义Python模块](config.md#import-the-custom-module)。
|
||||||
```
|
```
|
||||||
|
|
||||||
模块成功注册后,我们可以通过配置文件使用这个激活模块。
|
模块成功注册后,我们可以通过配置文件使用这个激活模块。
|
||||||
@ -119,7 +126,7 @@ def build_activation(cfg, registry, *args, **kwargs):
|
|||||||
并将 `build_activation` 传递给 `build_func` 参数
|
并将 `build_activation` 传递给 `build_func` 参数
|
||||||
|
|
||||||
```python
|
```python
|
||||||
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
|
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])
|
||||||
|
|
||||||
@ACTIVATION.register_module()
|
@ACTIVATION.register_module()
|
||||||
class Tanh(nn.Module):
|
class Tanh(nn.Module):
|
||||||
@ -206,7 +213,7 @@ class RReLU(nn.Module):
|
|||||||
```python
|
```python
|
||||||
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
from mmengine import Registry, MODELS as MMENGINE_MODELS
|
||||||
|
|
||||||
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
|
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
|
||||||
```
|
```
|
||||||
|
|
||||||
下图是 `MMEngine` 和 `MMAlpha` 的注册器层级结构。
|
下图是 `MMEngine` 和 `MMAlpha` 的注册器层级结构。
|
||||||
|
@ -24,3 +24,4 @@ mmengine.registry
|
|||||||
build_scheduler_from_cfg
|
build_scheduler_from_cfg
|
||||||
count_registered_modules
|
count_registered_modules
|
||||||
traverse_registry_tree
|
traverse_registry_tree
|
||||||
|
init_default_scope
|
||||||
|
@ -8,7 +8,8 @@ from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
|||||||
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
|
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
|
||||||
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
|
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
|
||||||
TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
|
TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
|
||||||
from .utils import count_registered_modules, traverse_registry_tree
|
from .utils import (count_registered_modules, init_default_scope,
|
||||||
|
traverse_registry_tree)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS',
|
'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS',
|
||||||
@ -18,5 +19,5 @@ __all__ = [
|
|||||||
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
|
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
|
||||||
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
|
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
|
||||||
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
|
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
|
||||||
'build_scheduler_from_cfg'
|
'build_scheduler_from_cfg', 'init_default_scope'
|
||||||
]
|
]
|
||||||
|
@ -32,6 +32,9 @@ class Registry:
|
|||||||
for children registry. If not specified, scope will be the name of
|
for children registry. If not specified, scope will be the name of
|
||||||
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
locations (list): The locations to import the modules registered
|
||||||
|
in this registry. Defaults to [].
|
||||||
|
New in version 0.4.0.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # define a registry
|
>>> # define a registry
|
||||||
@ -54,6 +57,16 @@ class Registry:
|
|||||||
>>> pass
|
>>> pass
|
||||||
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
|
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
|
||||||
|
|
||||||
|
>>> # add locations to enable auto import
|
||||||
|
>>> DETECTORS = Registry('detectors', parent=MODELS,
|
||||||
|
>>> scope='det', locations=['det.models.detectors'])
|
||||||
|
>>> # define this class in 'det.models.detectors'
|
||||||
|
>>> @DETECTORS.register_module()
|
||||||
|
>>> class MaskRCNN:
|
||||||
|
>>> pass
|
||||||
|
>>> # The registry will auto import det.models.detectors.MaskRCNN
|
||||||
|
>>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN'))
|
||||||
|
|
||||||
More advanced usages can be found at
|
More advanced usages can be found at
|
||||||
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||||
"""
|
"""
|
||||||
@ -62,11 +75,14 @@ class Registry:
|
|||||||
name: str,
|
name: str,
|
||||||
build_func: Optional[Callable] = None,
|
build_func: Optional[Callable] = None,
|
||||||
parent: Optional['Registry'] = None,
|
parent: Optional['Registry'] = None,
|
||||||
scope: Optional[str] = None):
|
scope: Optional[str] = None,
|
||||||
|
locations: List = []):
|
||||||
from .build_functions import build_from_cfg
|
from .build_functions import build_from_cfg
|
||||||
self._name = name
|
self._name = name
|
||||||
self._module_dict: Dict[str, Type] = dict()
|
self._module_dict: Dict[str, Type] = dict()
|
||||||
self._children: Dict[str, 'Registry'] = dict()
|
self._children: Dict[str, 'Registry'] = dict()
|
||||||
|
self._locations = locations
|
||||||
|
self._imported = False
|
||||||
|
|
||||||
if scope is not None:
|
if scope is not None:
|
||||||
assert isinstance(scope, str)
|
assert isinstance(scope, str)
|
||||||
@ -240,10 +256,8 @@ class Registry:
|
|||||||
# Get registry by scope
|
# Get registry by scope
|
||||||
if default_scope is not None:
|
if default_scope is not None:
|
||||||
scope_name = default_scope.scope_name
|
scope_name = default_scope.scope_name
|
||||||
if scope_name in PKG2PROJECT:
|
|
||||||
try:
|
try:
|
||||||
module = import_module(f'{scope_name}.utils')
|
import_module(f'{scope_name}.registry')
|
||||||
module.register_all_modules(False) # type: ignore
|
|
||||||
except (ImportError, AttributeError, ModuleNotFoundError):
|
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||||
if scope in PKG2PROJECT:
|
if scope in PKG2PROJECT:
|
||||||
print_log(
|
print_log(
|
||||||
@ -256,9 +270,9 @@ class Registry:
|
|||||||
level=logging.WARNING)
|
level=logging.WARNING)
|
||||||
else:
|
else:
|
||||||
print_log(
|
print_log(
|
||||||
f'Failed to import {scope} and register '
|
f'Failed to import `{scope}.registry` '
|
||||||
'its modules, please make sure you '
|
f'make sure the registry.py exists in `{scope}` '
|
||||||
'have registered the module manually.',
|
'package.',
|
||||||
logger='current',
|
logger='current',
|
||||||
level=logging.WARNING)
|
level=logging.WARNING)
|
||||||
root = self._get_root_registry()
|
root = self._get_root_registry()
|
||||||
@ -290,6 +304,59 @@ class Registry:
|
|||||||
root = root.parent
|
root = root.parent
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
def import_from_location(self) -> None:
|
||||||
|
"""import modules from the pre-defined locations in self._location."""
|
||||||
|
if not self._imported:
|
||||||
|
# Avoid circular import
|
||||||
|
from ..logging import print_log
|
||||||
|
|
||||||
|
# avoid BC breaking
|
||||||
|
if len(self._locations) == 0 and self.scope in PKG2PROJECT:
|
||||||
|
print_log(
|
||||||
|
f'The "{self.name}" registry in {self.scope} did not '
|
||||||
|
'set import location. Fallback to call '
|
||||||
|
f'`{self.scope}.utils.register_all_modules` '
|
||||||
|
'instead.',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
|
try:
|
||||||
|
module = import_module(f'{self.scope}.utils')
|
||||||
|
module.register_all_modules(False) # type: ignore
|
||||||
|
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||||
|
if self.scope in PKG2PROJECT:
|
||||||
|
print_log(
|
||||||
|
f'{self.scope} is not installed and its '
|
||||||
|
'modules will not be registered. If you '
|
||||||
|
'want to use modules defined in '
|
||||||
|
f'{self.scope}, Please install {self.scope} by '
|
||||||
|
f'`pip install {PKG2PROJECT[self.scope]}.',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
|
else:
|
||||||
|
print_log(
|
||||||
|
f'Failed to import {self.scope} and register '
|
||||||
|
'its modules, please make sure you '
|
||||||
|
'have registered the module manually.',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
|
|
||||||
|
for loc in self._locations:
|
||||||
|
try:
|
||||||
|
import_module(loc)
|
||||||
|
print_log(
|
||||||
|
f"Modules of {self.scope}'s {self.name} registry have "
|
||||||
|
f'been automatically imported from {loc}',
|
||||||
|
logger='current',
|
||||||
|
level=logging.DEBUG)
|
||||||
|
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||||
|
print_log(
|
||||||
|
f'Failed to import {loc}, please check the '
|
||||||
|
f'location of the registry {self.name} is '
|
||||||
|
'correct.',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
|
self._imported = True
|
||||||
|
|
||||||
def get(self, key: str) -> Optional[Type]:
|
def get(self, key: str) -> Optional[Type]:
|
||||||
"""Get the registry record.
|
"""Get the registry record.
|
||||||
|
|
||||||
@ -346,11 +413,14 @@ class Registry:
|
|||||||
obj_cls = None
|
obj_cls = None
|
||||||
registry_name = self.name
|
registry_name = self.name
|
||||||
scope_name = self.scope
|
scope_name = self.scope
|
||||||
|
|
||||||
|
# lazy import the modules to register them into the registry
|
||||||
|
self.import_from_location()
|
||||||
|
|
||||||
if scope is None or scope == self._scope:
|
if scope is None or scope == self._scope:
|
||||||
# get from self
|
# get from self
|
||||||
if real_key in self._module_dict:
|
if real_key in self._module_dict:
|
||||||
obj_cls = self._module_dict[real_key]
|
obj_cls = self._module_dict[real_key]
|
||||||
|
|
||||||
elif scope is None:
|
elif scope is None:
|
||||||
# try to get the target from its parent or ancestors
|
# try to get the target from its parent or ancestors
|
||||||
parent = self.parent
|
parent = self.parent
|
||||||
@ -362,24 +432,21 @@ class Registry:
|
|||||||
break
|
break
|
||||||
parent = parent.parent
|
parent = parent.parent
|
||||||
else:
|
else:
|
||||||
|
# import the registry to add the nodes into the registry tree
|
||||||
try:
|
try:
|
||||||
module = import_module(f'{scope}.utils')
|
import_module(f'{scope}.registry')
|
||||||
module.register_all_modules(False) # type: ignore
|
print_log(
|
||||||
|
f'Registry node of {scope} has been automatically '
|
||||||
|
'imported.',
|
||||||
|
logger='current',
|
||||||
|
level=logging.DEBUG)
|
||||||
except (ImportError, AttributeError, ModuleNotFoundError):
|
except (ImportError, AttributeError, ModuleNotFoundError):
|
||||||
if scope in PKG2PROJECT:
|
|
||||||
print_log(
|
print_log(
|
||||||
f'{scope} is not installed and its modules '
|
f'Cannot auto import {scope}.registry, please check '
|
||||||
'will not be registered. If you want to use '
|
f'whether the package "{scope}" is installed correctly '
|
||||||
f'modules defined in {scope}, Please install '
|
'or import the registry manually.',
|
||||||
f'{scope} by `pip install {PKG2PROJECT[scope]} ',
|
|
||||||
logger='current',
|
logger='current',
|
||||||
level=logging.WARNING)
|
level=logging.DEBUG)
|
||||||
else:
|
|
||||||
print_log(
|
|
||||||
f'Failed to import "{scope}", and register its '
|
|
||||||
f'modules. Please register {real_key} manually.',
|
|
||||||
logger='current',
|
|
||||||
level=logging.WARNING)
|
|
||||||
# get from self._children
|
# get from self._children
|
||||||
if scope in self._children:
|
if scope in self._children:
|
||||||
obj_cls = self._children[scope].get(real_key)
|
obj_cls = self._children[scope].get(real_key)
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import datetime
|
import datetime
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from mmengine.fileio import dump
|
from mmengine.fileio import dump
|
||||||
from mmengine.logging import print_log
|
from mmengine.logging import print_log
|
||||||
from . import root
|
from . import root
|
||||||
|
from .default_scope import DefaultScope
|
||||||
from .registry import Registry
|
from .registry import Registry
|
||||||
|
|
||||||
|
|
||||||
@ -90,3 +92,25 @@ def count_registered_modules(save_path: Optional[str] = None,
|
|||||||
dump(scan_data, json_path, indent=2)
|
dump(scan_data, json_path, indent=2)
|
||||||
print_log(f'Result has been saved to {json_path}', logger='current')
|
print_log(f'Result has been saved to {json_path}', logger='current')
|
||||||
return scan_data
|
return scan_data
|
||||||
|
|
||||||
|
|
||||||
|
def init_default_scope(scope: str) -> None:
|
||||||
|
"""Initialize the given default scope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope (str): The name of the default scope.
|
||||||
|
"""
|
||||||
|
never_created = DefaultScope.get_current_instance(
|
||||||
|
) is None or not DefaultScope.check_instance_created(scope)
|
||||||
|
if never_created:
|
||||||
|
DefaultScope.get_instance(scope, scope_name=scope)
|
||||||
|
return
|
||||||
|
current_scope = DefaultScope.get_current_instance() # type: ignore
|
||||||
|
if current_scope.scope_name != scope: # type: ignore
|
||||||
|
warnings.warn('The current default scope ' # type: ignore
|
||||||
|
f'"{current_scope.scope_name}" is not "{scope}", '
|
||||||
|
'`init_default_scope` will force set the current'
|
||||||
|
f'default scope to "{scope}".')
|
||||||
|
# avoid name conflict
|
||||||
|
new_instance_name = f'{scope}-{datetime.datetime.now()}'
|
||||||
|
DefaultScope.get_instance(new_instance_name, scope_name=scope)
|
||||||
|
@ -34,8 +34,7 @@ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
|||||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||||
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
|
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
|
||||||
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
|
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
|
||||||
VISUALIZERS, DefaultScope,
|
VISUALIZERS, DefaultScope)
|
||||||
count_registered_modules)
|
|
||||||
from mmengine.utils import digit_version, get_git_hash, is_seq_of
|
from mmengine.utils import digit_version, get_git_hash, is_seq_of
|
||||||
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
|
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
|
||||||
set_multi_processing)
|
set_multi_processing)
|
||||||
@ -372,11 +371,6 @@ class Runner:
|
|||||||
# Collect and log environment information.
|
# Collect and log environment information.
|
||||||
self._log_env(env_cfg)
|
self._log_env(env_cfg)
|
||||||
|
|
||||||
# collect information of all modules registered in the registries
|
|
||||||
registries_info = count_registered_modules(
|
|
||||||
self.work_dir if self.rank == 0 else None, verbose=False)
|
|
||||||
self.logger.debug(registries_info)
|
|
||||||
|
|
||||||
# Build `message_hub` for communication among components.
|
# Build `message_hub` for communication among components.
|
||||||
# `message_hub` can store log scalars (loss, learning rate) and
|
# `message_hub` can store log scalars (loss, learning rate) and
|
||||||
# runtime information (iter and epoch). Those components that do not
|
# runtime information (iter and epoch). Those components that do not
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import datetime
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from unittest import TestCase, skipIf
|
from unittest import TestCase, skipIf
|
||||||
|
|
||||||
from mmengine.registry import (Registry, count_registered_modules, root,
|
from mmengine.registry import (DefaultScope, Registry,
|
||||||
traverse_registry_tree)
|
count_registered_modules, init_default_scope,
|
||||||
|
root, traverse_registry_tree)
|
||||||
from mmengine.utils import is_installed
|
from mmengine.utils import is_installed
|
||||||
|
|
||||||
|
|
||||||
@ -62,3 +64,17 @@ class TestUtils(TestCase):
|
|||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
osp.exists(
|
osp.exists(
|
||||||
osp.join(temp_dir.name, 'modules_statistic_results.json')))
|
osp.join(temp_dir.name, 'modules_statistic_results.json')))
|
||||||
|
|
||||||
|
@skipIf(not is_installed('torch'), 'tests requires torch')
|
||||||
|
def test_init_default_scope(self):
|
||||||
|
# init default scope
|
||||||
|
init_default_scope('mmdet')
|
||||||
|
self.assertEqual(DefaultScope.get_current_instance().scope_name,
|
||||||
|
'mmdet')
|
||||||
|
|
||||||
|
# init default scope when another scope is init
|
||||||
|
name = f'test-{datetime.datetime.now()}'
|
||||||
|
DefaultScope.get_instance(name, scope_name='test')
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
Warning, 'The current default scope "test" is not "mmdet"'):
|
||||||
|
init_default_scope('mmdet')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user