[Enhance] Support building optimizer wrapper from built Optimizer instance (#422)
* support build optimizer wrapper from built Optimizer instance * refine commentspull/441/head
parent
a706bbc018
commit
e08b9031fc
|
@ -33,9 +33,9 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
|||
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
||||
build_optim_wrapper)
|
||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||
LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
||||
RUNNERS, VISUALIZERS, DefaultScope,
|
||||
count_registered_modules)
|
||||
LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS,
|
||||
PARAM_SCHEDULERS, RUNNERS, VISUALIZERS,
|
||||
DefaultScope, count_registered_modules)
|
||||
from mmengine.registry.root import LOG_PROCESSORS
|
||||
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
||||
get_git_hash, is_list_of, is_seq_of,
|
||||
|
@ -1067,26 +1067,29 @@ class Runner:
|
|||
"""
|
||||
if isinstance(optim_wrapper, OptimWrapper):
|
||||
return optim_wrapper
|
||||
elif isinstance(optim_wrapper, (dict, ConfigDict, Config)):
|
||||
# If `optim_wrapper` is a config dict with only one optimizer,
|
||||
# the config dict must contain `optimizer`:
|
||||
# optim_wrapper = dict(optimizer=dict(type='SGD', lr=0.1))
|
||||
# `type` is optional, defaults to `OptimWrapper`.
|
||||
# `optim_wrapper` could also be defined as:
|
||||
# optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.1)) # noqa: E501
|
||||
# to build specific optimizer wrapper.
|
||||
if 'type' in optim_wrapper or 'optimizer' in optim_wrapper:
|
||||
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
|
||||
return optim_wrapper
|
||||
elif 'constructor' not in optim_wrapper:
|
||||
# if `type` and `optimizer` are not defined in `optim_wrapper`,
|
||||
# it should be the case of training with multiple optimizers.
|
||||
# If constructor is not defined in `optim_wrapper`, each value
|
||||
# of `optim_wrapper` must be an `OptimWrapper` instance since
|
||||
# `DefaultOptimizerConstructor` will not handle the case of
|
||||
# training with multiple optimizers. `build_optim_wrapper` will
|
||||
# directly build the `OptimWrapperDict` instance from
|
||||
# `optim_wrapper.`
|
||||
if isinstance(optim_wrapper, (dict, ConfigDict, Config)):
|
||||
# optimizer must be defined for single optimizer training.
|
||||
optimizer = optim_wrapper.get('optimizer', None)
|
||||
|
||||
# If optimizer is a built `Optimizer` instance, the optimizer
|
||||
# wrapper should be built by `OPTIM_WRAPPERS` registry.
|
||||
if isinstance(optimizer, Optimizer):
|
||||
optim_wrapper.setdefault('type', 'OptimWrapper')
|
||||
return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore
|
||||
|
||||
# If `optimizer` is not None or `constructor` is defined, it means,
|
||||
# optimizer wrapper will be built by optimizer wrapper
|
||||
# constructor. Therefore, `build_optim_wrapper` should be called.
|
||||
if optimizer is not None or 'constructor' in optim_wrapper:
|
||||
return build_optim_wrapper(self.model, optim_wrapper)
|
||||
else:
|
||||
# if `optimizer` is not defined, it should be the case of
|
||||
# training with multiple optimizers. If `constructor` is not
|
||||
# defined either, Each value of `optim_wrapper` must be an
|
||||
# `OptimWrapper` instance since `DefaultOptimizerConstructor`
|
||||
# will not handle the case of training with multiple
|
||||
# optimizers. `build_optim_wrapper` will directly build the
|
||||
# `OptimWrapperDict` instance from `optim_wrapper.`
|
||||
optim_wrappers = OrderedDict()
|
||||
for name, optim in optim_wrapper.items():
|
||||
if not isinstance(optim, OptimWrapper):
|
||||
|
@ -1096,11 +1099,6 @@ class Runner:
|
|||
f'optimizer, but got {name}={optim}')
|
||||
optim_wrappers[name] = optim
|
||||
return OptimWrapperDict(**optim_wrappers)
|
||||
# If constructor is defined, directly build the optimizer
|
||||
# wrapper instance from the config dict.
|
||||
else:
|
||||
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
|
||||
return optim_wrapper
|
||||
else:
|
||||
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
||||
f'object or dict, but got {optim_wrapper}')
|
||||
|
|
|
@ -21,8 +21,9 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
|||
RuntimeInfoHook)
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
||||
OptimWrapper, OptimWrapperDict, StepLR)
|
||||
from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor,
|
||||
MultiStepLR, OptimWrapper, OptimWrapperDict,
|
||||
StepLR)
|
||||
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
||||
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
|
||||
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
||||
|
@ -926,6 +927,23 @@ class TestRunner(TestCase):
|
|||
self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD)
|
||||
self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam)
|
||||
|
||||
# 2.4 input is a dict which contains optimizer instance.
|
||||
model = nn.Linear(1, 1)
|
||||
optimizer = SGD(model.parameters(), lr=0.1)
|
||||
optim_wrapper_cfg = dict(optimizer=optimizer)
|
||||
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
self.assertIsInstance(optim_wrapper, OptimWrapper)
|
||||
self.assertIs(optim_wrapper.optimizer, optimizer)
|
||||
|
||||
model = nn.Linear(1, 1)
|
||||
optimizer = SGD(model.parameters(), lr=0.1)
|
||||
optim_wrapper_cfg = dict(
|
||||
optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2)
|
||||
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
self.assertIsInstance(optim_wrapper, AmpOptimWrapper)
|
||||
self.assertIs(optim_wrapper.optimizer, optimizer)
|
||||
self.assertEqual(optim_wrapper._accumulative_counts, 2)
|
||||
|
||||
def test_build_param_scheduler(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_param_scheduler'
|
||||
|
|
Loading…
Reference in New Issue