mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix get optimizer_cls (#1324)
This commit is contained in:
parent
714c8eedc3
commit
170758aefe
@ -299,7 +299,8 @@ class DefaultOptimWrapperConstructor:
|
||||
# `model_params` rather than `params`. Here we get the first argument
|
||||
# name and fill it with the model parameters.
|
||||
if isinstance(optimizer_cls, str):
|
||||
optimizer_cls = OPTIMIZERS.get(self.optimizer_cfg['type'])
|
||||
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
|
||||
optimizer_cls = registry.get(self.optimizer_cfg['type'])
|
||||
fisrt_arg_name = next(
|
||||
iter(inspect.signature(optimizer_cls).parameters))
|
||||
# if no paramwise option is specified, just use the global setting
|
||||
|
@ -17,7 +17,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS,
|
||||
LION_OPTIMIZERS,
|
||||
TORCH_OPTIMIZERS)
|
||||
from mmengine.registry import build_from_cfg
|
||||
from mmengine.registry import DefaultScope, Registry, build_from_cfg
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
|
||||
from mmengine.utils.version_utils import digit_version
|
||||
@ -391,6 +391,22 @@ class TestBuilder(TestCase):
|
||||
optim_wrapper = optim_constructor(self.model)
|
||||
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
|
||||
|
||||
# Support building custom optimizers
|
||||
CUSTOM_OPTIMIZERS = Registry(
|
||||
'custom optimizer', scope='custom optimizer', parent=OPTIMIZERS)
|
||||
|
||||
class CustomOptimizer(torch.optim.SGD):
|
||||
|
||||
def __init__(self, model_params, *args, **kwargs):
|
||||
super().__init__(params=model_params, *args, **kwargs)
|
||||
|
||||
CUSTOM_OPTIMIZERS.register_module()(CustomOptimizer)
|
||||
optimizer_cfg = dict(optimizer=dict(type='CustomOptimizer', lr=0.1), )
|
||||
with DefaultScope.overwrite_default_scope('custom optimizer'):
|
||||
optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg)
|
||||
optim_wrapper = optim_constructor(self.model)
|
||||
OPTIMIZERS.children.pop('custom optimizer')
|
||||
|
||||
def test_default_optimizer_constructor_with_model_wrapper(self):
|
||||
# basic config with pseudo data parallel
|
||||
model = PseudoDataParallel()
|
||||
|
Loading…
x
Reference in New Issue
Block a user