mirror of https://github.com/open-mmlab/mmcv.git
add model registry (#760)
* add model registry * fixed infer scoep * fixed build func * add docstring * add md * support multi level * clean comments * add docs * fixed parent * add more doc * add value error, add docstring * fixed docs * change to local/global search * resolve comments * fixed test * update some docstring * update docs (minior) * update docs * update docspull/942/head
parent
47825b194e
commit
375605fba8
118
docs/registry.md
118
docs/registry.md
|
@ -9,12 +9,15 @@ In MMCV, registry can be regarded as a mapping that maps a class to a string.
|
|||
These classes contained by a single registry usually have similar APIs but implement different algorithms or support different datasets.
|
||||
With the registry, users can find and instantiate the class through its corresponding string, and use the instantiated module as they want.
|
||||
One typical example is the config systems in most OpenMMLab projects, which use the registry to create hooks, runners, models, and datasets, through configs.
|
||||
The API reference could be find [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry).
|
||||
|
||||
To manage your modules in the codebase by `Registry`, there are three steps as below.
|
||||
|
||||
1. Create an registry
|
||||
2. Create a build method
|
||||
3. Use this registry to manage the modules
|
||||
1. Create a build method (optional, in most cases you can just use the default one).
|
||||
2. Create a registry.
|
||||
3. Use this registry to manage the modules.
|
||||
|
||||
`build_func` argument of `Registry` is to customize how to instantiate the class instance, the default one is `build_from_cfg` implemented [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg).
|
||||
|
||||
### A Simple Example
|
||||
|
||||
|
@ -22,27 +25,13 @@ Here we show a simple example of using registry to manage modules in a package.
|
|||
You can find more practical examples in OpenMMLab projects.
|
||||
|
||||
Assuming we want to implement a series of Dataset Converter for converting different formats of data to the expected data format.
|
||||
We create directory as a package named `converters`.
|
||||
We create a directory as a package named `converters`.
|
||||
In the package, we first create a file to implement builders, named `converters/builder.py`, as below
|
||||
|
||||
```python
|
||||
from mmcv.utils import Registry
|
||||
|
||||
# create a registry for converters
|
||||
CONVERTERS = Registry('converter')
|
||||
|
||||
|
||||
# create a build function
|
||||
def build_converter(cfg, *args, **kwargs):
|
||||
cfg_ = cfg.copy()
|
||||
converter_type = cfg_.pop('type')
|
||||
if converter_type not in CONVERTERS:
|
||||
raise KeyError(f'Unrecognized task type {converter_type}')
|
||||
else:
|
||||
converter_cls = CONVERTERS.get(converter_type)
|
||||
|
||||
converter = converter_cls(*args, **kwargs, **cfg_)
|
||||
return converter
|
||||
```
|
||||
|
||||
Then we can implement different converters in the package. For example, implement `Converter1` in `converters/converter1.py`
|
||||
|
@ -51,7 +40,6 @@ Then we can implement different converters in the package. For example, implemen
|
|||
|
||||
from .builder import CONVERTERS
|
||||
|
||||
|
||||
# use the registry to manage the module
|
||||
@CONVERTERS.register_module()
|
||||
class Converter1(object):
|
||||
|
@ -71,5 +59,95 @@ If the module is successfully registered, you can use this converter through con
|
|||
|
||||
```python
|
||||
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
|
||||
converter = build_converter(converter_cfg)
|
||||
converter = CONVERTERS.build(converter_cfg)
|
||||
```
|
||||
|
||||
## Customize Build Function
|
||||
|
||||
Suppose we would like to customize how `converters` are built, we could implement a customized `build_func` and pass it into the registry.
|
||||
|
||||
```python
|
||||
from mmcv.utils import Registry
|
||||
|
||||
# create a build function
|
||||
def build_converter(cfg, registry, *args, **kwargs):
|
||||
cfg_ = cfg.copy()
|
||||
converter_type = cfg_.pop('type')
|
||||
if converter_type not in registry:
|
||||
raise KeyError(f'Unrecognized converter type {converter_type}')
|
||||
else:
|
||||
converter_cls = registry.get(converter_type)
|
||||
|
||||
converter = converter_cls(*args, **kwargs, **cfg_)
|
||||
return converter
|
||||
|
||||
# create a registry for converters and pass ``build_converter`` function
|
||||
CONVERTERS = Registry('converter', build_func=build_converter)
|
||||
```
|
||||
|
||||
Note: in this example, we demonstrate how to use the `build_func` argument to customize the way to build a class instance.
|
||||
The functionality is similar to the default `build_from_cfg`. In most cases, default one would be sufficient.
|
||||
`build_model_from_cfg` is also implemented to build PyTorch module in `nn.Sequentail`, you may directly use them instead of implementing by yourself.
|
||||
|
||||
## Hierarchy Registry
|
||||
|
||||
You could also build modules from more than one OpenMMLab frameworks, e.g. you could use all backbones in [MMClassification](https://github.com/open-mmlab/mmclassification) for object detectors in [MMDetection](https://github.com/open-mmlab/mmdetection), you may also combine an object detection model in [MMDetection](https://github.com/open-mmlab/mmdetection) and semantic segmentation model in [MMSegmentation](https://github.com/open-mmlab/mmsegmentation).
|
||||
|
||||
All `MODELS` registries of downstream codebases are children registries of MMCV's `MODELS` registry.
|
||||
Basically, there are two ways to build a module from child or sibling registries.
|
||||
|
||||
1. Build from children registries.
|
||||
|
||||
For example:
|
||||
|
||||
In MMDetection we define:
|
||||
|
||||
```python
|
||||
from mmcv.utils import Registry
|
||||
from mmcv.cnn import MODELS as MMCV_MODELS
|
||||
MODELS = Registry('model', parent=MMCV_MODELS)
|
||||
|
||||
@MODELS.register_module()
|
||||
class NetA(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
```
|
||||
|
||||
In MMClassification we define:
|
||||
|
||||
```python
|
||||
from mmcv.utils import Registry
|
||||
from mmcv.cnn import MODELS as MMCV_MODELS
|
||||
MODELS = Registry('model', parent=MMCV_MODELS)
|
||||
|
||||
@MODELS.register_module()
|
||||
class NetB(nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
```
|
||||
|
||||
We could build two net in either MMDetection or MMClassification by:
|
||||
|
||||
```python
|
||||
from mmdet.models import MODELS
|
||||
net_a = MODELS.build(cfg=dict(type='NetA'))
|
||||
net_b = MODELS.build(cfg=dict(type='mmcls.NetB'))
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```python
|
||||
from mmcls.models import MODELS
|
||||
net_a = MODELS.build(cfg=dict(type='mmdet.NetA'))
|
||||
net_b = MODELS.build(cfg=dict(type='NetB'))
|
||||
```
|
||||
|
||||
2. Build from parent registry.
|
||||
|
||||
The shared `MODELS` registry in MMCV is the parent registry for all downstream codebases (root registry):
|
||||
|
||||
```python
|
||||
from mmcv.cnn import MODELS as MMCV_MODELS
|
||||
net_a = MMCV_MODELS.build(cfg=dict(type='mmdet.NetA'))
|
||||
net_b = MMCV_MODELS.build(cfg=dict(type='mmcls.NetB'))
|
||||
```
|
||||
|
|
|
@ -11,6 +11,7 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
|||
build_activation_layer, build_conv_layer,
|
||||
build_norm_layer, build_padding_layer, build_plugin_layer,
|
||||
build_upsample_layer, conv_ws_2d, is_norm)
|
||||
from .builder import MODELS, build_model_from_cfg
|
||||
# yapf: enable
|
||||
from .resnet import ResNet, make_res_layer
|
||||
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
|
||||
|
@ -34,5 +35,5 @@ __all__ = [
|
|||
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
|
||||
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
|
||||
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
||||
'Caffe2XavierInit'
|
||||
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from ..utils import Registry, build_from_cfg
|
||||
|
||||
|
||||
def build_model_from_cfg(cfg, registry, default_args=None):
|
||||
"""Build a PyTorch model from config dict(s). Different from
|
||||
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
|
||||
|
||||
Args:
|
||||
cfg (dict, list[dict]): The config of modules, is is either a config
|
||||
dict or a list of config dicts. If cfg is a list, a
|
||||
the built modules will be wrapped with ``nn.Sequential``.
|
||||
registry (:obj:`Registry`): A registry the module belongs to.
|
||||
default_args (dict, optional): Default arguments to build the module.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: A built nn module.
|
||||
"""
|
||||
if isinstance(cfg, list):
|
||||
modules = [
|
||||
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
||||
]
|
||||
return nn.Sequential(*modules)
|
||||
else:
|
||||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
|
||||
MODELS = Registry('model', build_func=build_model_from_cfg)
|
|
@ -5,16 +5,107 @@ from functools import partial
|
|||
from .misc import is_seq_of
|
||||
|
||||
|
||||
def build_from_cfg(cfg, registry, default_args=None):
|
||||
"""Build a module from config dict.
|
||||
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
||||
if 'type' not in cfg:
|
||||
if default_args is None or 'type' not in default_args:
|
||||
raise KeyError(
|
||||
'`cfg` or `default_args` must contain the key "type", '
|
||||
f'but got {cfg}\n{default_args}')
|
||||
if not isinstance(registry, Registry):
|
||||
raise TypeError('registry must be an mmcv.Registry object, '
|
||||
f'but got {type(registry)}')
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
args = cfg.copy()
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError(
|
||||
f'{obj_type} is not in the {registry.name} registry')
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
try:
|
||||
return obj_cls(**args)
|
||||
except Exception as e:
|
||||
# Normal TypeError does not print class name.
|
||||
raise type(e)(f'{obj_cls.__name__}: {e}')
|
||||
|
||||
|
||||
class Registry:
|
||||
"""A registry to map strings to classes.
|
||||
|
||||
Registered object could be built from registry.
|
||||
Example:
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = MODELS.build(dict(type='ResNet'))
|
||||
|
||||
Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
|
||||
advanced useage.
|
||||
|
||||
Args:
|
||||
name (str): Registry name.
|
||||
build_func(func, optional): Build function to construct instance from
|
||||
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
||||
``build_func`` is specified. If ``parent`` is specified and
|
||||
``build_func`` is not given, ``build_func`` will be inherited
|
||||
from ``parent``. Default: None.
|
||||
parent (Registry, optional): Parent registry. The class registered in
|
||||
children registry could be built from parent. Default: None.
|
||||
scope (str, optional): The scope of registry. It is the key to search
|
||||
for children registry. If not specified, scope will be the name of
|
||||
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name, build_func=None, parent=None, scope=None):
|
||||
self._name = name
|
||||
self._module_dict = dict()
|
||||
self._children = dict()
|
||||
self._scope = self.infer_scope() if scope is None else scope
|
||||
|
||||
# self.build_func will be set with the following priority:
|
||||
# 1. build_func
|
||||
# 2. parent.build_func
|
||||
# 3. build_from_cfg
|
||||
if build_func is None:
|
||||
if parent is not None:
|
||||
self.build_func = parent.build_func
|
||||
else:
|
||||
self.build_func = build_from_cfg
|
||||
else:
|
||||
self.build_func = build_func
|
||||
if parent is not None:
|
||||
assert isinstance(parent, Registry)
|
||||
parent._add_children(self)
|
||||
self.parent = parent
|
||||
else:
|
||||
self.parent = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self._module_dict)
|
||||
|
@ -28,14 +119,68 @@ class Registry:
|
|||
f'items={self._module_dict})'
|
||||
return format_str
|
||||
|
||||
@staticmethod
|
||||
def infer_scope():
|
||||
"""Infer the scope of registry.
|
||||
|
||||
The name of the package where registry is defined will be returned.
|
||||
|
||||
Example:
|
||||
# in mmdet/models/backbone/resnet.py
|
||||
>>> MODELS = Registry('models')
|
||||
>>> @MODELS.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
The scope of ``ResNet`` will be ``mmdet``.
|
||||
|
||||
|
||||
Returns:
|
||||
scope (str): The inferred scope name.
|
||||
"""
|
||||
# inspect.stack() trace where this function is called, the index-2
|
||||
# indicates the frame where `infer_scope()` is called
|
||||
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
|
||||
split_filename = filename.split('.')
|
||||
return split_filename[0]
|
||||
|
||||
@staticmethod
|
||||
def split_scope_key(key):
|
||||
"""Split scope and key.
|
||||
|
||||
The first scope will be split from key.
|
||||
|
||||
Examples:
|
||||
>>> Registry.split_scope_key('mmdet.ResNet')
|
||||
'mmdet', 'ResNet'
|
||||
>>> Registry.split_scope_key('ResNet')
|
||||
None, 'ResNet'
|
||||
|
||||
Return:
|
||||
scope (str, None): The first scope.
|
||||
key (str): The remaining key.
|
||||
"""
|
||||
split_index = key.find('.')
|
||||
if split_index != -1:
|
||||
return key[:split_index], key[split_index + 1:]
|
||||
else:
|
||||
return None, key
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def scope(self):
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def module_dict(self):
|
||||
return self._module_dict
|
||||
|
||||
@property
|
||||
def children(self):
|
||||
return self._children
|
||||
|
||||
def get(self, key):
|
||||
"""Get the registry record.
|
||||
|
||||
|
@ -45,7 +190,45 @@ class Registry:
|
|||
Returns:
|
||||
class: The corresponding class.
|
||||
"""
|
||||
return self._module_dict.get(key, None)
|
||||
scope, real_key = self.split_scope_key(key)
|
||||
if scope is None or scope == self._scope:
|
||||
# get from self
|
||||
if real_key in self._module_dict:
|
||||
return self._module_dict[real_key]
|
||||
else:
|
||||
# get from self._children
|
||||
if scope in self._children:
|
||||
return self._children[scope].get(real_key)
|
||||
else:
|
||||
# goto root
|
||||
parent = self.parent
|
||||
while parent.parent is not None:
|
||||
parent = parent.parent
|
||||
return parent.get(key)
|
||||
|
||||
def build(self, *args, **kwargs):
|
||||
return self.build_func(*args, **kwargs, registry=self)
|
||||
|
||||
def _add_children(self, registry):
|
||||
"""Add children for a registry.
|
||||
|
||||
The ``registry`` will be added as children based on its scope.
|
||||
The parent registry could build objects from children registry.
|
||||
|
||||
Example:
|
||||
>>> models = Registry('models')
|
||||
>>> mmdet_models = Registry('models', parent=models)
|
||||
>>> @mmdet_models.register_module()
|
||||
>>> class ResNet:
|
||||
>>> pass
|
||||
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
||||
"""
|
||||
|
||||
assert isinstance(registry, Registry)
|
||||
assert registry.scope is not None
|
||||
assert registry.scope not in self.children, \
|
||||
f'scope {registry.scope} exists in {self.name} registry'
|
||||
self.children[registry.scope] = registry
|
||||
|
||||
def _register_module(self, module_class, module_name=None, force=False):
|
||||
if not inspect.isclass(module_class):
|
||||
|
@ -131,52 +314,3 @@ class Registry:
|
|||
return cls
|
||||
|
||||
return _register
|
||||
|
||||
|
||||
def build_from_cfg(cfg, registry, default_args=None):
|
||||
"""Build a module from config dict.
|
||||
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
||||
if 'type' not in cfg:
|
||||
if default_args is None or 'type' not in default_args:
|
||||
raise KeyError(
|
||||
'`cfg` or `default_args` must contain the key "type", '
|
||||
f'but got {cfg}\n{default_args}')
|
||||
if not isinstance(registry, Registry):
|
||||
raise TypeError('registry must be an mmcv.Registry object, '
|
||||
f'but got {type(registry)}')
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
args = cfg.copy()
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError(
|
||||
f'{obj_type} is not in the {registry.name} registry')
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
try:
|
||||
return obj_cls(**args)
|
||||
except Exception as e:
|
||||
# Normal TypeError does not print class name.
|
||||
raise type(e)(f'{obj_cls.__name__}: {e}')
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import torch.nn as nn
|
||||
|
||||
import mmcv
|
||||
from mmcv.cnn import MODELS, build_model_from_cfg
|
||||
|
||||
|
||||
def test_build_model_from_cfg():
|
||||
BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg)
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeXt(nn.Module):
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = [
|
||||
dict(type='ResNet', depth=50),
|
||||
dict(type='ResNeXt', depth=50, stages=3)
|
||||
]
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, nn.Sequential)
|
||||
assert isinstance(model[0], ResNet)
|
||||
assert model[0].depth == 50 and model[0].stages == 4
|
||||
assert isinstance(model[1], ResNeXt)
|
||||
assert model[1].depth == 50 and model[1].stages == 3
|
||||
|
||||
# test inherit `build_func` from parent
|
||||
NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new')
|
||||
assert NEW_MODELS.build_func is build_model_from_cfg
|
||||
|
||||
# test specify `build_func`
|
||||
def pseudo_build(cfg):
|
||||
return cfg
|
||||
|
||||
NEW_MODELS = mmcv.Registry(
|
||||
'models', parent=MODELS, build_func=pseudo_build)
|
||||
assert NEW_MODELS.build_func is pseudo_build
|
|
@ -132,6 +132,57 @@ def test_registry():
|
|||
# end: test old APIs
|
||||
|
||||
|
||||
def test_multi_scope_registry():
|
||||
DOGS = mmcv.Registry('dogs')
|
||||
assert DOGS.name == 'dogs'
|
||||
assert DOGS.scope == 'test_registry'
|
||||
assert DOGS.module_dict == {}
|
||||
assert len(DOGS) == 0
|
||||
|
||||
@DOGS.register_module()
|
||||
class GoldenRetriever:
|
||||
pass
|
||||
|
||||
assert len(DOGS) == 1
|
||||
assert DOGS.get('GoldenRetriever') is GoldenRetriever
|
||||
|
||||
HOUNDS = mmcv.Registry('dogs', parent=DOGS, scope='hound')
|
||||
|
||||
@HOUNDS.register_module()
|
||||
class BloodHound:
|
||||
pass
|
||||
|
||||
assert len(HOUNDS) == 1
|
||||
assert HOUNDS.get('BloodHound') is BloodHound
|
||||
assert DOGS.get('hound.BloodHound') is BloodHound
|
||||
assert HOUNDS.get('hound.BloodHound') is BloodHound
|
||||
|
||||
LITTLE_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='little_hound')
|
||||
|
||||
@LITTLE_HOUNDS.register_module()
|
||||
class Dachshund:
|
||||
pass
|
||||
|
||||
assert len(LITTLE_HOUNDS) == 1
|
||||
assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
|
||||
assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
|
||||
assert HOUNDS.get('little_hound.Dachshund') is Dachshund
|
||||
assert DOGS.get('hound.little_hound.Dachshund') is Dachshund
|
||||
|
||||
MID_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='mid_hound')
|
||||
|
||||
@MID_HOUNDS.register_module()
|
||||
class Beagle:
|
||||
pass
|
||||
|
||||
assert MID_HOUNDS.get('Beagle') is Beagle
|
||||
assert HOUNDS.get('mid_hound.Beagle') is Beagle
|
||||
assert DOGS.get('hound.mid_hound.Beagle') is Beagle
|
||||
assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
|
||||
assert MID_HOUNDS.get('hound.BloodHound') is BloodHound
|
||||
assert MID_HOUNDS.get('hound.Dachshund') is None
|
||||
|
||||
|
||||
def test_build_from_cfg():
|
||||
BACKBONES = mmcv.Registry('backbone')
|
||||
|
||||
|
|
Loading…
Reference in New Issue