mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add Registry (#11)
This commit is contained in:
parent
166ca02363
commit
cccd20a636
4
docs/en/api.rst
Normal file
4
docs/en/api.rst
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
Registry
|
||||||
|
--------
|
||||||
|
.. automodule:: mmengine.registry
|
||||||
|
:members:
|
@ -8,6 +8,12 @@ You can switch between Chinese and English documents in the lower-left corner of
|
|||||||
|
|
||||||
tutorials/registry.md
|
tutorials/registry.md
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: API Reference
|
||||||
|
|
||||||
|
api.rst
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: Switch Language
|
:caption: Switch Language
|
||||||
|
|
||||||
|
4
docs/zh_cn/api.rst
Normal file
4
docs/zh_cn/api.rst
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
Registry
|
||||||
|
--------
|
||||||
|
.. automodule:: mmengine.registry
|
||||||
|
:members:
|
@ -9,6 +9,12 @@
|
|||||||
tutorials/registry.md
|
tutorials/registry.md
|
||||||
tutorials/config.md
|
tutorials/config.md
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: API 文档
|
||||||
|
|
||||||
|
api.rst
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: 语言切换
|
:caption: 语言切换
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
from .fileio import *
|
from .fileio import *
|
||||||
|
from .registry import *
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
11
mmengine/registry/__init__.py
Normal file
11
mmengine/registry/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .registry import Registry, build_from_cfg
|
||||||
|
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS,
|
||||||
|
OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, RUNNER_CONSTRUCTORS,
|
||||||
|
RUNNERS, TASK_UTILS, TRANSFORMS, WEIGHT_INITIALIZERS)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
|
||||||
|
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
|
||||||
|
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS'
|
||||||
|
]
|
491
mmengine/registry/registry.py
Normal file
491
mmengine/registry/registry.py
Normal file
@ -0,0 +1,491 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from ..config import Config, ConfigDict
|
||||||
|
from ..utils import is_seq_of
|
||||||
|
|
||||||
|
|
||||||
|
def build_from_cfg(
|
||||||
|
cfg: Union[dict, ConfigDict, Config],
|
||||||
|
registry: 'Registry',
|
||||||
|
default_args: Optional[Union[dict, ConfigDict,
|
||||||
|
Config]] = None) -> object:
|
||||||
|
"""Build a module from config dict.
|
||||||
|
|
||||||
|
At least one of the ``cfg`` and ``default_args`` contains the key "type"
|
||||||
|
which type should be either str or class. If they all contain it, the key
|
||||||
|
in ``cfg`` will be used because ``cfg`` has a high priority than
|
||||||
|
``default_args`` that means if a key exist in both of them, the value of
|
||||||
|
the key will be ``cfg[key]``. They will be merged first and the key "type"
|
||||||
|
will be popped up and the remaining keys will be used as initialization
|
||||||
|
arguments.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mmengine import Registry, build_from_cfg
|
||||||
|
>>> MODELS = Registry('models')
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> def __init__(self, depth, stages=4):
|
||||||
|
>>> self.depth = depth
|
||||||
|
>>> self.stages = stages
|
||||||
|
>>> cfg = dict(type='ResNet', depth=50)
|
||||||
|
>>> model = build_from_cfg(cfg, MODELS)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict or ConfigDict or Config): Config dict. It should at least
|
||||||
|
contain the key "type".
|
||||||
|
registry (:obj:`Registry`): The registry to search the type from.
|
||||||
|
default_args (dict or ConfigDict or Config, optional): Default
|
||||||
|
initialization arguments. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The constructed object.
|
||||||
|
"""
|
||||||
|
if not isinstance(cfg, (dict, ConfigDict, Config)):
|
||||||
|
raise TypeError(
|
||||||
|
f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}')
|
||||||
|
|
||||||
|
if 'type' not in cfg:
|
||||||
|
if default_args is None or 'type' not in default_args:
|
||||||
|
raise KeyError(
|
||||||
|
'`cfg` or `default_args` must contain the key "type", '
|
||||||
|
f'but got {cfg}\n{default_args}')
|
||||||
|
|
||||||
|
if not isinstance(registry, Registry):
|
||||||
|
raise TypeError('registry must be a mmengine.Registry object, '
|
||||||
|
f'but got {type(registry)}')
|
||||||
|
|
||||||
|
if not (isinstance(default_args,
|
||||||
|
(dict, ConfigDict, Config)) or default_args is None):
|
||||||
|
raise TypeError(
|
||||||
|
'default_args should be a dict, ConfigDict, Config or None, '
|
||||||
|
f'but got {type(default_args)}')
|
||||||
|
|
||||||
|
args = cfg.copy()
|
||||||
|
|
||||||
|
if default_args is not None:
|
||||||
|
for name, value in default_args.items():
|
||||||
|
args.setdefault(name, value)
|
||||||
|
|
||||||
|
obj_type = args.pop('type')
|
||||||
|
if isinstance(obj_type, str):
|
||||||
|
obj_cls = registry.get(obj_type)
|
||||||
|
if obj_cls is None:
|
||||||
|
raise KeyError(
|
||||||
|
f'{obj_type} is not in the {registry.name} registry. '
|
||||||
|
f'Please check whether the value of `{obj_type}` is correct or'
|
||||||
|
' it was registered as expected. More details can be found at'
|
||||||
|
' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
|
||||||
|
)
|
||||||
|
elif inspect.isclass(obj_type):
|
||||||
|
obj_cls = obj_type
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
return obj_cls(**args) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
# Normal TypeError does not print class name.
|
||||||
|
raise type(e)(f'{obj_cls.__name__}: {e}') # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
"""A registry to map strings to classes.
|
||||||
|
|
||||||
|
Registered objects could be built from registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Registry name.
|
||||||
|
build_func (callable, optional): A function to construct instance
|
||||||
|
from Registry. :func:`build_from_cfg` is used if neither ``parent``
|
||||||
|
or ``build_func`` is specified. If ``parent`` is specified and
|
||||||
|
``build_func`` is not given, ``build_func`` will be inherited
|
||||||
|
from ``parent``. Defaults to None.
|
||||||
|
parent (:obj:`Registry`, optional): Parent registry. The class
|
||||||
|
registered in children registry could be built from parent.
|
||||||
|
Defaults to None.
|
||||||
|
scope (str, optional): The scope of registry. It is the key to search
|
||||||
|
for children registry. If not specified, scope will be the name of
|
||||||
|
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # define a registry
|
||||||
|
>>> MODELS = Registry('models')
|
||||||
|
>>> # registry the `ResNet` to `MODELS`
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> # build model from `MODELS`
|
||||||
|
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||||
|
|
||||||
|
>>> # hierarchical registry
|
||||||
|
>>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
|
||||||
|
>>> @DETECTORS.register_module()
|
||||||
|
>>> class FasterRCNN:
|
||||||
|
>>> pass
|
||||||
|
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
|
||||||
|
|
||||||
|
More advanced usages can be found at
|
||||||
|
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
build_func: Optional[Callable] = None,
|
||||||
|
parent: Optional['Registry'] = None,
|
||||||
|
scope: Optional[str] = None):
|
||||||
|
self._name = name
|
||||||
|
self._module_dict: Dict[str, Type] = dict()
|
||||||
|
self._children: Dict[str, 'Registry'] = dict()
|
||||||
|
|
||||||
|
if scope is not None:
|
||||||
|
assert isinstance(scope, str)
|
||||||
|
self._scope = scope
|
||||||
|
else:
|
||||||
|
self._scope = self.infer_scope() if scope is None else scope
|
||||||
|
|
||||||
|
# See https://mypy.readthedocs.io/en/stable/common_issues.html#
|
||||||
|
# variables-vs-type-aliases for the use
|
||||||
|
self.parent: Optional['Registry']
|
||||||
|
if parent is not None:
|
||||||
|
assert isinstance(parent, Registry)
|
||||||
|
parent._add_child(self)
|
||||||
|
self.parent = parent
|
||||||
|
else:
|
||||||
|
self.parent = None
|
||||||
|
|
||||||
|
# self.build_func will be set with the following priority:
|
||||||
|
# 1. build_func
|
||||||
|
# 2. parent.build_func
|
||||||
|
# 3. build_from_cfg
|
||||||
|
self.build_func: Callable
|
||||||
|
if build_func is None:
|
||||||
|
if self.parent is not None:
|
||||||
|
self.build_func = self.parent.build_func
|
||||||
|
else:
|
||||||
|
self.build_func = build_from_cfg
|
||||||
|
else:
|
||||||
|
self.build_func = build_func
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._module_dict)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return self.get(key) is not None
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
format_str = self.__class__.__name__ + \
|
||||||
|
f'(name={self._name}, ' \
|
||||||
|
f'items={self._module_dict})'
|
||||||
|
return format_str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer_scope() -> str:
|
||||||
|
"""Infer the scope of registry.
|
||||||
|
|
||||||
|
The name of the package where registry is defined will be returned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The inferred scope name.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # in mmdet/models/backbone/resnet.py
|
||||||
|
>>> MODELS = Registry('models')
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> # The scope of ``ResNet`` will be ``mmdet``.
|
||||||
|
"""
|
||||||
|
# `sys._getframe` returns the frame object that many calls below the
|
||||||
|
# top of the stack. The call stack for `infer_scope` can be listed as
|
||||||
|
# follow:
|
||||||
|
# frame-0: `infer_scope` itself
|
||||||
|
# frame-1: `__init__` of `Registry` which calls the `infer_scope`
|
||||||
|
# frame-2: Where the `Registry(...)` is called
|
||||||
|
filename = inspect.getmodule(sys._getframe(2)).__name__ # type: ignore
|
||||||
|
split_filename = filename.split('.')
|
||||||
|
return split_filename[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_scope_key(key: str) -> Tuple[Optional[str], str]:
|
||||||
|
"""Split scope and key.
|
||||||
|
|
||||||
|
The first scope will be split from key.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
tuple[str | None, str]: The former element is the first scope of
|
||||||
|
the key, which can be ``None``. The latter is the remaining key.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> Registry.split_scope_key('mmdet.ResNet')
|
||||||
|
'mmdet', 'ResNet'
|
||||||
|
>>> Registry.split_scope_key('ResNet')
|
||||||
|
None, 'ResNet'
|
||||||
|
"""
|
||||||
|
split_index = key.find('.')
|
||||||
|
if split_index != -1:
|
||||||
|
return key[:split_index], key[split_index + 1:]
|
||||||
|
else:
|
||||||
|
return None, key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scope(self):
|
||||||
|
return self._scope
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module_dict(self):
|
||||||
|
return self._module_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def children(self):
|
||||||
|
return self._children
|
||||||
|
|
||||||
|
def _get_root_registry(self) -> 'Registry':
|
||||||
|
"""Return the root registry."""
|
||||||
|
root = self
|
||||||
|
while root.parent is not None:
|
||||||
|
root = root.parent
|
||||||
|
return root
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[Type]:
|
||||||
|
"""Get the registry record.
|
||||||
|
|
||||||
|
The method will first parse :attr:`key` and check whether it contains
|
||||||
|
a scope name. The logic to search for :attr:`key`:
|
||||||
|
|
||||||
|
- ``key`` does not contain a scope name, i.e., it is purely a module
|
||||||
|
name like "ResNet": :meth:`get` will search for ``ResNet`` from the
|
||||||
|
current registry to its parent or ancestors until finding it.
|
||||||
|
|
||||||
|
- ``key`` contains a scope name and it is equal to the scope of the
|
||||||
|
current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get`
|
||||||
|
will only search for ``ResNet`` in the current registry.
|
||||||
|
|
||||||
|
- ``key`` contains a scope name and it is not equal to the scope of
|
||||||
|
the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the
|
||||||
|
scope exists in its children, :meth:`get`will get "FCNet" from
|
||||||
|
them. If not, :meth:`get` will first get the root registry and root
|
||||||
|
registry call its own :meth:`get` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str): Name of the registered item, e.g., the class name in
|
||||||
|
string format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type or None: Return the corresponding class if ``key`` exists,
|
||||||
|
otherwise return None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # define a registry
|
||||||
|
>>> MODELS = Registry('models')
|
||||||
|
>>> # register `ResNet` to `MODELS`
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> resnet_cls = MODELS.get('ResNet')
|
||||||
|
|
||||||
|
>>> # hierarchical registry
|
||||||
|
>>> DETECTORS = Registry('detector', parent=MODELS, scope='det')
|
||||||
|
>>> # `ResNet` does not exist in `DETECTORS` but `get` method
|
||||||
|
>>> # will try to search from its parenet or ancestors
|
||||||
|
>>> resnet_cls = DETECTORS.get('ResNet')
|
||||||
|
>>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls')
|
||||||
|
>>> @CLASSIFIER.register_module()
|
||||||
|
>>> class MobileNet:
|
||||||
|
>>> pass
|
||||||
|
>>> # `get` from its sibling registries
|
||||||
|
>>> mobilenet_cls = DETECTORS.get('cls.MobileNet')
|
||||||
|
"""
|
||||||
|
scope, real_key = self.split_scope_key(key)
|
||||||
|
if scope is None or scope == self._scope:
|
||||||
|
# get from self
|
||||||
|
if real_key in self._module_dict:
|
||||||
|
return self._module_dict[real_key]
|
||||||
|
|
||||||
|
if scope is None:
|
||||||
|
# try to get the target from its parent or ancestors
|
||||||
|
parent = self.parent
|
||||||
|
while parent is not None:
|
||||||
|
if real_key in parent._module_dict:
|
||||||
|
return parent._module_dict[real_key]
|
||||||
|
parent = parent.parent
|
||||||
|
else:
|
||||||
|
# get from self._children
|
||||||
|
if scope in self._children:
|
||||||
|
return self._children[scope].get(real_key)
|
||||||
|
else:
|
||||||
|
root = self._get_root_registry()
|
||||||
|
return root.get(key)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _search_child(self, scope: str) -> Optional['Registry']:
|
||||||
|
"""Depth-first search for the corresponding registry in its children.
|
||||||
|
|
||||||
|
Note that the method only search for the corresponding registry from
|
||||||
|
the current registry. Therefore, if we want to search from the root
|
||||||
|
registry, :meth:`_get_root_registry` should be called to get the
|
||||||
|
root registry first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope (str): The scope name used for searching for its
|
||||||
|
corresponding registry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Registry or None: Return the corresponding registry if ``scope``
|
||||||
|
exists, otherwise return None.
|
||||||
|
"""
|
||||||
|
if self._scope == scope:
|
||||||
|
return self
|
||||||
|
|
||||||
|
for child in self._children.values():
|
||||||
|
registry = child._search_child(scope)
|
||||||
|
if registry is not None:
|
||||||
|
return registry
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
*args,
|
||||||
|
default_scope: Optional[str] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""Build an instance.
|
||||||
|
|
||||||
|
Build an instance by calling :attr:`build_func`. If
|
||||||
|
:attr:`default_scope` is given, :meth:`build` will firstly get the
|
||||||
|
responding registry and then call its own :meth:`build`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_scope (str, optional): The ``default_scope`` is used to
|
||||||
|
reset the current registry. Defaults to None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mmengine import Registry
|
||||||
|
>>> MODELS = Registry('models')
|
||||||
|
>>> @MODELS.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> def __init__(self, depth, stages=4):
|
||||||
|
>>> self.depth = depth
|
||||||
|
>>> self.stages = stages
|
||||||
|
>>> cfg = dict(type='ResNet', depth=50)
|
||||||
|
>>> model = MODELS.build(cfg)
|
||||||
|
"""
|
||||||
|
if default_scope is not None:
|
||||||
|
root = self._get_root_registry()
|
||||||
|
registry = root._search_child(default_scope)
|
||||||
|
if registry is None:
|
||||||
|
raise KeyError(
|
||||||
|
f'{default_scope} does not exist in the registry tree.')
|
||||||
|
else:
|
||||||
|
registry = self
|
||||||
|
|
||||||
|
return registry.build_func(*args, **kwargs, registry=registry)
|
||||||
|
|
||||||
|
def _add_child(self, registry: 'Registry') -> None:
|
||||||
|
"""Add a child for a registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry (:obj:`Registry`): The ``registry`` will be added as a
|
||||||
|
child of the ``self``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(registry, Registry)
|
||||||
|
assert registry.scope is not None
|
||||||
|
assert registry.scope not in self.children, \
|
||||||
|
f'scope {registry.scope} exists in {self.name} registry'
|
||||||
|
self.children[registry.scope] = registry
|
||||||
|
|
||||||
|
def _register_module(self,
|
||||||
|
module_class: Type,
|
||||||
|
module_name: Optional[Union[str, List[str]]] = None,
|
||||||
|
force: bool = False) -> None:
|
||||||
|
"""Register a module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_class (type): Module class to be registered.
|
||||||
|
module_name (str or list of str, optional): The module name to be
|
||||||
|
registered. If not specified, the class name will be used.
|
||||||
|
Defaults to None.
|
||||||
|
force (bool): Whether to override an existing class with the same
|
||||||
|
name. Defaults to False.
|
||||||
|
"""
|
||||||
|
if not inspect.isclass(module_class):
|
||||||
|
raise TypeError('module must be a class, '
|
||||||
|
f'but got {type(module_class)}')
|
||||||
|
|
||||||
|
if module_name is None:
|
||||||
|
module_name = module_class.__name__
|
||||||
|
if isinstance(module_name, str):
|
||||||
|
module_name = [module_name]
|
||||||
|
for name in module_name:
|
||||||
|
if not force and name in self._module_dict:
|
||||||
|
raise KeyError(f'{name} is already registered '
|
||||||
|
f'in {self.name}')
|
||||||
|
self._module_dict[name] = module_class
|
||||||
|
|
||||||
|
def register_module(
|
||||||
|
self,
|
||||||
|
name: Optional[Union[str, List[str]]] = None,
|
||||||
|
force: bool = False,
|
||||||
|
module: Optional[Type] = None) -> Union[type, Callable]:
|
||||||
|
"""Register a module.
|
||||||
|
|
||||||
|
A record will be added to ``self._module_dict``, whose key is the class
|
||||||
|
name or the specified name, and value is the class itself.
|
||||||
|
It can be used as a decorator or a normal function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str or list of str, optional): The module name to be
|
||||||
|
registered. If not specified, the class name will be used.
|
||||||
|
force (bool): Whether to override an existing class with the same
|
||||||
|
name. Default to False.
|
||||||
|
module (type, optional): Module class to be registered. Defaults to
|
||||||
|
None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> backbones = Registry('backbone')
|
||||||
|
>>> # as a decorator
|
||||||
|
>>> @backbones.register_module()
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> backbones = Registry('backbone')
|
||||||
|
>>> @backbones.register_module(name='mnet')
|
||||||
|
>>> class MobileNet:
|
||||||
|
>>> pass
|
||||||
|
|
||||||
|
>>> # as a normal function
|
||||||
|
>>> class ResNet:
|
||||||
|
>>> pass
|
||||||
|
>>> backbones.register_module(module=ResNet)
|
||||||
|
"""
|
||||||
|
if not isinstance(force, bool):
|
||||||
|
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||||
|
|
||||||
|
# raise the error ahead of time
|
||||||
|
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
||||||
|
raise TypeError(
|
||||||
|
'name must be None, an instance of str, or a sequence of str, '
|
||||||
|
f'but got {type(name)}')
|
||||||
|
|
||||||
|
# use it as a normal method: x.register_module(module=SomeClass)
|
||||||
|
if module is not None:
|
||||||
|
self._register_module(
|
||||||
|
module_class=module, module_name=name, force=force)
|
||||||
|
return module
|
||||||
|
|
||||||
|
# use it as a decorator: @x.register_module()
|
||||||
|
def _register(cls):
|
||||||
|
self._register_module(
|
||||||
|
module_class=cls, module_name=name, force=force)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return _register
|
34
mmengine/registry/root.py
Normal file
34
mmengine/registry/root.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
"""MMEngine provides 11 root registries to support using modules across
|
||||||
|
projects.
|
||||||
|
|
||||||
|
More datails can be found at
|
||||||
|
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .registry import Registry
|
||||||
|
|
||||||
|
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||||
|
RUNNERS = Registry('runner')
|
||||||
|
# manage runner constructors that define how to initialize runners
|
||||||
|
RUNNER_CONSTRUCTORS = Registry('runner constructor')
|
||||||
|
# manage all kinds of hooks like `CheckpointHook`
|
||||||
|
HOOKS = Registry('hook')
|
||||||
|
|
||||||
|
# manage data-related modules
|
||||||
|
DATASETS = Registry('dataset')
|
||||||
|
DATA_SAMPLERS = Registry('data sampler')
|
||||||
|
TRANSFORMS = Registry('transform')
|
||||||
|
|
||||||
|
# mangage all kinds of modules inheriting `nn.Module`
|
||||||
|
MODELS = Registry('model')
|
||||||
|
# mangage all kinds of weight initialization modules like `Uniform`
|
||||||
|
WEIGHT_INITIALIZERS = Registry('weight initializer')
|
||||||
|
|
||||||
|
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||||
|
OPTIMIZERS = Registry('optimizer')
|
||||||
|
# manage constructors that customize the optimization hyperparameters.
|
||||||
|
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
|
||||||
|
|
||||||
|
# manage task-specific modules like anchor generators and box coders
|
||||||
|
TASK_UTILS = Registry('task util')
|
@ -8,6 +8,7 @@ from collections import abc
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from inspect import getfullargspec
|
from inspect import getfullargspec
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
|
from typing import Sequence, Type
|
||||||
|
|
||||||
|
|
||||||
# From PyTorch internals
|
# From PyTorch internals
|
||||||
@ -125,16 +126,26 @@ def tuple_cast(inputs, dst_type):
|
|||||||
return iter_cast(inputs, dst_type, return_type=tuple)
|
return iter_cast(inputs, dst_type, return_type=tuple)
|
||||||
|
|
||||||
|
|
||||||
def is_seq_of(seq, expected_type, seq_type=None):
|
def is_seq_of(seq: Sequence,
|
||||||
|
expected_type: Type,
|
||||||
|
seq_type: Type = None) -> bool:
|
||||||
"""Check whether it is a sequence of some type.
|
"""Check whether it is a sequence of some type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq (Sequence): The sequence to be checked.
|
seq (Sequence): The sequence to be checked.
|
||||||
expected_type (type): Expected type of sequence items.
|
expected_type (type): Expected type of sequence items.
|
||||||
seq_type (type, optional): Expected sequence type.
|
seq_type (type, optional): Expected sequence type. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether the sequence is valid.
|
bool: Return True if ``seq`` is valid else False.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mmengine.utils import is_seq_of
|
||||||
|
>>> seq = ['a', 'b', 'c']
|
||||||
|
>>> is_seq_of(seq, str)
|
||||||
|
True
|
||||||
|
>>> is_seq_of(seq, int)
|
||||||
|
False
|
||||||
"""
|
"""
|
||||||
if seq_type is None:
|
if seq_type is None:
|
||||||
exp_seq_type = abc.Sequence
|
exp_seq_type = abc.Sequence
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mmengine import Config, ConfigDict, Registry, build_from_cfg
|
from mmengine.config import Config, ConfigDict # type: ignore
|
||||||
|
from mmengine.registry import Registry, build_from_cfg
|
||||||
|
|
||||||
|
|
||||||
class TestRegistry:
|
class TestRegistry:
|
||||||
@ -91,8 +92,8 @@ class TestRegistry:
|
|||||||
# `name` is an invalid type
|
# `name` is an invalid type
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match=('name must be either of None, an instance of str or a '
|
match=('name must be None, an instance of str, or a sequence '
|
||||||
"sequence of str, but got <class 'int'>")):
|
"of str, but got <class 'int'>")):
|
||||||
|
|
||||||
@CATS.register_module(name=7474741)
|
@CATS.register_module(name=7474741)
|
||||||
class SiameseCat:
|
class SiameseCat:
|
||||||
@ -149,8 +150,8 @@ class TestRegistry:
|
|||||||
assert len(CATS) == 8
|
assert len(CATS) == 8
|
||||||
|
|
||||||
def _build_registry(self):
|
def _build_registry(self):
|
||||||
"""A helper function to build a hierarchy registry."""
|
"""A helper function to build a Hierarchical Registry."""
|
||||||
# Hierarchy Registry
|
# Hierarchical Registry
|
||||||
# DOGS
|
# DOGS
|
||||||
# _______|_______
|
# _______|_______
|
||||||
# | |
|
# | |
|
||||||
@ -177,7 +178,7 @@ class TestRegistry:
|
|||||||
return registries
|
return registries
|
||||||
|
|
||||||
def test_get_root_registry(self):
|
def test_get_root_registry(self):
|
||||||
# Hierarchy Registry
|
# Hierarchical Registry
|
||||||
# DOGS
|
# DOGS
|
||||||
# _______|_______
|
# _______|_______
|
||||||
# | |
|
# | |
|
||||||
@ -186,7 +187,8 @@ class TestRegistry:
|
|||||||
# | | |
|
# | | |
|
||||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||||
# (little_hound) (mid_hound) (little_samoyed)
|
# (little_hound) (mid_hound) (little_samoyed)
|
||||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
|
registries = self._build_registry()
|
||||||
|
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
|
||||||
|
|
||||||
assert DOGS._get_root_registry() is DOGS
|
assert DOGS._get_root_registry() is DOGS
|
||||||
assert HOUNDS._get_root_registry() is DOGS
|
assert HOUNDS._get_root_registry() is DOGS
|
||||||
@ -194,7 +196,7 @@ class TestRegistry:
|
|||||||
assert MID_HOUNDS._get_root_registry() is DOGS
|
assert MID_HOUNDS._get_root_registry() is DOGS
|
||||||
|
|
||||||
def test_get(self):
|
def test_get(self):
|
||||||
# Hierarchy Registry
|
# Hierarchical Registry
|
||||||
# DOGS
|
# DOGS
|
||||||
# _______|_______
|
# _______|_______
|
||||||
# | |
|
# | |
|
||||||
@ -277,7 +279,7 @@ class TestRegistry:
|
|||||||
) is LittlePedigreeSamoyed
|
) is LittlePedigreeSamoyed
|
||||||
|
|
||||||
def test_search_child(self):
|
def test_search_child(self):
|
||||||
# Hierarchy Registry
|
# Hierarchical Registry
|
||||||
# DOGS
|
# DOGS
|
||||||
# _______|_______
|
# _______|_______
|
||||||
# | |
|
# | |
|
||||||
@ -286,7 +288,8 @@ class TestRegistry:
|
|||||||
# | | |
|
# | | |
|
||||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||||
# (little_hound) (mid_hound) (little_samoyed)
|
# (little_hound) (mid_hound) (little_samoyed)
|
||||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
|
registries = self._build_registry()
|
||||||
|
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
|
||||||
|
|
||||||
assert DOGS._search_child('hound') is HOUNDS
|
assert DOGS._search_child('hound') is HOUNDS
|
||||||
assert DOGS._search_child('not a child') is None
|
assert DOGS._search_child('not a child') is None
|
||||||
@ -294,8 +297,9 @@ class TestRegistry:
|
|||||||
assert LITTLE_HOUNDS._search_child('hound') is None
|
assert LITTLE_HOUNDS._search_child('hound') is None
|
||||||
assert LITTLE_HOUNDS._search_child('mid_hound') is None
|
assert LITTLE_HOUNDS._search_child('mid_hound') is None
|
||||||
|
|
||||||
def test_build(self):
|
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
||||||
# Hierarchy Registry
|
def test_build(self, cfg_type):
|
||||||
|
# Hierarchical Registry
|
||||||
# DOGS
|
# DOGS
|
||||||
# _______|_______
|
# _______|_______
|
||||||
# | |
|
# | |
|
||||||
@ -311,14 +315,14 @@ class TestRegistry:
|
|||||||
class GoldenRetriever:
|
class GoldenRetriever:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
gr_cfg = dict(type='GoldenRetriever')
|
gr_cfg = cfg_type(dict(type='GoldenRetriever'))
|
||||||
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
|
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
|
||||||
|
|
||||||
@HOUNDS.register_module()
|
@HOUNDS.register_module()
|
||||||
class BloodHound:
|
class BloodHound:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
bh_cfg = dict(type='BloodHound')
|
bh_cfg = cfg_type(dict(type='BloodHound'))
|
||||||
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
|
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
|
||||||
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
|
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
|
||||||
|
|
||||||
@ -326,14 +330,14 @@ class TestRegistry:
|
|||||||
class Dachshund:
|
class Dachshund:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
d_cfg = dict(type='Dachshund')
|
d_cfg = cfg_type(dict(type='Dachshund'))
|
||||||
assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund)
|
assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund)
|
||||||
|
|
||||||
@MID_HOUNDS.register_module()
|
@MID_HOUNDS.register_module()
|
||||||
class Beagle:
|
class Beagle:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
b_cfg = dict(type='Beagle')
|
b_cfg = cfg_type(dict(type='Beagle'))
|
||||||
assert isinstance(MID_HOUNDS.build(b_cfg), Beagle)
|
assert isinstance(MID_HOUNDS.build(b_cfg), Beagle)
|
||||||
|
|
||||||
# test `default_scope`
|
# test `default_scope`
|
||||||
@ -367,7 +371,8 @@ class TestRegistry:
|
|||||||
assert repr(CATS) == repr_str
|
assert repr(CATS) == repr_str
|
||||||
|
|
||||||
|
|
||||||
def test_build_from_cfg():
|
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
||||||
|
def test_build_from_cfg(cfg_type):
|
||||||
BACKBONES = Registry('backbone')
|
BACKBONES = Registry('backbone')
|
||||||
|
|
||||||
@BACKBONES.register_module()
|
@BACKBONES.register_module()
|
||||||
@ -387,24 +392,14 @@ def test_build_from_cfg():
|
|||||||
# test `cfg` parameter
|
# test `cfg` parameter
|
||||||
# `cfg` should be a dict, ConfigDict or Config object
|
# `cfg` should be a dict, ConfigDict or Config object
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError, match="cfg must be a dict, but got <class 'str'>"):
|
TypeError,
|
||||||
|
match=('cfg should be a dict, ConfigDict or Config, but got '
|
||||||
|
"<class 'str'>")):
|
||||||
cfg = 'ResNet'
|
cfg = 'ResNet'
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
model = build_from_cfg(cfg, BACKBONES)
|
||||||
|
|
||||||
# `cfg` is a dict
|
# `cfg` is a dict, ConfigDict or Config object
|
||||||
cfg = dict(type='ResNet', depth=50)
|
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
|
||||||
assert isinstance(model, ResNet)
|
|
||||||
assert model.depth == 50 and model.stages == 4
|
|
||||||
|
|
||||||
# `cfg` is a ConfigDict object
|
|
||||||
cfg = ConfigDict(dict(type='ResNet', depth=50))
|
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
|
||||||
assert isinstance(model, ResNet)
|
|
||||||
assert model.depth == 50 and model.stages == 4
|
|
||||||
|
|
||||||
# `cfg` is a Config object
|
|
||||||
cfg = Config(dict(type='ResNet', depth=50))
|
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
model = build_from_cfg(cfg, BACKBONES)
|
||||||
assert isinstance(model, ResNet)
|
assert isinstance(model, ResNet)
|
||||||
assert model.depth == 50 and model.stages == 4
|
assert model.depth == 50 and model.stages == 4
|
||||||
@ -412,6 +407,7 @@ def test_build_from_cfg():
|
|||||||
# `cfg` is a dict but it does not contain the key "type"
|
# `cfg` is a dict but it does not contain the key "type"
|
||||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||||
cfg = dict(depth=50, stages=4)
|
cfg = dict(depth=50, stages=4)
|
||||||
|
cfg = cfg_type(cfg)
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
model = build_from_cfg(cfg, BACKBONES)
|
||||||
|
|
||||||
# cfg['type'] should be a str or class
|
# cfg['type'] should be a str or class
|
||||||
@ -419,52 +415,56 @@ def test_build_from_cfg():
|
|||||||
TypeError,
|
TypeError,
|
||||||
match="type must be a str or valid type, but got <class 'int'>"):
|
match="type must be a str or valid type, but got <class 'int'>"):
|
||||||
cfg = dict(type=1000)
|
cfg = dict(type=1000)
|
||||||
|
cfg = cfg_type(cfg)
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
model = build_from_cfg(cfg, BACKBONES)
|
||||||
|
|
||||||
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
model = build_from_cfg(cfg, BACKBONES)
|
||||||
assert isinstance(model, ResNeXt)
|
assert isinstance(model, ResNeXt)
|
||||||
assert model.depth == 50 and model.stages == 3
|
assert model.depth == 50 and model.stages == 3
|
||||||
|
|
||||||
cfg = dict(type=ResNet, depth=50)
|
cfg = cfg_type(dict(type=ResNet, depth=50))
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
model = build_from_cfg(cfg, BACKBONES)
|
||||||
assert isinstance(model, ResNet)
|
assert isinstance(model, ResNet)
|
||||||
assert model.depth == 50 and model.stages == 4
|
assert model.depth == 50 and model.stages == 4
|
||||||
|
|
||||||
# non-registered class
|
# non-registered class
|
||||||
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
|
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
|
||||||
cfg = dict(type='VGG')
|
cfg = cfg_type(dict(type='VGG'))
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
model = build_from_cfg(cfg, BACKBONES)
|
||||||
|
|
||||||
# `cfg` contains unexpected arguments
|
# `cfg` contains unexpected arguments
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
cfg = dict(type='ResNet', non_existing_arg=50)
|
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
|
||||||
model = build_from_cfg(cfg, BACKBONES)
|
model = build_from_cfg(cfg, BACKBONES)
|
||||||
|
|
||||||
# test `default_args` parameter
|
# test `default_args` parameter
|
||||||
cfg = dict(type='ResNet', depth=50)
|
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||||
model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
|
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
|
||||||
assert isinstance(model, ResNet)
|
assert isinstance(model, ResNet)
|
||||||
assert model.depth == 50 and model.stages == 3
|
assert model.depth == 50 and model.stages == 3
|
||||||
|
|
||||||
# default_args must be a dict or None
|
# default_args must be a dict or None
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
cfg = dict(type='ResNet', depth=50)
|
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||||
model = build_from_cfg(cfg, BACKBONES, default_args=1)
|
model = build_from_cfg(cfg, BACKBONES, default_args=1)
|
||||||
|
|
||||||
# cfg or default_args should contain the key "type"
|
# cfg or default_args should contain the key "type"
|
||||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||||
cfg = dict(depth=50)
|
cfg = cfg_type(dict(depth=50))
|
||||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4))
|
model = build_from_cfg(
|
||||||
|
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
|
||||||
|
|
||||||
# "type" defined using default_args
|
# "type" defined using default_args
|
||||||
cfg = dict(depth=50)
|
cfg = cfg_type(dict(depth=50))
|
||||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(type='ResNet'))
|
model = build_from_cfg(
|
||||||
|
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
|
||||||
assert isinstance(model, ResNet)
|
assert isinstance(model, ResNet)
|
||||||
assert model.depth == 50 and model.stages == 4
|
assert model.depth == 50 and model.stages == 4
|
||||||
|
|
||||||
cfg = dict(depth=50)
|
cfg = cfg_type(dict(depth=50))
|
||||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
|
model = build_from_cfg(
|
||||||
|
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
|
||||||
assert isinstance(model, ResNet)
|
assert isinstance(model, ResNet)
|
||||||
assert model.depth == 50 and model.stages == 4
|
assert model.depth == 50 and model.stages == 4
|
||||||
|
|
||||||
@ -474,5 +474,5 @@ def test_build_from_cfg():
|
|||||||
TypeError,
|
TypeError,
|
||||||
match=('registry must be a mmengine.Registry object, but got '
|
match=('registry must be a mmengine.Registry object, but got '
|
||||||
"<class 'str'>")):
|
"<class 'str'>")):
|
||||||
cfg = dict(type='ResNet', depth=50)
|
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||||
model = build_from_cfg(cfg, 'BACKBONES')
|
model = build_from_cfg(cfg, 'BACKBONES')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user