[Enhance] Add build function for scheduler. (#372)

* add build function for scheduler

* add unit test

add unit test

* handle convert_to_iter in build_scheduler_from_cfg

* restore deleted code

* format import

* fix lint
pull/414/head
Mashiro 2022-08-08 20:34:16 +08:00 committed by GitHub
parent 99de0951af
commit a07a063306
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 305 additions and 218 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .build_functions import (build_from_cfg, build_model_from_cfg,
build_runner_from_cfg)
build_runner_from_cfg, build_scheduler_from_cfg)
from .default_scope import DefaultScope
from .registry import Registry
from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
@ -17,5 +17,6 @@ __all__ = [
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg'
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
'build_scheduler_from_cfg'
]

View File

@ -1,12 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import logging
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import torch.nn as nn
from ..config import Config, ConfigDict
from ..utils import ManagerMixin
from .registry import Registry
if TYPE_CHECKING:
from ..optim.scheduler import _ParamScheduler
from ..runner import Runner
def build_from_cfg(
cfg: Union[dict, ConfigDict, Config],
@ -131,7 +137,7 @@ def build_from_cfg(
def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
registry: Registry) -> Any:
registry: Registry) -> 'Runner':
"""Build a Runner object.
Examples:
>>> from mmengine.registry import Registry, build_runner_from_cfg
@ -203,7 +209,11 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
f'{cls_location}.py: {e}')
def build_model_from_cfg(cfg, registry, default_args=None):
def build_model_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
nn.Module:
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
@ -226,3 +236,69 @@ def build_model_from_cfg(cfg, registry, default_args=None):
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
'_ParamScheduler':
"""Builds a ``ParamScheduler`` instance from config.
``ParamScheduler`` supports building instance by its constructor or
method ``build_iter_from_epoch``. Therefore, its registry needs a build
function to handle both cases.
Args:
cfg (dict or ConfigDict or Config): Config dictionary. If it contains
the key ``convert_to_iter_based``, instance will be built by method
``convert_to_iter_based``, otherwise instance will be built by its
constructor.
registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry.
default_args (dict or ConfigDict or Config, optional): Default
initialization arguments. It must contain key ``optimizer``. If
``convert_to_iter_based`` is defined in ``cfg``, it must
additionally contain key ``epoch_length``. Defaults to None.
Returns:
object: The constructed ``ParamScheduler``.
"""
assert isinstance(
cfg,
(dict, ConfigDict, Config
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
assert isinstance(
registry, Registry), ('registry should be a mmengine.Registry object',
f'but got {type(registry)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
scope = args.pop('_scope_', None)
with registry.switch_scope_and_registry(scope) as registry:
convert_to_iter = args.pop('convert_to_iter_based', False)
if convert_to_iter:
scheduler_type = args.pop('type')
assert 'epoch_length' in args and args.get('by_epoch', True), (
'Only epoch-based parameter scheduler can be converted to '
'iter-based, and `epoch_length` should be set')
if isinstance(scheduler_type, str):
scheduler_cls = registry.get(scheduler_type)
if scheduler_cls is None:
raise KeyError(
f'{scheduler_type} is not in the {registry.name} '
'registry. Please check whether the value of '
f'`{scheduler_type}` is correct or it was registered '
'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
)
elif inspect.isclass(scheduler_type):
scheduler_cls = scheduler_type
else:
raise TypeError('type must be a str or valid type, but got '
f'{type(scheduler_type)}')
return scheduler_cls.build_iter_from_epoch( # type: ignore
**args)
else:
args.pop('epoch_length', None)
return build_from_cfg(args, registry)

View File

@ -6,7 +6,8 @@ More datails can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
from mmengine.registry import build_model_from_cfg, build_runner_from_cfg
from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
build_scheduler_from_cfg)
from .registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
@ -37,7 +38,8 @@ OPTIM_WRAPPERS = Registry('optim_wrapper')
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor')
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry('parameter scheduler')
PARAM_SCHEDULERS = Registry(
'parameter scheduler', build_func=build_scheduler_from_cfg)
# manage all kinds of metrics
METRICS = Registry('metric')

View File

@ -1134,34 +1134,16 @@ class Runner:
f'The `end` of {_scheduler["type"]} is not set. '
'Use the max epochs/iters of train loop as default.')
convert_to_iter = _scheduler.pop('convert_to_iter_based',
False)
if convert_to_iter:
assert _scheduler.get(
'by_epoch',
True), ('only epoch-based parameter scheduler can be '
'converted to iter-based')
assert isinstance(self._train_loop, BaseLoop), \
'Scheduler can only be converted to iter-based ' \
'when train loop is built.'
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
param_schedulers.append(
cls.build_iter_from_epoch( # type: ignore
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(
optimizer=optim_wrapper,
**_scheduler,
epoch_length=len(
self.train_dataloader), # type: ignore
))
else:
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(optimizer=optim_wrapper)))
epoch_length=len(self.train_dataloader))))
else:
raise TypeError(
'scheduler should be a _ParamScheduler object or dict, '
f'but got {scheduler}')
return param_schedulers
def build_param_scheduler(

View File

@ -0,0 +1,213 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch.nn as nn
from torch.optim import SGD
from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin,
Registry, build_from_cfg, build_model_from_cfg)
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
def test_build_from_cfg(cfg_type):
BACKBONES = Registry('backbone')
@BACKBONES.register_module()
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module()
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
# test `cfg` parameter
# `cfg` should be a dict, ConfigDict or Config object
with pytest.raises(
TypeError,
match=('cfg should be a dict, ConfigDict or Config, but got '
"<class 'str'>")):
cfg = 'ResNet'
model = build_from_cfg(cfg, BACKBONES)
# `cfg` is a dict, ConfigDict or Config object
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# `cfg` is a dict but it does not contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
cfg = cfg_type(cfg)
model = build_from_cfg(cfg, BACKBONES)
# cfg['type'] should be a str or class
with pytest.raises(
TypeError,
match="type must be a str or valid type, but got <class 'int'>"):
cfg = dict(type=1000)
cfg = cfg_type(cfg)
model = build_from_cfg(cfg, BACKBONES)
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = cfg_type(dict(type=ResNet, depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# non-registered class
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
cfg = cfg_type(dict(type='VGG'))
model = build_from_cfg(cfg, BACKBONES)
# `cfg` contains unexpected arguments
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
model = build_from_cfg(cfg, BACKBONES)
# test `default_args` parameter
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
# default_args must be a dict or None
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
# "type" defined using default_args
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# test `registry` parameter
# incorrect registry type
with pytest.raises(
TypeError,
match=('registry must be a mmengine.Registry object, but got '
"<class 'str'>")):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, 'BACKBONES')
VISUALIZER = Registry('visualizer')
@VISUALIZER.register_module()
class Visualizer(ManagerMixin):
def __init__(self, name):
super().__init__(name)
with pytest.raises(RuntimeError):
Visualizer.get_current_instance()
cfg = dict(type='Visualizer', name='visualizer')
build_from_cfg(cfg, VISUALIZER)
Visualizer.get_current_instance()
def test_build_model_from_cfg():
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
class ResNet(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
@BACKBONES.register_module()
class ResNeXt(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
cfg = dict(type='ResNet', depth=50)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = [
dict(type='ResNet', depth=50),
dict(type='ResNeXt', depth=50, stages=3)
]
model = BACKBONES.build(cfg)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], ResNet)
assert model[0].depth == 50 and model[0].stages == 4
assert isinstance(model[1], ResNeXt)
assert model[1].depth == 50 and model[1].stages == 3
# test inherit `build_func` from parent
NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
assert NEW_MODELS.build_func is build_model_from_cfg
# test specify `build_func`
def pseudo_build(cfg):
return cfg
NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
assert NEW_MODELS.build_func is pseudo_build
def test_build_sheduler_from_cfg():
model = nn.Conv2d(1, 1, 1)
optimizer = SGD(model.parameters(), lr=0.1)
cfg = dict(
type='LinearParamScheduler',
optimizer=optimizer,
param_name='lr',
begin=0,
end=100)
sheduler = PARAM_SCHEDULERS.build(cfg)
assert sheduler.begin == 0
assert sheduler.end == 100
cfg = dict(
type='LinearParamScheduler',
convert_to_iter_based=True,
optimizer=optimizer,
param_name='lr',
begin=0,
end=100,
epoch_length=10)
sheduler = PARAM_SCHEDULERS.build(cfg)
assert sheduler.begin == 0
assert sheduler.end == 1000

View File

@ -2,12 +2,9 @@
import time
import pytest
import torch.nn as nn
from mmengine.config import Config, ConfigDict # type: ignore
from mmengine.registry import (DefaultScope, Registry, build_from_cfg,
build_model_from_cfg)
from mmengine.utils import ManagerMixin
from mmengine.registry import DefaultScope, Registry, build_from_cfg
class TestRegistry:
@ -476,182 +473,3 @@ class TestRegistry:
"<locals>.Munchkin'>")
repr_str += '})'
assert repr(CATS) == repr_str
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
def test_build_from_cfg(cfg_type):
BACKBONES = Registry('backbone')
@BACKBONES.register_module()
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module()
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
# test `cfg` parameter
# `cfg` should be a dict, ConfigDict or Config object
with pytest.raises(
TypeError,
match=('cfg should be a dict, ConfigDict or Config, but got '
"<class 'str'>")):
cfg = 'ResNet'
model = build_from_cfg(cfg, BACKBONES)
# `cfg` is a dict, ConfigDict or Config object
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# `cfg` is a dict but it does not contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
cfg = cfg_type(cfg)
model = build_from_cfg(cfg, BACKBONES)
# cfg['type'] should be a str or class
with pytest.raises(
TypeError,
match="type must be a str or valid type, but got <class 'int'>"):
cfg = dict(type=1000)
cfg = cfg_type(cfg)
model = build_from_cfg(cfg, BACKBONES)
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = cfg_type(dict(type=ResNet, depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# non-registered class
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
cfg = cfg_type(dict(type='VGG'))
model = build_from_cfg(cfg, BACKBONES)
# `cfg` contains unexpected arguments
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
model = build_from_cfg(cfg, BACKBONES)
# test `default_args` parameter
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
# default_args must be a dict or None
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
# "type" defined using default_args
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# test `registry` parameter
# incorrect registry type
with pytest.raises(
TypeError,
match=('registry must be a mmengine.Registry object, but got '
"<class 'str'>")):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, 'BACKBONES')
VISUALIZER = Registry('visualizer')
@VISUALIZER.register_module()
class Visualizer(ManagerMixin):
def __init__(self, name):
super().__init__(name)
with pytest.raises(RuntimeError):
Visualizer.get_current_instance()
cfg = dict(type='Visualizer', name='visualizer')
build_from_cfg(cfg, VISUALIZER)
Visualizer.get_current_instance()
def test_build_model_from_cfg():
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
class ResNet(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
@BACKBONES.register_module()
class ResNeXt(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
cfg = dict(type='ResNet', depth=50)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = [
dict(type='ResNet', depth=50),
dict(type='ResNeXt', depth=50, stages=3)
]
model = BACKBONES.build(cfg)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], ResNet)
assert model[0].depth == 50 and model[0].stages == 4
assert isinstance(model[1], ResNeXt)
assert model[1].depth == 50 and model[1].stages == 3
# test inherit `build_func` from parent
NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
assert NEW_MODELS.build_func is build_model_from_cfg
# test specify `build_func`
def pseudo_build(cfg):
return cfg
NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
assert NEW_MODELS.build_func is pseudo_build

View File

@ -993,11 +993,6 @@ class TestRunner(TestCase):
# 5.1 train loop should be built before converting scheduler
cfg = dict(
type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True)
with self.assertRaisesRegex(
AssertionError,
'Scheduler can only be converted to iter-based when '
'train loop is built.'):
runner.build_param_scheduler(cfg)
# 5.2 convert epoch-based to iter-based scheduler
cfg = dict(