[Enhance] Support building optimizer wrapper from built Optimizer instance (#422)

* support build optimizer wrapper from built Optimizer instance

* refine comments
This commit is contained in:
Mashiro 2022-08-17 19:17:00 +08:00 committed by GitHub
parent a706bbc018
commit e08b9031fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 30 deletions

View File

@ -33,9 +33,9 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel,
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper) build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS,
RUNNERS, VISUALIZERS, DefaultScope, PARAM_SCHEDULERS, RUNNERS, VISUALIZERS,
count_registered_modules) DefaultScope, count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version, from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
get_git_hash, is_list_of, is_seq_of, get_git_hash, is_list_of, is_seq_of,
@ -1067,26 +1067,29 @@ class Runner:
""" """
if isinstance(optim_wrapper, OptimWrapper): if isinstance(optim_wrapper, OptimWrapper):
return optim_wrapper return optim_wrapper
elif isinstance(optim_wrapper, (dict, ConfigDict, Config)): if isinstance(optim_wrapper, (dict, ConfigDict, Config)):
# If `optim_wrapper` is a config dict with only one optimizer, # optimizer must be defined for single optimizer training.
# the config dict must contain `optimizer`: optimizer = optim_wrapper.get('optimizer', None)
# optim_wrapper = dict(optimizer=dict(type='SGD', lr=0.1))
# `type` is optional, defaults to `OptimWrapper`. # If optimizer is a built `Optimizer` instance, the optimizer
# `optim_wrapper` could also be defined as: # wrapper should be built by `OPTIM_WRAPPERS` registry.
# optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.1)) # noqa: E501 if isinstance(optimizer, Optimizer):
# to build specific optimizer wrapper. optim_wrapper.setdefault('type', 'OptimWrapper')
if 'type' in optim_wrapper or 'optimizer' in optim_wrapper: return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
return optim_wrapper # If `optimizer` is not None or `constructor` is defined, it means,
elif 'constructor' not in optim_wrapper: # optimizer wrapper will be built by optimizer wrapper
# if `type` and `optimizer` are not defined in `optim_wrapper`, # constructor. Therefore, `build_optim_wrapper` should be called.
# it should be the case of training with multiple optimizers. if optimizer is not None or 'constructor' in optim_wrapper:
# If constructor is not defined in `optim_wrapper`, each value return build_optim_wrapper(self.model, optim_wrapper)
# of `optim_wrapper` must be an `OptimWrapper` instance since else:
# `DefaultOptimizerConstructor` will not handle the case of # if `optimizer` is not defined, it should be the case of
# training with multiple optimizers. `build_optim_wrapper` will # training with multiple optimizers. If `constructor` is not
# directly build the `OptimWrapperDict` instance from # defined either, Each value of `optim_wrapper` must be an
# `optim_wrapper.` # `OptimWrapper` instance since `DefaultOptimizerConstructor`
# will not handle the case of training with multiple
# optimizers. `build_optim_wrapper` will directly build the
# `OptimWrapperDict` instance from `optim_wrapper.`
optim_wrappers = OrderedDict() optim_wrappers = OrderedDict()
for name, optim in optim_wrapper.items(): for name, optim in optim_wrapper.items():
if not isinstance(optim, OptimWrapper): if not isinstance(optim, OptimWrapper):
@ -1096,11 +1099,6 @@ class Runner:
f'optimizer, but got {name}={optim}') f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers) return OptimWrapperDict(**optim_wrappers)
# If constructor is defined, directly build the optimizer
# wrapper instance from the config dict.
else:
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
return optim_wrapper
else: else:
raise TypeError('optimizer wrapper should be an OptimWrapper ' raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}') f'object or dict, but got {optim_wrapper}')

View File

@ -21,8 +21,9 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
RuntimeInfoHook) RuntimeInfoHook)
from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor,
OptimWrapper, OptimWrapperDict, StepLR) MultiStepLR, OptimWrapper, OptimWrapperDict,
StepLR)
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS, from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
LOOPS, METRICS, MODEL_WRAPPERS, MODELS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS, OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
@ -926,6 +927,23 @@ class TestRunner(TestCase):
self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD) self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD)
self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam) self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam)
# 2.4 input is a dict which contains optimizer instance.
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper_cfg = dict(optimizer=optimizer)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, OptimWrapper)
self.assertIs(optim_wrapper.optimizer, optimizer)
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper_cfg = dict(
optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, AmpOptimWrapper)
self.assertIs(optim_wrapper.optimizer, optimizer)
self.assertEqual(optim_wrapper._accumulative_counts, 2)
def test_build_param_scheduler(self): def test_build_param_scheduler(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_param_scheduler' cfg.experiment_name = 'test_build_param_scheduler'