mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Support building optimizer wrapper from built Optimizer instance (#422)
* support build optimizer wrapper from built Optimizer instance * refine comments
This commit is contained in:
parent
a706bbc018
commit
e08b9031fc
@ -33,9 +33,9 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
|||||||
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
||||||
build_optim_wrapper)
|
build_optim_wrapper)
|
||||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||||
LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS,
|
||||||
RUNNERS, VISUALIZERS, DefaultScope,
|
PARAM_SCHEDULERS, RUNNERS, VISUALIZERS,
|
||||||
count_registered_modules)
|
DefaultScope, count_registered_modules)
|
||||||
from mmengine.registry.root import LOG_PROCESSORS
|
from mmengine.registry.root import LOG_PROCESSORS
|
||||||
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
||||||
get_git_hash, is_list_of, is_seq_of,
|
get_git_hash, is_list_of, is_seq_of,
|
||||||
@ -1067,26 +1067,29 @@ class Runner:
|
|||||||
"""
|
"""
|
||||||
if isinstance(optim_wrapper, OptimWrapper):
|
if isinstance(optim_wrapper, OptimWrapper):
|
||||||
return optim_wrapper
|
return optim_wrapper
|
||||||
elif isinstance(optim_wrapper, (dict, ConfigDict, Config)):
|
if isinstance(optim_wrapper, (dict, ConfigDict, Config)):
|
||||||
# If `optim_wrapper` is a config dict with only one optimizer,
|
# optimizer must be defined for single optimizer training.
|
||||||
# the config dict must contain `optimizer`:
|
optimizer = optim_wrapper.get('optimizer', None)
|
||||||
# optim_wrapper = dict(optimizer=dict(type='SGD', lr=0.1))
|
|
||||||
# `type` is optional, defaults to `OptimWrapper`.
|
# If optimizer is a built `Optimizer` instance, the optimizer
|
||||||
# `optim_wrapper` could also be defined as:
|
# wrapper should be built by `OPTIM_WRAPPERS` registry.
|
||||||
# optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.1)) # noqa: E501
|
if isinstance(optimizer, Optimizer):
|
||||||
# to build specific optimizer wrapper.
|
optim_wrapper.setdefault('type', 'OptimWrapper')
|
||||||
if 'type' in optim_wrapper or 'optimizer' in optim_wrapper:
|
return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore
|
||||||
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
|
|
||||||
return optim_wrapper
|
# If `optimizer` is not None or `constructor` is defined, it means,
|
||||||
elif 'constructor' not in optim_wrapper:
|
# optimizer wrapper will be built by optimizer wrapper
|
||||||
# if `type` and `optimizer` are not defined in `optim_wrapper`,
|
# constructor. Therefore, `build_optim_wrapper` should be called.
|
||||||
# it should be the case of training with multiple optimizers.
|
if optimizer is not None or 'constructor' in optim_wrapper:
|
||||||
# If constructor is not defined in `optim_wrapper`, each value
|
return build_optim_wrapper(self.model, optim_wrapper)
|
||||||
# of `optim_wrapper` must be an `OptimWrapper` instance since
|
else:
|
||||||
# `DefaultOptimizerConstructor` will not handle the case of
|
# if `optimizer` is not defined, it should be the case of
|
||||||
# training with multiple optimizers. `build_optim_wrapper` will
|
# training with multiple optimizers. If `constructor` is not
|
||||||
# directly build the `OptimWrapperDict` instance from
|
# defined either, Each value of `optim_wrapper` must be an
|
||||||
# `optim_wrapper.`
|
# `OptimWrapper` instance since `DefaultOptimizerConstructor`
|
||||||
|
# will not handle the case of training with multiple
|
||||||
|
# optimizers. `build_optim_wrapper` will directly build the
|
||||||
|
# `OptimWrapperDict` instance from `optim_wrapper.`
|
||||||
optim_wrappers = OrderedDict()
|
optim_wrappers = OrderedDict()
|
||||||
for name, optim in optim_wrapper.items():
|
for name, optim in optim_wrapper.items():
|
||||||
if not isinstance(optim, OptimWrapper):
|
if not isinstance(optim, OptimWrapper):
|
||||||
@ -1096,11 +1099,6 @@ class Runner:
|
|||||||
f'optimizer, but got {name}={optim}')
|
f'optimizer, but got {name}={optim}')
|
||||||
optim_wrappers[name] = optim
|
optim_wrappers[name] = optim
|
||||||
return OptimWrapperDict(**optim_wrappers)
|
return OptimWrapperDict(**optim_wrappers)
|
||||||
# If constructor is defined, directly build the optimizer
|
|
||||||
# wrapper instance from the config dict.
|
|
||||||
else:
|
|
||||||
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
|
|
||||||
return optim_wrapper
|
|
||||||
else:
|
else:
|
||||||
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
||||||
f'object or dict, but got {optim_wrapper}')
|
f'object or dict, but got {optim_wrapper}')
|
||||||
|
@ -21,8 +21,9 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
|||||||
RuntimeInfoHook)
|
RuntimeInfoHook)
|
||||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||||
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||||
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor,
|
||||||
OptimWrapper, OptimWrapperDict, StepLR)
|
MultiStepLR, OptimWrapper, OptimWrapperDict,
|
||||||
|
StepLR)
|
||||||
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
||||||
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
|
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
|
||||||
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
||||||
@ -926,6 +927,23 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD)
|
self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD)
|
||||||
self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam)
|
self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam)
|
||||||
|
|
||||||
|
# 2.4 input is a dict which contains optimizer instance.
|
||||||
|
model = nn.Linear(1, 1)
|
||||||
|
optimizer = SGD(model.parameters(), lr=0.1)
|
||||||
|
optim_wrapper_cfg = dict(optimizer=optimizer)
|
||||||
|
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||||
|
self.assertIsInstance(optim_wrapper, OptimWrapper)
|
||||||
|
self.assertIs(optim_wrapper.optimizer, optimizer)
|
||||||
|
|
||||||
|
model = nn.Linear(1, 1)
|
||||||
|
optimizer = SGD(model.parameters(), lr=0.1)
|
||||||
|
optim_wrapper_cfg = dict(
|
||||||
|
optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2)
|
||||||
|
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||||
|
self.assertIsInstance(optim_wrapper, AmpOptimWrapper)
|
||||||
|
self.assertIs(optim_wrapper.optimizer, optimizer)
|
||||||
|
self.assertEqual(optim_wrapper._accumulative_counts, 2)
|
||||||
|
|
||||||
def test_build_param_scheduler(self):
|
def test_build_param_scheduler(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_build_param_scheduler'
|
cfg.experiment_name = 'test_build_param_scheduler'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user