mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add Registry (#11)
This commit is contained in:
parent
166ca02363
commit
cccd20a636
4
docs/en/api.rst
Normal file
4
docs/en/api.rst
Normal file
@ -0,0 +1,4 @@
|
||||
Registry
|
||||
--------
|
||||
.. automodule:: mmengine.registry
|
||||
:members:
|
@ -8,6 +8,12 @@ You can switch between Chinese and English documents in the lower-left corner of
|
||||
|
||||
tutorials/registry.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: API Reference
|
||||
|
||||
api.rst
|
||||
|
||||
.. toctree::
|
||||
:caption: Switch Language
|
||||
|
||||
|
4
docs/zh_cn/api.rst
Normal file
4
docs/zh_cn/api.rst
Normal file
@ -0,0 +1,4 @@
|
||||
Registry
|
||||
--------
|
||||
.. automodule:: mmengine.registry
|
||||
:members:
|
@ -9,6 +9,12 @@
|
||||
tutorials/registry.md
|
||||
tutorials/config.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: API 文档
|
||||
|
||||
api.rst
|
||||
|
||||
.. toctree::
|
||||
:caption: 语言切换
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
from .fileio import *
|
||||
from .registry import *
|
||||
from .utils import *
|
||||
|
11
mmengine/registry/__init__.py
Normal file
11
mmengine/registry/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS,
|
||||
OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, RUNNER_CONSTRUCTORS,
|
||||
RUNNERS, TASK_UTILS, TRANSFORMS, WEIGHT_INITIALIZERS)
|
||||
|
||||
__all__ = [
|
||||
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
|
||||
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
|
||||
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS'
|
||||
]
|
491
mmengine/registry/registry.py
Normal file
491
mmengine/registry/registry.py
Normal file
@ -0,0 +1,491 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ..config import Config, ConfigDict
|
||||
from ..utils import is_seq_of
|
||||
|
||||
|
||||
def build_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: 'Registry',
|
||||
default_args: Optional[Union[dict, ConfigDict,
|
||||
Config]] = None) -> object:
|
||||
"""Build a module from config dict.
|
||||
|
||||
At least one of the ``cfg`` and ``default_args`` contains the key "type"
|
||||
which type should be either str or class. If they all contain it, the key
|
||||
in ``cfg`` will be used because ``cfg`` has a high priority than
|
||||
``default_args`` that means if a key exist in both of them, the value of
|
||||
the key will be ``cfg[key]``. They will be merged first and the key "type"
|
||||
will be popped up and the remaining keys will be used as initialization
|
||||
arguments.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine import Registry, build_from_cfg
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> def __init__(self, depth, stages=4):
|
||||
>>> self.depth = depth
|
||||
>>> self.stages = stages
|
||||
>>> cfg = dict(type='ResNet', depth=50)
|
||||
>>> model = build_from_cfg(cfg, MODELS)
|
||||
|
||||
Args:
|
||||
cfg (dict or ConfigDict or Config): Config dict. It should at least
|
||||
contain the key "type".
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
default_args (dict or ConfigDict or Config, optional): Default
|
||||
initialization arguments. Defaults to None.
|
||||
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if not isinstance(cfg, (dict, ConfigDict, Config)):
|
||||
raise TypeError(
|
||||
f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}')
|
||||
|
||||
if 'type' not in cfg:
|
||||
if default_args is None or 'type' not in default_args:
|
||||
raise KeyError(
|
||||
'`cfg` or `default_args` must contain the key "type", '
|
||||
f'but got {cfg}\n{default_args}')
|
||||
|
||||
if not isinstance(registry, Registry):
|
||||
raise TypeError('registry must be a mmengine.Registry object, '
|
||||
f'but got {type(registry)}')
|
||||
|
||||
if not (isinstance(default_args,
|
||||
(dict, ConfigDict, Config)) or default_args is None):
|
||||
raise TypeError(
|
||||
'default_args should be a dict, ConfigDict, Config or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
args = cfg.copy()
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError(
|
||||
f'{obj_type} is not in the {registry.name} registry. '
|
||||
f'Please check whether the value of `{obj_type}` is correct or'
|
||||
' it was registered as expected. More details can be found at'
|
||||
' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
|
||||
)
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
|
||||
try:
|
||||
return obj_cls(**args) # type: ignore
|
||||
except Exception as e:
|
||||
# Normal TypeError does not print class name.
|
||||
raise type(e)(f'{obj_cls.__name__}: {e}') # type: ignore
|
||||
|
||||
|
||||
class Registry:
|
||||
"""A registry to map strings to classes.
|
||||
|
||||
Registered objects could be built from registry.
|
||||
|
||||
Args:
|
||||
name (str): Registry name.
|
||||
build_func (callable, optional): A function to construct instance
|
||||
from Registry. :func:`build_from_cfg` is used if neither ``parent``
|
||||
or ``build_func`` is specified. If ``parent`` is specified and
|
||||
``build_func`` is not given, ``build_func`` will be inherited
|
||||
from ``parent``. Defaults to None.
|
||||
parent (:obj:`Registry`, optional): Parent registry. The class
|
||||
registered in children registry could be built from parent.
|
||||
Defaults to None.
|
||||
scope (str, optional): The scope of registry. It is the key to search
|
||||
for children registry. If not specified, scope will be the name of
|
||||
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> # define a registry
|
||||
>>> MODELS = Registry('models')
|
||||
>>> # registry the `ResNet` to `MODELS`
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> # build model from `MODELS`
|
||||
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||
|
||||
>>> # hierarchical registry
|
||||
>>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
|
||||
>>> @DETECTORS.register_module()
|
||||
>>> class FasterRCNN:
|
||||
>>> pass
|
||||
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
|
||||
|
||||
More advanced usages can be found at
|
||||
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
build_func: Optional[Callable] = None,
|
||||
parent: Optional['Registry'] = None,
|
||||
scope: Optional[str] = None):
|
||||
self._name = name
|
||||
self._module_dict: Dict[str, Type] = dict()
|
||||
self._children: Dict[str, 'Registry'] = dict()
|
||||
|
||||
if scope is not None:
|
||||
assert isinstance(scope, str)
|
||||
self._scope = scope
|
||||
else:
|
||||
self._scope = self.infer_scope() if scope is None else scope
|
||||
|
||||
# See https://mypy.readthedocs.io/en/stable/common_issues.html#
|
||||
# variables-vs-type-aliases for the use
|
||||
self.parent: Optional['Registry']
|
||||
if parent is not None:
|
||||
assert isinstance(parent, Registry)
|
||||
parent._add_child(self)
|
||||
self.parent = parent
|
||||
else:
|
||||
self.parent = None
|
||||
|
||||
# self.build_func will be set with the following priority:
|
||||
# 1. build_func
|
||||
# 2. parent.build_func
|
||||
# 3. build_from_cfg
|
||||
self.build_func: Callable
|
||||
if build_func is None:
|
||||
if self.parent is not None:
|
||||
self.build_func = self.parent.build_func
|
||||
else:
|
||||
self.build_func = build_from_cfg
|
||||
else:
|
||||
self.build_func = build_func
|
||||
|
||||
def __len__(self):
|
||||
return len(self._module_dict)
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.get(key) is not None
|
||||
|
||||
def __repr__(self):
|
||||
format_str = self.__class__.__name__ + \
|
||||
f'(name={self._name}, ' \
|
||||
f'items={self._module_dict})'
|
||||
return format_str
|
||||
|
||||
@staticmethod
|
||||
def infer_scope() -> str:
|
||||
"""Infer the scope of registry.
|
||||
|
||||
The name of the package where registry is defined will be returned.
|
||||
|
||||
Returns:
|
||||
str: The inferred scope name.
|
||||
|
||||
Examples:
|
||||
>>> # in mmdet/models/backbone/resnet.py
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> # The scope of ``ResNet`` will be ``mmdet``.
|
||||
"""
|
||||
# `sys._getframe` returns the frame object that many calls below the
|
||||
# top of the stack. The call stack for `infer_scope` can be listed as
|
||||
# follow:
|
||||
# frame-0: `infer_scope` itself
|
||||
# frame-1: `__init__` of `Registry` which calls the `infer_scope`
|
||||
# frame-2: Where the `Registry(...)` is called
|
||||
filename = inspect.getmodule(sys._getframe(2)).__name__ # type: ignore
|
||||
split_filename = filename.split('.')
|
||||
return split_filename[0]
|
||||
|
||||
@staticmethod
|
||||
def split_scope_key(key: str) -> Tuple[Optional[str], str]:
|
||||
"""Split scope and key.
|
||||
|
||||
The first scope will be split from key.
|
||||
|
||||
Return:
|
||||
tuple[str | None, str]: The former element is the first scope of
|
||||
the key, which can be ``None``. The latter is the remaining key.
|
||||
|
||||
Examples:
|
||||
>>> Registry.split_scope_key('mmdet.ResNet')
|
||||
'mmdet', 'ResNet'
|
||||
>>> Registry.split_scope_key('ResNet')
|
||||
None, 'ResNet'
|
||||
"""
|
||||
split_index = key.find('.')
|
||||
if split_index != -1:
|
||||
return key[:split_index], key[split_index + 1:]
|
||||
else:
|
||||
return None, key
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def scope(self):
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def module_dict(self):
|
||||
return self._module_dict
|
||||
|
||||
@property
|
||||
def children(self):
|
||||
return self._children
|
||||
|
||||
def _get_root_registry(self) -> 'Registry':
|
||||
"""Return the root registry."""
|
||||
root = self
|
||||
while root.parent is not None:
|
||||
root = root.parent
|
||||
return root
|
||||
|
||||
def get(self, key: str) -> Optional[Type]:
|
||||
"""Get the registry record.
|
||||
|
||||
The method will first parse :attr:`key` and check whether it contains
|
||||
a scope name. The logic to search for :attr:`key`:
|
||||
|
||||
- ``key`` does not contain a scope name, i.e., it is purely a module
|
||||
name like "ResNet": :meth:`get` will search for ``ResNet`` from the
|
||||
current registry to its parent or ancestors until finding it.
|
||||
|
||||
- ``key`` contains a scope name and it is equal to the scope of the
|
||||
current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get`
|
||||
will only search for ``ResNet`` in the current registry.
|
||||
|
||||
- ``key`` contains a scope name and it is not equal to the scope of
|
||||
the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the
|
||||
scope exists in its children, :meth:`get`will get "FCNet" from
|
||||
them. If not, :meth:`get` will first get the root registry and root
|
||||
registry call its own :meth:`get` method.
|
||||
|
||||
Args:
|
||||
key (str): Name of the registered item, e.g., the class name in
|
||||
string format.
|
||||
|
||||
Returns:
|
||||
Type or None: Return the corresponding class if ``key`` exists,
|
||||
otherwise return None.
|
||||
|
||||
Examples:
|
||||
>>> # define a registry
|
||||
>>> MODELS = Registry('models')
|
||||
>>> # register `ResNet` to `MODELS`
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet_cls = MODELS.get('ResNet')
|
||||
|
||||
>>> # hierarchical registry
|
||||
>>> DETECTORS = Registry('detector', parent=MODELS, scope='det')
|
||||
>>> # `ResNet` does not exist in `DETECTORS` but `get` method
|
||||
>>> # will try to search from its parenet or ancestors
|
||||
>>> resnet_cls = DETECTORS.get('ResNet')
|
||||
>>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls')
|
||||
>>> @CLASSIFIER.register_module()
|
||||
>>> class MobileNet:
|
||||
>>> pass
|
||||
>>> # `get` from its sibling registries
|
||||
>>> mobilenet_cls = DETECTORS.get('cls.MobileNet')
|
||||
"""
|
||||
scope, real_key = self.split_scope_key(key)
|
||||
if scope is None or scope == self._scope:
|
||||
# get from self
|
||||
if real_key in self._module_dict:
|
||||
return self._module_dict[real_key]
|
||||
|
||||
if scope is None:
|
||||
# try to get the target from its parent or ancestors
|
||||
parent = self.parent
|
||||
while parent is not None:
|
||||
if real_key in parent._module_dict:
|
||||
return parent._module_dict[real_key]
|
||||
parent = parent.parent
|
||||
else:
|
||||
# get from self._children
|
||||
if scope in self._children:
|
||||
return self._children[scope].get(real_key)
|
||||
else:
|
||||
root = self._get_root_registry()
|
||||
return root.get(key)
|
||||
|
||||
return None
|
||||
|
||||
def _search_child(self, scope: str) -> Optional['Registry']:
|
||||
"""Depth-first search for the corresponding registry in its children.
|
||||
|
||||
Note that the method only search for the corresponding registry from
|
||||
the current registry. Therefore, if we want to search from the root
|
||||
registry, :meth:`_get_root_registry` should be called to get the
|
||||
root registry first.
|
||||
|
||||
Args:
|
||||
scope (str): The scope name used for searching for its
|
||||
corresponding registry.
|
||||
|
||||
Returns:
|
||||
Registry or None: Return the corresponding registry if ``scope``
|
||||
exists, otherwise return None.
|
||||
"""
|
||||
if self._scope == scope:
|
||||
return self
|
||||
|
||||
for child in self._children.values():
|
||||
registry = child._search_child(scope)
|
||||
if registry is not None:
|
||||
return registry
|
||||
|
||||
return None
|
||||
|
||||
def build(self,
|
||||
*args,
|
||||
default_scope: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
"""Build an instance.
|
||||
|
||||
Build an instance by calling :attr:`build_func`. If
|
||||
:attr:`default_scope` is given, :meth:`build` will firstly get the
|
||||
responding registry and then call its own :meth:`build`.
|
||||
|
||||
Args:
|
||||
default_scope (str, optional): The ``default_scope`` is used to
|
||||
reset the current registry. Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine import Registry
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> def __init__(self, depth, stages=4):
|
||||
>>> self.depth = depth
|
||||
>>> self.stages = stages
|
||||
>>> cfg = dict(type='ResNet', depth=50)
|
||||
>>> model = MODELS.build(cfg)
|
||||
"""
|
||||
if default_scope is not None:
|
||||
root = self._get_root_registry()
|
||||
registry = root._search_child(default_scope)
|
||||
if registry is None:
|
||||
raise KeyError(
|
||||
f'{default_scope} does not exist in the registry tree.')
|
||||
else:
|
||||
registry = self
|
||||
|
||||
return registry.build_func(*args, **kwargs, registry=registry)
|
||||
|
||||
def _add_child(self, registry: 'Registry') -> None:
|
||||
"""Add a child for a registry.
|
||||
|
||||
Args:
|
||||
registry (:obj:`Registry`): The ``registry`` will be added as a
|
||||
child of the ``self``.
|
||||
"""
|
||||
|
||||
assert isinstance(registry, Registry)
|
||||
assert registry.scope is not None
|
||||
assert registry.scope not in self.children, \
|
||||
f'scope {registry.scope} exists in {self.name} registry'
|
||||
self.children[registry.scope] = registry
|
||||
|
||||
def _register_module(self,
|
||||
module_class: Type,
|
||||
module_name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = False) -> None:
|
||||
"""Register a module.
|
||||
|
||||
Args:
|
||||
module_class (type): Module class to be registered.
|
||||
module_name (str or list of str, optional): The module name to be
|
||||
registered. If not specified, the class name will be used.
|
||||
Defaults to None.
|
||||
force (bool): Whether to override an existing class with the same
|
||||
name. Defaults to False.
|
||||
"""
|
||||
if not inspect.isclass(module_class):
|
||||
raise TypeError('module must be a class, '
|
||||
f'but got {type(module_class)}')
|
||||
|
||||
if module_name is None:
|
||||
module_name = module_class.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in self._module_dict:
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'in {self.name}')
|
||||
self._module_dict[name] = module_class
|
||||
|
||||
def register_module(
|
||||
self,
|
||||
name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = False,
|
||||
module: Optional[Type] = None) -> Union[type, Callable]:
|
||||
"""Register a module.
|
||||
|
||||
A record will be added to ``self._module_dict``, whose key is the class
|
||||
name or the specified name, and value is the class itself.
|
||||
It can be used as a decorator or a normal function.
|
||||
|
||||
Args:
|
||||
name (str or list of str, optional): The module name to be
|
||||
registered. If not specified, the class name will be used.
|
||||
force (bool): Whether to override an existing class with the same
|
||||
name. Default to False.
|
||||
module (type, optional): Module class to be registered. Defaults to
|
||||
None.
|
||||
|
||||
Examples:
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> # as a decorator
|
||||
>>> @backbones.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> backbones = Registry('backbone')
|
||||
>>> @backbones.register_module(name='mnet')
|
||||
>>> class MobileNet:
|
||||
>>> pass
|
||||
|
||||
>>> # as a normal function
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> backbones.register_module(module=ResNet)
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be None, an instance of str, or a sequence of str, '
|
||||
f'but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
self._register_module(
|
||||
module_class=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(cls):
|
||||
self._register_module(
|
||||
module_class=cls, module_name=name, force=force)
|
||||
return cls
|
||||
|
||||
return _register
|
34
mmengine/registry/root.py
Normal file
34
mmengine/registry/root.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""MMEngine provides 11 root registries to support using modules across
|
||||
projects.
|
||||
|
||||
More datails can be found at
|
||||
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||
"""
|
||||
|
||||
from .registry import Registry
|
||||
|
||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
RUNNERS = Registry('runner')
|
||||
# manage runner constructors that define how to initialize runners
|
||||
RUNNER_CONSTRUCTORS = Registry('runner constructor')
|
||||
# manage all kinds of hooks like `CheckpointHook`
|
||||
HOOKS = Registry('hook')
|
||||
|
||||
# manage data-related modules
|
||||
DATASETS = Registry('dataset')
|
||||
DATA_SAMPLERS = Registry('data sampler')
|
||||
TRANSFORMS = Registry('transform')
|
||||
|
||||
# mangage all kinds of modules inheriting `nn.Module`
|
||||
MODELS = Registry('model')
|
||||
# mangage all kinds of weight initialization modules like `Uniform`
|
||||
WEIGHT_INITIALIZERS = Registry('weight initializer')
|
||||
|
||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer')
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
|
||||
|
||||
# manage task-specific modules like anchor generators and box coders
|
||||
TASK_UTILS = Registry('task util')
|
@ -8,6 +8,7 @@ from collections import abc
|
||||
from importlib import import_module
|
||||
from inspect import getfullargspec
|
||||
from itertools import repeat
|
||||
from typing import Sequence, Type
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
@ -125,16 +126,26 @@ def tuple_cast(inputs, dst_type):
|
||||
return iter_cast(inputs, dst_type, return_type=tuple)
|
||||
|
||||
|
||||
def is_seq_of(seq, expected_type, seq_type=None):
|
||||
def is_seq_of(seq: Sequence,
|
||||
expected_type: Type,
|
||||
seq_type: Type = None) -> bool:
|
||||
"""Check whether it is a sequence of some type.
|
||||
|
||||
Args:
|
||||
seq (Sequence): The sequence to be checked.
|
||||
expected_type (type): Expected type of sequence items.
|
||||
seq_type (type, optional): Expected sequence type.
|
||||
seq_type (type, optional): Expected sequence type. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: Whether the sequence is valid.
|
||||
bool: Return True if ``seq`` is valid else False.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine.utils import is_seq_of
|
||||
>>> seq = ['a', 'b', 'c']
|
||||
>>> is_seq_of(seq, str)
|
||||
True
|
||||
>>> is_seq_of(seq, int)
|
||||
False
|
||||
"""
|
||||
if seq_type is None:
|
||||
exp_seq_type = abc.Sequence
|
||||
|
@ -1,7 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
|
||||
from mmengine import Config, ConfigDict, Registry, build_from_cfg
|
||||
from mmengine.config import Config, ConfigDict # type: ignore
|
||||
from mmengine.registry import Registry, build_from_cfg
|
||||
|
||||
|
||||
class TestRegistry:
|
||||
@ -91,8 +92,8 @@ class TestRegistry:
|
||||
# `name` is an invalid type
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('name must be either of None, an instance of str or a '
|
||||
"sequence of str, but got <class 'int'>")):
|
||||
match=('name must be None, an instance of str, or a sequence '
|
||||
"of str, but got <class 'int'>")):
|
||||
|
||||
@CATS.register_module(name=7474741)
|
||||
class SiameseCat:
|
||||
@ -149,8 +150,8 @@ class TestRegistry:
|
||||
assert len(CATS) == 8
|
||||
|
||||
def _build_registry(self):
|
||||
"""A helper function to build a hierarchy registry."""
|
||||
# Hierarchy Registry
|
||||
"""A helper function to build a Hierarchical Registry."""
|
||||
# Hierarchical Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
@ -177,7 +178,7 @@ class TestRegistry:
|
||||
return registries
|
||||
|
||||
def test_get_root_registry(self):
|
||||
# Hierarchy Registry
|
||||
# Hierarchical Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
@ -186,7 +187,8 @@ class TestRegistry:
|
||||
# | | |
|
||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||
# (little_hound) (mid_hound) (little_samoyed)
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
|
||||
registries = self._build_registry()
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
|
||||
|
||||
assert DOGS._get_root_registry() is DOGS
|
||||
assert HOUNDS._get_root_registry() is DOGS
|
||||
@ -194,7 +196,7 @@ class TestRegistry:
|
||||
assert MID_HOUNDS._get_root_registry() is DOGS
|
||||
|
||||
def test_get(self):
|
||||
# Hierarchy Registry
|
||||
# Hierarchical Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
@ -277,7 +279,7 @@ class TestRegistry:
|
||||
) is LittlePedigreeSamoyed
|
||||
|
||||
def test_search_child(self):
|
||||
# Hierarchy Registry
|
||||
# Hierarchical Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
@ -286,7 +288,8 @@ class TestRegistry:
|
||||
# | | |
|
||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||
# (little_hound) (mid_hound) (little_samoyed)
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
|
||||
registries = self._build_registry()
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
|
||||
|
||||
assert DOGS._search_child('hound') is HOUNDS
|
||||
assert DOGS._search_child('not a child') is None
|
||||
@ -294,8 +297,9 @@ class TestRegistry:
|
||||
assert LITTLE_HOUNDS._search_child('hound') is None
|
||||
assert LITTLE_HOUNDS._search_child('mid_hound') is None
|
||||
|
||||
def test_build(self):
|
||||
# Hierarchy Registry
|
||||
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
||||
def test_build(self, cfg_type):
|
||||
# Hierarchical Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
@ -311,14 +315,14 @@ class TestRegistry:
|
||||
class GoldenRetriever:
|
||||
pass
|
||||
|
||||
gr_cfg = dict(type='GoldenRetriever')
|
||||
gr_cfg = cfg_type(dict(type='GoldenRetriever'))
|
||||
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
|
||||
|
||||
@HOUNDS.register_module()
|
||||
class BloodHound:
|
||||
pass
|
||||
|
||||
bh_cfg = dict(type='BloodHound')
|
||||
bh_cfg = cfg_type(dict(type='BloodHound'))
|
||||
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
|
||||
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
|
||||
|
||||
@ -326,14 +330,14 @@ class TestRegistry:
|
||||
class Dachshund:
|
||||
pass
|
||||
|
||||
d_cfg = dict(type='Dachshund')
|
||||
d_cfg = cfg_type(dict(type='Dachshund'))
|
||||
assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund)
|
||||
|
||||
@MID_HOUNDS.register_module()
|
||||
class Beagle:
|
||||
pass
|
||||
|
||||
b_cfg = dict(type='Beagle')
|
||||
b_cfg = cfg_type(dict(type='Beagle'))
|
||||
assert isinstance(MID_HOUNDS.build(b_cfg), Beagle)
|
||||
|
||||
# test `default_scope`
|
||||
@ -367,7 +371,8 @@ class TestRegistry:
|
||||
assert repr(CATS) == repr_str
|
||||
|
||||
|
||||
def test_build_from_cfg():
|
||||
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
||||
def test_build_from_cfg(cfg_type):
|
||||
BACKBONES = Registry('backbone')
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@ -387,24 +392,14 @@ def test_build_from_cfg():
|
||||
# test `cfg` parameter
|
||||
# `cfg` should be a dict, ConfigDict or Config object
|
||||
with pytest.raises(
|
||||
TypeError, match="cfg must be a dict, but got <class 'str'>"):
|
||||
TypeError,
|
||||
match=('cfg should be a dict, ConfigDict or Config, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = 'ResNet'
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` is a dict
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# `cfg` is a ConfigDict object
|
||||
cfg = ConfigDict(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# `cfg` is a Config object
|
||||
cfg = Config(dict(type='ResNet', depth=50))
|
||||
# `cfg` is a dict, ConfigDict or Config object
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
@ -412,6 +407,7 @@ def test_build_from_cfg():
|
||||
# `cfg` is a dict but it does not contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = dict(depth=50, stages=4)
|
||||
cfg = cfg_type(cfg)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# cfg['type'] should be a str or class
|
||||
@ -419,52 +415,56 @@ def test_build_from_cfg():
|
||||
TypeError,
|
||||
match="type must be a str or valid type, but got <class 'int'>"):
|
||||
cfg = dict(type=1000)
|
||||
cfg = cfg_type(cfg)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
||||
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = dict(type=ResNet, depth=50)
|
||||
cfg = cfg_type(dict(type=ResNet, depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# non-registered class
|
||||
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
|
||||
cfg = dict(type='VGG')
|
||||
cfg = cfg_type(dict(type='VGG'))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` contains unexpected arguments
|
||||
with pytest.raises(TypeError):
|
||||
cfg = dict(type='ResNet', non_existing_arg=50)
|
||||
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# test `default_args` parameter
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
# default_args must be a dict or None
|
||||
with pytest.raises(TypeError):
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=1)
|
||||
|
||||
# cfg or default_args should contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = dict(depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4))
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
|
||||
|
||||
# "type" defined using default_args
|
||||
cfg = dict(depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(type='ResNet'))
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = dict(depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
@ -474,5 +474,5 @@ def test_build_from_cfg():
|
||||
TypeError,
|
||||
match=('registry must be a mmengine.Registry object, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, 'BACKBONES')
|
||||
|
Loading…
x
Reference in New Issue
Block a user