From e08b9031fc2cfe2b352806838ab372edda3a21f0 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 17 Aug 2022 19:17:00 +0800 Subject: [PATCH] [Enhance] Support building optimizer wrapper from built Optimizer instance (#422) * support build optimizer wrapper from built Optimizer instance * refine comments --- mmengine/runner/runner.py | 54 +++++++++++++++----------------- tests/test_runner/test_runner.py | 22 +++++++++++-- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index e1a3fe84..25b5004a 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -33,9 +33,9 @@ from mmengine.model import (BaseModel, MMDistributedDataParallel, from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, build_optim_wrapper) from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, - LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, - RUNNERS, VISUALIZERS, DefaultScope, - count_registered_modules) + LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS, + PARAM_SCHEDULERS, RUNNERS, VISUALIZERS, + DefaultScope, count_registered_modules) from mmengine.registry.root import LOG_PROCESSORS from mmengine.utils import (TORCH_VERSION, collect_env, digit_version, get_git_hash, is_list_of, is_seq_of, @@ -1067,26 +1067,29 @@ class Runner: """ if isinstance(optim_wrapper, OptimWrapper): return optim_wrapper - elif isinstance(optim_wrapper, (dict, ConfigDict, Config)): - # If `optim_wrapper` is a config dict with only one optimizer, - # the config dict must contain `optimizer`: - # optim_wrapper = dict(optimizer=dict(type='SGD', lr=0.1)) - # `type` is optional, defaults to `OptimWrapper`. - # `optim_wrapper` could also be defined as: - # optim_wrapper = dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.1)) # noqa: E501 - # to build specific optimizer wrapper. - if 'type' in optim_wrapper or 'optimizer' in optim_wrapper: - optim_wrapper = build_optim_wrapper(self.model, optim_wrapper) - return optim_wrapper - elif 'constructor' not in optim_wrapper: - # if `type` and `optimizer` are not defined in `optim_wrapper`, - # it should be the case of training with multiple optimizers. - # If constructor is not defined in `optim_wrapper`, each value - # of `optim_wrapper` must be an `OptimWrapper` instance since - # `DefaultOptimizerConstructor` will not handle the case of - # training with multiple optimizers. `build_optim_wrapper` will - # directly build the `OptimWrapperDict` instance from - # `optim_wrapper.` + if isinstance(optim_wrapper, (dict, ConfigDict, Config)): + # optimizer must be defined for single optimizer training. + optimizer = optim_wrapper.get('optimizer', None) + + # If optimizer is a built `Optimizer` instance, the optimizer + # wrapper should be built by `OPTIM_WRAPPERS` registry. + if isinstance(optimizer, Optimizer): + optim_wrapper.setdefault('type', 'OptimWrapper') + return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore + + # If `optimizer` is not None or `constructor` is defined, it means, + # optimizer wrapper will be built by optimizer wrapper + # constructor. Therefore, `build_optim_wrapper` should be called. + if optimizer is not None or 'constructor' in optim_wrapper: + return build_optim_wrapper(self.model, optim_wrapper) + else: + # if `optimizer` is not defined, it should be the case of + # training with multiple optimizers. If `constructor` is not + # defined either, Each value of `optim_wrapper` must be an + # `OptimWrapper` instance since `DefaultOptimizerConstructor` + # will not handle the case of training with multiple + # optimizers. `build_optim_wrapper` will directly build the + # `OptimWrapperDict` instance from `optim_wrapper.` optim_wrappers = OrderedDict() for name, optim in optim_wrapper.items(): if not isinstance(optim, OptimWrapper): @@ -1096,11 +1099,6 @@ class Runner: f'optimizer, but got {name}={optim}') optim_wrappers[name] = optim return OptimWrapperDict(**optim_wrappers) - # If constructor is defined, directly build the optimizer - # wrapper instance from the config dict. - else: - optim_wrapper = build_optim_wrapper(self.model, optim_wrapper) - return optim_wrapper else: raise TypeError('optimizer wrapper should be an OptimWrapper ' f'object or dict, but got {optim_wrapper}') diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index c414e188..40494d92 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -21,8 +21,9 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, RuntimeInfoHook) from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor -from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, - OptimWrapper, OptimWrapperDict, StepLR) +from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor, + MultiStepLR, OptimWrapper, OptimWrapperDict, + StepLR) from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS, @@ -926,6 +927,23 @@ class TestRunner(TestCase): self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD) self.assertIsInstance(optim_wrapper['linear2'].optimizer, Adam) + # 2.4 input is a dict which contains optimizer instance. + model = nn.Linear(1, 1) + optimizer = SGD(model.parameters(), lr=0.1) + optim_wrapper_cfg = dict(optimizer=optimizer) + optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) + self.assertIsInstance(optim_wrapper, OptimWrapper) + self.assertIs(optim_wrapper.optimizer, optimizer) + + model = nn.Linear(1, 1) + optimizer = SGD(model.parameters(), lr=0.1) + optim_wrapper_cfg = dict( + optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2) + optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) + self.assertIsInstance(optim_wrapper, AmpOptimWrapper) + self.assertIs(optim_wrapper.optimizer, optimizer) + self.assertEqual(optim_wrapper._accumulative_counts, 2) + def test_build_param_scheduler(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_build_param_scheduler'