diff --git a/docs/en/api.rst b/docs/en/api.rst new file mode 100644 index 00000000..744f2348 --- /dev/null +++ b/docs/en/api.rst @@ -0,0 +1,4 @@ +Registry +-------- +.. automodule:: mmengine.registry + :members: diff --git a/docs/en/index.rst b/docs/en/index.rst index bb906f76..5d55908c 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -8,6 +8,12 @@ You can switch between Chinese and English documents in the lower-left corner of tutorials/registry.md +.. toctree:: + :maxdepth: 2 + :caption: API Reference + + api.rst + .. toctree:: :caption: Switch Language diff --git a/docs/zh_cn/api.rst b/docs/zh_cn/api.rst new file mode 100644 index 00000000..744f2348 --- /dev/null +++ b/docs/zh_cn/api.rst @@ -0,0 +1,4 @@ +Registry +-------- +.. automodule:: mmengine.registry + :members: diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index cc79a272..547c4e54 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -9,6 +9,12 @@ tutorials/registry.md tutorials/config.md +.. toctree:: + :maxdepth: 2 + :caption: API 文档 + + api.rst + .. toctree:: :caption: 语言切换 diff --git a/mmengine/__init__.py b/mmengine/__init__.py index 1e4ec2f9..ad8b6429 100644 --- a/mmengine/__init__.py +++ b/mmengine/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. # flake8: noqa from .fileio import * +from .registry import * from .utils import * diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py new file mode 100644 index 00000000..d7f8ee30 --- /dev/null +++ b/mmengine/registry/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .registry import Registry, build_from_cfg +from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS, + OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, RUNNER_CONSTRUCTORS, + RUNNERS, TASK_UTILS, TRANSFORMS, WEIGHT_INITIALIZERS) + +__all__ = [ + 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', + 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', + 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS' +] diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py new file mode 100644 index 00000000..32c7d722 --- /dev/null +++ b/mmengine/registry/registry.py @@ -0,0 +1,491 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import sys +from collections.abc import Callable +from typing import Dict, List, Optional, Tuple, Type, Union + +from ..config import Config, ConfigDict +from ..utils import is_seq_of + + +def build_from_cfg( + cfg: Union[dict, ConfigDict, Config], + registry: 'Registry', + default_args: Optional[Union[dict, ConfigDict, + Config]] = None) -> object: + """Build a module from config dict. + + At least one of the ``cfg`` and ``default_args`` contains the key "type" + which type should be either str or class. If they all contain it, the key + in ``cfg`` will be used because ``cfg`` has a high priority than + ``default_args`` that means if a key exist in both of them, the value of + the key will be ``cfg[key]``. They will be merged first and the key "type" + will be popped up and the remaining keys will be used as initialization + arguments. + + Examples: + >>> from mmengine import Registry, build_from_cfg + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> def __init__(self, depth, stages=4): + >>> self.depth = depth + >>> self.stages = stages + >>> cfg = dict(type='ResNet', depth=50) + >>> model = build_from_cfg(cfg, MODELS) + + Args: + cfg (dict or ConfigDict or Config): Config dict. It should at least + contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict or ConfigDict or Config, optional): Default + initialization arguments. Defaults to None. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, (dict, ConfigDict, Config)): + raise TypeError( + f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}') + + if 'type' not in cfg: + if default_args is None or 'type' not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f'but got {cfg}\n{default_args}') + + if not isinstance(registry, Registry): + raise TypeError('registry must be a mmengine.Registry object, ' + f'but got {type(registry)}') + + if not (isinstance(default_args, + (dict, ConfigDict, Config)) or default_args is None): + raise TypeError( + 'default_args should be a dict, ConfigDict, Config or None, ' + f'but got {type(default_args)}') + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop('type') + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name} registry. ' + f'Please check whether the value of `{obj_type}` is correct or' + ' it was registered as expected. More details can be found at' + ' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501 + ) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + try: + return obj_cls(**args) # type: ignore + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f'{obj_cls.__name__}: {e}') # type: ignore + + +class Registry: + """A registry to map strings to classes. + + Registered objects could be built from registry. + + Args: + name (str): Registry name. + build_func (callable, optional): A function to construct instance + from Registry. :func:`build_from_cfg` is used if neither ``parent`` + or ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Defaults to None. + parent (:obj:`Registry`, optional): Parent registry. The class + registered in children registry could be built from parent. + Defaults to None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Defaults to None. + + Examples: + >>> # define a registry + >>> MODELS = Registry('models') + >>> # registry the `ResNet` to `MODELS` + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> # build model from `MODELS` + >>> resnet = MODELS.build(dict(type='ResNet')) + + >>> # hierarchical registry + >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det') + >>> @DETECTORS.register_module() + >>> class FasterRCNN: + >>> pass + >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN')) + + More advanced usages can be found at + https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. + """ + + def __init__(self, + name: str, + build_func: Optional[Callable] = None, + parent: Optional['Registry'] = None, + scope: Optional[str] = None): + self._name = name + self._module_dict: Dict[str, Type] = dict() + self._children: Dict[str, 'Registry'] = dict() + + if scope is not None: + assert isinstance(scope, str) + self._scope = scope + else: + self._scope = self.infer_scope() if scope is None else scope + + # See https://mypy.readthedocs.io/en/stable/common_issues.html# + # variables-vs-type-aliases for the use + self.parent: Optional['Registry'] + if parent is not None: + assert isinstance(parent, Registry) + parent._add_child(self) + self.parent = parent + else: + self.parent = None + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + self.build_func: Callable + if build_func is None: + if self.parent is not None: + self.build_func = self.parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = self.__class__.__name__ + \ + f'(name={self._name}, ' \ + f'items={self._module_dict})' + return format_str + + @staticmethod + def infer_scope() -> str: + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Returns: + str: The inferred scope name. + + Examples: + >>> # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> # The scope of ``ResNet`` will be ``mmdet``. + """ + # `sys._getframe` returns the frame object that many calls below the + # top of the stack. The call stack for `infer_scope` can be listed as + # follow: + # frame-0: `infer_scope` itself + # frame-1: `__init__` of `Registry` which calls the `infer_scope` + # frame-2: Where the `Registry(...)` is called + filename = inspect.getmodule(sys._getframe(2)).__name__ # type: ignore + split_filename = filename.split('.') + return split_filename[0] + + @staticmethod + def split_scope_key(key: str) -> Tuple[Optional[str], str]: + """Split scope and key. + + The first scope will be split from key. + + Return: + tuple[str | None, str]: The former element is the first scope of + the key, which can be ``None``. The latter is the remaining key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + """ + split_index = key.find('.') + if split_index != -1: + return key[:split_index], key[split_index + 1:] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def _get_root_registry(self) -> 'Registry': + """Return the root registry.""" + root = self + while root.parent is not None: + root = root.parent + return root + + def get(self, key: str) -> Optional[Type]: + """Get the registry record. + + The method will first parse :attr:`key` and check whether it contains + a scope name. The logic to search for :attr:`key`: + + - ``key`` does not contain a scope name, i.e., it is purely a module + name like "ResNet": :meth:`get` will search for ``ResNet`` from the + current registry to its parent or ancestors until finding it. + + - ``key`` contains a scope name and it is equal to the scope of the + current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get` + will only search for ``ResNet`` in the current registry. + + - ``key`` contains a scope name and it is not equal to the scope of + the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the + scope exists in its children, :meth:`get`will get "FCNet" from + them. If not, :meth:`get` will first get the root registry and root + registry call its own :meth:`get` method. + + Args: + key (str): Name of the registered item, e.g., the class name in + string format. + + Returns: + Type or None: Return the corresponding class if ``key`` exists, + otherwise return None. + + Examples: + >>> # define a registry + >>> MODELS = Registry('models') + >>> # register `ResNet` to `MODELS` + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet_cls = MODELS.get('ResNet') + + >>> # hierarchical registry + >>> DETECTORS = Registry('detector', parent=MODELS, scope='det') + >>> # `ResNet` does not exist in `DETECTORS` but `get` method + >>> # will try to search from its parenet or ancestors + >>> resnet_cls = DETECTORS.get('ResNet') + >>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls') + >>> @CLASSIFIER.register_module() + >>> class MobileNet: + >>> pass + >>> # `get` from its sibling registries + >>> mobilenet_cls = DETECTORS.get('cls.MobileNet') + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + + if scope is None: + # try to get the target from its parent or ancestors + parent = self.parent + while parent is not None: + if real_key in parent._module_dict: + return parent._module_dict[real_key] + parent = parent.parent + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + root = self._get_root_registry() + return root.get(key) + + return None + + def _search_child(self, scope: str) -> Optional['Registry']: + """Depth-first search for the corresponding registry in its children. + + Note that the method only search for the corresponding registry from + the current registry. Therefore, if we want to search from the root + registry, :meth:`_get_root_registry` should be called to get the + root registry first. + + Args: + scope (str): The scope name used for searching for its + corresponding registry. + + Returns: + Registry or None: Return the corresponding registry if ``scope`` + exists, otherwise return None. + """ + if self._scope == scope: + return self + + for child in self._children.values(): + registry = child._search_child(scope) + if registry is not None: + return registry + + return None + + def build(self, + *args, + default_scope: Optional[str] = None, + **kwargs) -> None: + """Build an instance. + + Build an instance by calling :attr:`build_func`. If + :attr:`default_scope` is given, :meth:`build` will firstly get the + responding registry and then call its own :meth:`build`. + + Args: + default_scope (str, optional): The ``default_scope`` is used to + reset the current registry. Defaults to None. + + Examples: + >>> from mmengine import Registry + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> def __init__(self, depth, stages=4): + >>> self.depth = depth + >>> self.stages = stages + >>> cfg = dict(type='ResNet', depth=50) + >>> model = MODELS.build(cfg) + """ + if default_scope is not None: + root = self._get_root_registry() + registry = root._search_child(default_scope) + if registry is None: + raise KeyError( + f'{default_scope} does not exist in the registry tree.') + else: + registry = self + + return registry.build_func(*args, **kwargs, registry=registry) + + def _add_child(self, registry: 'Registry') -> None: + """Add a child for a registry. + + Args: + registry (:obj:`Registry`): The ``registry`` will be added as a + child of the ``self``. + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert registry.scope not in self.children, \ + f'scope {registry.scope} exists in {self.name} registry' + self.children[registry.scope] = registry + + def _register_module(self, + module_class: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = False) -> None: + """Register a module. + + Args: + module_class (type): Module class to be registered. + module_name (str or list of str, optional): The module name to be + registered. If not specified, the class name will be used. + Defaults to None. + force (bool): Whether to override an existing class with the same + name. Defaults to False. + """ + if not inspect.isclass(module_class): + raise TypeError('module must be a class, ' + f'but got {type(module_class)}') + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f'{name} is already registered ' + f'in {self.name}') + self._module_dict[name] = module_class + + def register_module( + self, + name: Optional[Union[str, List[str]]] = None, + force: bool = False, + module: Optional[Type] = None) -> Union[type, Callable]: + """Register a module. + + A record will be added to ``self._module_dict``, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Args: + name (str or list of str, optional): The module name to be + registered. If not specified, the class name will be used. + force (bool): Whether to override an existing class with the same + name. Default to False. + module (type, optional): Module class to be registered. Defaults to + None. + + Examples: + >>> backbones = Registry('backbone') + >>> # as a decorator + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> # as a normal function + >>> class ResNet: + >>> pass + >>> backbones.register_module(module=ResNet) + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module( + module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module( + module_class=cls, module_name=name, force=force) + return cls + + return _register diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py new file mode 100644 index 00000000..71ff9dd6 --- /dev/null +++ b/mmengine/registry/root.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMEngine provides 11 root registries to support using modules across +projects. + +More datails can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from .registry import Registry + +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry('runner') +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry('runner constructor') +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry('hook') + +# manage data-related modules +DATASETS = Registry('dataset') +DATA_SAMPLERS = Registry('data sampler') +TRANSFORMS = Registry('transform') + +# mangage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model') +# mangage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry('weight initializer') + +# mangage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry('optimizer') +# manage constructors that customize the optimization hyperparameters. +OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor') + +# manage task-specific modules like anchor generators and box coders +TASK_UTILS = Registry('task util') diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index 7957ea89..41823a7f 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -8,6 +8,7 @@ from collections import abc from importlib import import_module from inspect import getfullargspec from itertools import repeat +from typing import Sequence, Type # From PyTorch internals @@ -125,16 +126,26 @@ def tuple_cast(inputs, dst_type): return iter_cast(inputs, dst_type, return_type=tuple) -def is_seq_of(seq, expected_type, seq_type=None): +def is_seq_of(seq: Sequence, + expected_type: Type, + seq_type: Type = None) -> bool: """Check whether it is a sequence of some type. Args: seq (Sequence): The sequence to be checked. expected_type (type): Expected type of sequence items. - seq_type (type, optional): Expected sequence type. + seq_type (type, optional): Expected sequence type. Defaults to None. Returns: - bool: Whether the sequence is valid. + bool: Return True if ``seq`` is valid else False. + + Examples: + >>> from mmengine.utils import is_seq_of + >>> seq = ['a', 'b', 'c'] + >>> is_seq_of(seq, str) + True + >>> is_seq_of(seq, int) + False """ if seq_type is None: exp_seq_type = abc.Sequence diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 0f058cf6..61d55c9a 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest -from mmengine import Config, ConfigDict, Registry, build_from_cfg +from mmengine.config import Config, ConfigDict # type: ignore +from mmengine.registry import Registry, build_from_cfg class TestRegistry: @@ -91,8 +92,8 @@ class TestRegistry: # `name` is an invalid type with pytest.raises( TypeError, - match=('name must be either of None, an instance of str or a ' - "sequence of str, but got ")): + match=('name must be None, an instance of str, or a sequence ' + "of str, but got ")): @CATS.register_module(name=7474741) class SiameseCat: @@ -149,8 +150,8 @@ class TestRegistry: assert len(CATS) == 8 def _build_registry(self): - """A helper function to build a hierarchy registry.""" - # Hierarchy Registry + """A helper function to build a Hierarchical Registry.""" + # Hierarchical Registry # DOGS # _______|_______ # | | @@ -177,7 +178,7 @@ class TestRegistry: return registries def test_get_root_registry(self): - # Hierarchy Registry + # Hierarchical Registry # DOGS # _______|_______ # | | @@ -186,7 +187,8 @@ class TestRegistry: # | | | # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) - DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry() + registries = self._build_registry() + DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4] assert DOGS._get_root_registry() is DOGS assert HOUNDS._get_root_registry() is DOGS @@ -194,7 +196,7 @@ class TestRegistry: assert MID_HOUNDS._get_root_registry() is DOGS def test_get(self): - # Hierarchy Registry + # Hierarchical Registry # DOGS # _______|_______ # | | @@ -277,7 +279,7 @@ class TestRegistry: ) is LittlePedigreeSamoyed def test_search_child(self): - # Hierarchy Registry + # Hierarchical Registry # DOGS # _______|_______ # | | @@ -286,7 +288,8 @@ class TestRegistry: # | | | # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) - DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry() + registries = self._build_registry() + DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] assert DOGS._search_child('hound') is HOUNDS assert DOGS._search_child('not a child') is None @@ -294,8 +297,9 @@ class TestRegistry: assert LITTLE_HOUNDS._search_child('hound') is None assert LITTLE_HOUNDS._search_child('mid_hound') is None - def test_build(self): - # Hierarchy Registry + @pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) + def test_build(self, cfg_type): + # Hierarchical Registry # DOGS # _______|_______ # | | @@ -311,14 +315,14 @@ class TestRegistry: class GoldenRetriever: pass - gr_cfg = dict(type='GoldenRetriever') + gr_cfg = cfg_type(dict(type='GoldenRetriever')) assert isinstance(DOGS.build(gr_cfg), GoldenRetriever) @HOUNDS.register_module() class BloodHound: pass - bh_cfg = dict(type='BloodHound') + bh_cfg = cfg_type(dict(type='BloodHound')) assert isinstance(HOUNDS.build(bh_cfg), BloodHound) assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever) @@ -326,14 +330,14 @@ class TestRegistry: class Dachshund: pass - d_cfg = dict(type='Dachshund') + d_cfg = cfg_type(dict(type='Dachshund')) assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund) @MID_HOUNDS.register_module() class Beagle: pass - b_cfg = dict(type='Beagle') + b_cfg = cfg_type(dict(type='Beagle')) assert isinstance(MID_HOUNDS.build(b_cfg), Beagle) # test `default_scope` @@ -367,7 +371,8 @@ class TestRegistry: assert repr(CATS) == repr_str -def test_build_from_cfg(): +@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) +def test_build_from_cfg(cfg_type): BACKBONES = Registry('backbone') @BACKBONES.register_module() @@ -387,24 +392,14 @@ def test_build_from_cfg(): # test `cfg` parameter # `cfg` should be a dict, ConfigDict or Config object with pytest.raises( - TypeError, match="cfg must be a dict, but got "): + TypeError, + match=('cfg should be a dict, ConfigDict or Config, but got ' + "")): cfg = 'ResNet' model = build_from_cfg(cfg, BACKBONES) - # `cfg` is a dict - cfg = dict(type='ResNet', depth=50) - model = build_from_cfg(cfg, BACKBONES) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - # `cfg` is a ConfigDict object - cfg = ConfigDict(dict(type='ResNet', depth=50)) - model = build_from_cfg(cfg, BACKBONES) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - # `cfg` is a Config object - cfg = Config(dict(type='ResNet', depth=50)) + # `cfg` is a dict, ConfigDict or Config object + cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 @@ -412,6 +407,7 @@ def test_build_from_cfg(): # `cfg` is a dict but it does not contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = dict(depth=50, stages=4) + cfg = cfg_type(cfg) model = build_from_cfg(cfg, BACKBONES) # cfg['type'] should be a str or class @@ -419,52 +415,56 @@ def test_build_from_cfg(): TypeError, match="type must be a str or valid type, but got "): cfg = dict(type=1000) + cfg = cfg_type(cfg) model = build_from_cfg(cfg, BACKBONES) - cfg = dict(type='ResNeXt', depth=50, stages=3) + cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 - cfg = dict(type=ResNet, depth=50) + cfg = cfg_type(dict(type=ResNet, depth=50)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # non-registered class with pytest.raises(KeyError, match='VGG is not in the backbone registry'): - cfg = dict(type='VGG') + cfg = cfg_type(dict(type='VGG')) model = build_from_cfg(cfg, BACKBONES) # `cfg` contains unexpected arguments with pytest.raises(TypeError): - cfg = dict(type='ResNet', non_existing_arg=50) + cfg = cfg_type(dict(type='ResNet', non_existing_arg=50)) model = build_from_cfg(cfg, BACKBONES) # test `default_args` parameter - cfg = dict(type='ResNet', depth=50) - model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3}) + cfg = cfg_type(dict(type='ResNet', depth=50)) + model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 3 # default_args must be a dict or None with pytest.raises(TypeError): - cfg = dict(type='ResNet', depth=50) + cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, BACKBONES, default_args=1) # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): - cfg = dict(depth=50) - model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4)) + cfg = cfg_type(dict(depth=50)) + model = build_from_cfg( + cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) # "type" defined using default_args - cfg = dict(depth=50) - model = build_from_cfg(cfg, BACKBONES, default_args=dict(type='ResNet')) + cfg = cfg_type(dict(depth=50)) + model = build_from_cfg( + cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 - cfg = dict(depth=50) - model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet)) + cfg = cfg_type(dict(depth=50)) + model = build_from_cfg( + cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 @@ -474,5 +474,5 @@ def test_build_from_cfg(): TypeError, match=('registry must be a mmengine.Registry object, but got ' "")): - cfg = dict(type='ResNet', depth=50) + cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, 'BACKBONES')