[Fix] Fix get optimizer_cls (#1324)

This commit is contained in:
Mashiro 2023-08-28 16:15:00 +08:00 committed by GitHub
parent 714c8eedc3
commit 170758aefe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 2 deletions

View File

@ -299,7 +299,8 @@ class DefaultOptimWrapperConstructor:
# `model_params` rather than `params`. Here we get the first argument
# name and fill it with the model parameters.
if isinstance(optimizer_cls, str):
optimizer_cls = OPTIMIZERS.get(self.optimizer_cfg['type'])
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
optimizer_cls = registry.get(self.optimizer_cfg['type'])
fisrt_arg_name = next(
iter(inspect.signature(optimizer_cls).parameters))
# if no paramwise option is specified, just use the global setting

View File

@ -17,7 +17,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS,
LION_OPTIMIZERS,
TORCH_OPTIMIZERS)
from mmengine.registry import build_from_cfg
from mmengine.registry import DefaultScope, Registry, build_from_cfg
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
from mmengine.utils.version_utils import digit_version
@ -391,6 +391,22 @@ class TestBuilder(TestCase):
optim_wrapper = optim_constructor(self.model)
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
# Support building custom optimizers
CUSTOM_OPTIMIZERS = Registry(
'custom optimizer', scope='custom optimizer', parent=OPTIMIZERS)
class CustomOptimizer(torch.optim.SGD):
def __init__(self, model_params, *args, **kwargs):
super().__init__(params=model_params, *args, **kwargs)
CUSTOM_OPTIMIZERS.register_module()(CustomOptimizer)
optimizer_cfg = dict(optimizer=dict(type='CustomOptimizer', lr=0.1), )
with DefaultScope.overwrite_default_scope('custom optimizer'):
optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg)
optim_wrapper = optim_constructor(self.model)
OPTIMIZERS.children.pop('custom optimizer')
def test_default_optimizer_constructor_with_model_wrapper(self):
# basic config with pseudo data parallel
model = PseudoDataParallel()