[Enhance] Support building optimizer wrapper from built Optimizer instance (#422)

* support build optimizer wrapper from built Optimizer instance

* refine comments
pull/441/head
Mashiro 2022-08-17 19:17:00 +08:00 committed by GitHub
parent a706bbc018
commit e08b9031fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 30 deletions

View File

@ -33,9 +33,9 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope,
count_registered_modules)
LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, RUNNERS, VISUALIZERS,
DefaultScope, count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
get_git_hash, is_list_of, is_seq_of,
@ -1067,26 +1067,29 @@ class Runner:
"""
if isinstance(optim_wrapper, OptimWrapper):
return optim_wrapper
elif isinstance(optim_wrapper, (dict, ConfigDict, Config)):
# If `optim_wrapper` is a config dict with only one optimizer,
# the config dict must contain `optimizer`:
# optim_wrapper = dict(optimizer=dict(type='SGD', lr=0.1))
# `type` is optional, defaults to `OptimWrapper`.
# `optim_wrapper` could also be defined as:
# optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.1)) # noqa: E501
# to build specific optimizer wrapper.
if 'type' in optim_wrapper or 'optimizer' in optim_wrapper:
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
return optim_wrapper
elif 'constructor' not in optim_wrapper:
# if `type` and `optimizer` are not defined in `optim_wrapper`,
# it should be the case of training with multiple optimizers.
# If constructor is not defined in `optim_wrapper`, each value
# of `optim_wrapper` must be an `OptimWrapper` instance since
# `DefaultOptimizerConstructor` will not handle the case of
# training with multiple optimizers. `build_optim_wrapper` will
# directly build the `OptimWrapperDict` instance from
# `optim_wrapper.`
if isinstance(optim_wrapper, (dict, ConfigDict, Config)):
# optimizer must be defined for single optimizer training.
optimizer = optim_wrapper.get('optimizer', None)
# If optimizer is a built `Optimizer` instance, the optimizer
# wrapper should be built by `OPTIM_WRAPPERS` registry.
if isinstance(optimizer, Optimizer):
optim_wrapper.setdefault('type', 'OptimWrapper')
return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore
# If `optimizer` is not None or `constructor` is defined, it means,
# optimizer wrapper will be built by optimizer wrapper
# constructor. Therefore, `build_optim_wrapper` should be called.
if optimizer is not None or 'constructor' in optim_wrapper:
return build_optim_wrapper(self.model, optim_wrapper)
else:
# if `optimizer` is not defined, it should be the case of
# training with multiple optimizers. If `constructor` is not
# defined either, Each value of `optim_wrapper` must be an
# `OptimWrapper` instance since `DefaultOptimizerConstructor`
# will not handle the case of training with multiple
# optimizers. `build_optim_wrapper` will directly build the
# `OptimWrapperDict` instance from `optim_wrapper.`
optim_wrappers = OrderedDict()
for name, optim in optim_wrapper.items():
if not isinstance(optim, OptimWrapper):
@ -1096,11 +1099,6 @@ class Runner:
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
# If constructor is defined, directly build the optimizer
# wrapper instance from the config dict.
else:
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
return optim_wrapper
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')

View File

@ -21,8 +21,9 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
RuntimeInfoHook)
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
OptimWrapper, OptimWrapperDict, StepLR)
from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor,
MultiStepLR, OptimWrapper, OptimWrapperDict,
StepLR)
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
@ -926,6 +927,23 @@ class TestRunner(TestCase):
self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD)
self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam)
# 2.4 input is a dict which contains optimizer instance.
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper_cfg = dict(optimizer=optimizer)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, OptimWrapper)
self.assertIs(optim_wrapper.optimizer, optimizer)
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper_cfg = dict(
optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, AmpOptimWrapper)
self.assertIs(optim_wrapper.optimizer, optimizer)
self.assertEqual(optim_wrapper._accumulative_counts, 2)
def test_build_param_scheduler(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_param_scheduler'