diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index 0b4f67d4..95233d86 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -299,7 +299,8 @@ class DefaultOptimWrapperConstructor: # `model_params` rather than `params`. Here we get the first argument # name and fill it with the model parameters. if isinstance(optimizer_cls, str): - optimizer_cls = OPTIMIZERS.get(self.optimizer_cfg['type']) + with OPTIMIZERS.switch_scope_and_registry(None) as registry: + optimizer_cls = registry.get(self.optimizer_cfg['type']) fisrt_arg_name = next( iter(inspect.signature(optimizer_cls).parameters)) # if no paramwise option is specified, just use the global setting diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index e4089a4e..9f851bd3 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -17,7 +17,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS, LION_OPTIMIZERS, TORCH_OPTIMIZERS) -from mmengine.registry import build_from_cfg +from mmengine.registry import DefaultScope, Registry, build_from_cfg from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available from mmengine.utils.version_utils import digit_version @@ -391,6 +391,22 @@ class TestBuilder(TestCase): optim_wrapper = optim_constructor(self.model) self._check_default_optimizer(optim_wrapper.optimizer, self.model) + # Support building custom optimizers + CUSTOM_OPTIMIZERS = Registry( + 'custom optimizer', scope='custom optimizer', parent=OPTIMIZERS) + + class CustomOptimizer(torch.optim.SGD): + + def __init__(self, model_params, *args, **kwargs): + super().__init__(params=model_params, *args, **kwargs) + + CUSTOM_OPTIMIZERS.register_module()(CustomOptimizer) + optimizer_cfg = dict(optimizer=dict(type='CustomOptimizer', lr=0.1), ) + with DefaultScope.overwrite_default_scope('custom optimizer'): + optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) + optim_wrapper = optim_constructor(self.model) + OPTIMIZERS.children.pop('custom optimizer') + def test_default_optimizer_constructor_with_model_wrapper(self): # basic config with pseudo data parallel model = PseudoDataParallel()