[Enhance] Add build function for scheduler. (#372)
* add build function for scheduler * add unit test add unit test * handle convert_to_iter in build_scheduler_from_cfg * restore deleted code * format import * fix lintpull/414/head
parent
99de0951af
commit
a07a063306
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .build_functions import (build_from_cfg, build_model_from_cfg,
|
||||
build_runner_from_cfg)
|
||||
build_runner_from_cfg, build_scheduler_from_cfg)
|
||||
from .default_scope import DefaultScope
|
||||
from .registry import Registry
|
||||
from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
||||
|
@ -17,5 +17,6 @@ __all__ = [
|
|||
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
|
||||
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
|
||||
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
|
||||
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg'
|
||||
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
|
||||
'build_scheduler_from_cfg'
|
||||
]
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..config import Config, ConfigDict
|
||||
from ..utils import ManagerMixin
|
||||
from .registry import Registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..optim.scheduler import _ParamScheduler
|
||||
from ..runner import Runner
|
||||
|
||||
|
||||
def build_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
|
@ -131,7 +137,7 @@ def build_from_cfg(
|
|||
|
||||
|
||||
def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry) -> Any:
|
||||
registry: Registry) -> 'Runner':
|
||||
"""Build a Runner object.
|
||||
Examples:
|
||||
>>> from mmengine.registry import Registry, build_runner_from_cfg
|
||||
|
@ -203,7 +209,11 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
|
|||
f'{cls_location}.py: {e}')
|
||||
|
||||
|
||||
def build_model_from_cfg(cfg, registry, default_args=None):
|
||||
def build_model_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
|
||||
nn.Module:
|
||||
"""Build a PyTorch model from config dict(s). Different from
|
||||
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
|
||||
|
||||
|
@ -226,3 +236,69 @@ def build_model_from_cfg(cfg, registry, default_args=None):
|
|||
return Sequential(*modules)
|
||||
else:
|
||||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
|
||||
def build_scheduler_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
|
||||
'_ParamScheduler':
|
||||
"""Builds a ``ParamScheduler`` instance from config.
|
||||
|
||||
``ParamScheduler`` supports building instance by its constructor or
|
||||
method ``build_iter_from_epoch``. Therefore, its registry needs a build
|
||||
function to handle both cases.
|
||||
|
||||
Args:
|
||||
cfg (dict or ConfigDict or Config): Config dictionary. If it contains
|
||||
the key ``convert_to_iter_based``, instance will be built by method
|
||||
``convert_to_iter_based``, otherwise instance will be built by its
|
||||
constructor.
|
||||
registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry.
|
||||
default_args (dict or ConfigDict or Config, optional): Default
|
||||
initialization arguments. It must contain key ``optimizer``. If
|
||||
``convert_to_iter_based`` is defined in ``cfg``, it must
|
||||
additionally contain key ``epoch_length``. Defaults to None.
|
||||
|
||||
Returns:
|
||||
object: The constructed ``ParamScheduler``.
|
||||
"""
|
||||
assert isinstance(
|
||||
cfg,
|
||||
(dict, ConfigDict, Config
|
||||
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
|
||||
assert isinstance(
|
||||
registry, Registry), ('registry should be a mmengine.Registry object',
|
||||
f'but got {type(registry)}')
|
||||
|
||||
args = cfg.copy()
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
scope = args.pop('_scope_', None)
|
||||
with registry.switch_scope_and_registry(scope) as registry:
|
||||
convert_to_iter = args.pop('convert_to_iter_based', False)
|
||||
if convert_to_iter:
|
||||
scheduler_type = args.pop('type')
|
||||
assert 'epoch_length' in args and args.get('by_epoch', True), (
|
||||
'Only epoch-based parameter scheduler can be converted to '
|
||||
'iter-based, and `epoch_length` should be set')
|
||||
if isinstance(scheduler_type, str):
|
||||
scheduler_cls = registry.get(scheduler_type)
|
||||
if scheduler_cls is None:
|
||||
raise KeyError(
|
||||
f'{scheduler_type} is not in the {registry.name} '
|
||||
'registry. Please check whether the value of '
|
||||
f'`{scheduler_type}` is correct or it was registered '
|
||||
'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
|
||||
)
|
||||
elif inspect.isclass(scheduler_type):
|
||||
scheduler_cls = scheduler_type
|
||||
else:
|
||||
raise TypeError('type must be a str or valid type, but got '
|
||||
f'{type(scheduler_type)}')
|
||||
return scheduler_cls.build_iter_from_epoch( # type: ignore
|
||||
**args)
|
||||
else:
|
||||
args.pop('epoch_length', None)
|
||||
return build_from_cfg(args, registry)
|
||||
|
|
|
@ -6,7 +6,8 @@ More datails can be found at
|
|||
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||
"""
|
||||
|
||||
from mmengine.registry import build_model_from_cfg, build_runner_from_cfg
|
||||
from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
|
||||
build_scheduler_from_cfg)
|
||||
from .registry import Registry
|
||||
|
||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
|
@ -37,7 +38,8 @@ OPTIM_WRAPPERS = Registry('optim_wrapper')
|
|||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor')
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry('parameter scheduler')
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', build_func=build_scheduler_from_cfg)
|
||||
|
||||
# manage all kinds of metrics
|
||||
METRICS = Registry('metric')
|
||||
|
|
|
@ -1134,34 +1134,16 @@ class Runner:
|
|||
f'The `end` of {_scheduler["type"]} is not set. '
|
||||
'Use the max epochs/iters of train loop as default.')
|
||||
|
||||
convert_to_iter = _scheduler.pop('convert_to_iter_based',
|
||||
False)
|
||||
if convert_to_iter:
|
||||
assert _scheduler.get(
|
||||
'by_epoch',
|
||||
True), ('only epoch-based parameter scheduler can be '
|
||||
'converted to iter-based')
|
||||
assert isinstance(self._train_loop, BaseLoop), \
|
||||
'Scheduler can only be converted to iter-based ' \
|
||||
'when train loop is built.'
|
||||
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
|
||||
param_schedulers.append(
|
||||
cls.build_iter_from_epoch( # type: ignore
|
||||
param_schedulers.append(
|
||||
PARAM_SCHEDULERS.build(
|
||||
_scheduler,
|
||||
default_args=dict(
|
||||
optimizer=optim_wrapper,
|
||||
**_scheduler,
|
||||
epoch_length=len(
|
||||
self.train_dataloader), # type: ignore
|
||||
))
|
||||
else:
|
||||
param_schedulers.append(
|
||||
PARAM_SCHEDULERS.build(
|
||||
_scheduler,
|
||||
default_args=dict(optimizer=optim_wrapper)))
|
||||
epoch_length=len(self.train_dataloader))))
|
||||
else:
|
||||
raise TypeError(
|
||||
'scheduler should be a _ParamScheduler object or dict, '
|
||||
f'but got {scheduler}')
|
||||
|
||||
return param_schedulers
|
||||
|
||||
def build_param_scheduler(
|
||||
|
|
|
@ -0,0 +1,213 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin,
|
||||
Registry, build_from_cfg, build_model_from_cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
||||
def test_build_from_cfg(cfg_type):
|
||||
BACKBONES = Registry('backbone')
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet:
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeXt:
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
# test `cfg` parameter
|
||||
# `cfg` should be a dict, ConfigDict or Config object
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('cfg should be a dict, ConfigDict or Config, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = 'ResNet'
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` is a dict, ConfigDict or Config object
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# `cfg` is a dict but it does not contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = dict(depth=50, stages=4)
|
||||
cfg = cfg_type(cfg)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# cfg['type'] should be a str or class
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="type must be a str or valid type, but got <class 'int'>"):
|
||||
cfg = dict(type=1000)
|
||||
cfg = cfg_type(cfg)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = cfg_type(dict(type=ResNet, depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# non-registered class
|
||||
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
|
||||
cfg = cfg_type(dict(type='VGG'))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` contains unexpected arguments
|
||||
with pytest.raises(TypeError):
|
||||
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# test `default_args` parameter
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
# default_args must be a dict or None
|
||||
with pytest.raises(TypeError):
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=1)
|
||||
|
||||
# cfg or default_args should contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
|
||||
|
||||
# "type" defined using default_args
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# test `registry` parameter
|
||||
# incorrect registry type
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('registry must be a mmengine.Registry object, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, 'BACKBONES')
|
||||
|
||||
VISUALIZER = Registry('visualizer')
|
||||
|
||||
@VISUALIZER.register_module()
|
||||
class Visualizer(ManagerMixin):
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Visualizer.get_current_instance()
|
||||
cfg = dict(type='Visualizer', name='visualizer')
|
||||
build_from_cfg(cfg, VISUALIZER)
|
||||
Visualizer.get_current_instance()
|
||||
|
||||
|
||||
def test_build_model_from_cfg():
|
||||
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeXt(nn.Module):
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = [
|
||||
dict(type='ResNet', depth=50),
|
||||
dict(type='ResNeXt', depth=50, stages=3)
|
||||
]
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, nn.Sequential)
|
||||
assert isinstance(model[0], ResNet)
|
||||
assert model[0].depth == 50 and model[0].stages == 4
|
||||
assert isinstance(model[1], ResNeXt)
|
||||
assert model[1].depth == 50 and model[1].stages == 3
|
||||
|
||||
# test inherit `build_func` from parent
|
||||
NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
|
||||
assert NEW_MODELS.build_func is build_model_from_cfg
|
||||
|
||||
# test specify `build_func`
|
||||
def pseudo_build(cfg):
|
||||
return cfg
|
||||
|
||||
NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
|
||||
assert NEW_MODELS.build_func is pseudo_build
|
||||
|
||||
|
||||
def test_build_sheduler_from_cfg():
|
||||
model = nn.Conv2d(1, 1, 1)
|
||||
optimizer = SGD(model.parameters(), lr=0.1)
|
||||
cfg = dict(
|
||||
type='LinearParamScheduler',
|
||||
optimizer=optimizer,
|
||||
param_name='lr',
|
||||
begin=0,
|
||||
end=100)
|
||||
sheduler = PARAM_SCHEDULERS.build(cfg)
|
||||
assert sheduler.begin == 0
|
||||
assert sheduler.end == 100
|
||||
|
||||
cfg = dict(
|
||||
type='LinearParamScheduler',
|
||||
convert_to_iter_based=True,
|
||||
optimizer=optimizer,
|
||||
param_name='lr',
|
||||
begin=0,
|
||||
end=100,
|
||||
epoch_length=10)
|
||||
|
||||
sheduler = PARAM_SCHEDULERS.build(cfg)
|
||||
assert sheduler.begin == 0
|
||||
assert sheduler.end == 1000
|
|
@ -2,12 +2,9 @@
|
|||
import time
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.config import Config, ConfigDict # type: ignore
|
||||
from mmengine.registry import (DefaultScope, Registry, build_from_cfg,
|
||||
build_model_from_cfg)
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.registry import DefaultScope, Registry, build_from_cfg
|
||||
|
||||
|
||||
class TestRegistry:
|
||||
|
@ -476,182 +473,3 @@ class TestRegistry:
|
|||
"<locals>.Munchkin'>")
|
||||
repr_str += '})'
|
||||
assert repr(CATS) == repr_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
||||
def test_build_from_cfg(cfg_type):
|
||||
BACKBONES = Registry('backbone')
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet:
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeXt:
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
# test `cfg` parameter
|
||||
# `cfg` should be a dict, ConfigDict or Config object
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('cfg should be a dict, ConfigDict or Config, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = 'ResNet'
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` is a dict, ConfigDict or Config object
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# `cfg` is a dict but it does not contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = dict(depth=50, stages=4)
|
||||
cfg = cfg_type(cfg)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# cfg['type'] should be a str or class
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="type must be a str or valid type, but got <class 'int'>"):
|
||||
cfg = dict(type=1000)
|
||||
cfg = cfg_type(cfg)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = cfg_type(dict(type=ResNet, depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# non-registered class
|
||||
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
|
||||
cfg = cfg_type(dict(type='VGG'))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` contains unexpected arguments
|
||||
with pytest.raises(TypeError):
|
||||
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# test `default_args` parameter
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
# default_args must be a dict or None
|
||||
with pytest.raises(TypeError):
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=1)
|
||||
|
||||
# cfg or default_args should contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
|
||||
|
||||
# "type" defined using default_args
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = cfg_type(dict(depth=50))
|
||||
model = build_from_cfg(
|
||||
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# test `registry` parameter
|
||||
# incorrect registry type
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('registry must be a mmengine.Registry object, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = cfg_type(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, 'BACKBONES')
|
||||
|
||||
VISUALIZER = Registry('visualizer')
|
||||
|
||||
@VISUALIZER.register_module()
|
||||
class Visualizer(ManagerMixin):
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Visualizer.get_current_instance()
|
||||
cfg = dict(type='Visualizer', name='visualizer')
|
||||
build_from_cfg(cfg, VISUALIZER)
|
||||
Visualizer.get_current_instance()
|
||||
|
||||
|
||||
def test_build_model_from_cfg():
|
||||
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeXt(nn.Module):
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = [
|
||||
dict(type='ResNet', depth=50),
|
||||
dict(type='ResNeXt', depth=50, stages=3)
|
||||
]
|
||||
model = BACKBONES.build(cfg)
|
||||
assert isinstance(model, nn.Sequential)
|
||||
assert isinstance(model[0], ResNet)
|
||||
assert model[0].depth == 50 and model[0].stages == 4
|
||||
assert isinstance(model[1], ResNeXt)
|
||||
assert model[1].depth == 50 and model[1].stages == 3
|
||||
|
||||
# test inherit `build_func` from parent
|
||||
NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
|
||||
assert NEW_MODELS.build_func is build_model_from_cfg
|
||||
|
||||
# test specify `build_func`
|
||||
def pseudo_build(cfg):
|
||||
return cfg
|
||||
|
||||
NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
|
||||
assert NEW_MODELS.build_func is pseudo_build
|
||||
|
|
|
@ -993,11 +993,6 @@ class TestRunner(TestCase):
|
|||
# 5.1 train loop should be built before converting scheduler
|
||||
cfg = dict(
|
||||
type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
'Scheduler can only be converted to iter-based when '
|
||||
'train loop is built.'):
|
||||
runner.build_param_scheduler(cfg)
|
||||
|
||||
# 5.2 convert epoch-based to iter-based scheduler
|
||||
cfg = dict(
|
||||
|
|
Loading…
Reference in New Issue