mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix get optimizer_cls (#1324)
This commit is contained in:
parent
714c8eedc3
commit
170758aefe
@ -299,7 +299,8 @@ class DefaultOptimWrapperConstructor:
|
|||||||
# `model_params` rather than `params`. Here we get the first argument
|
# `model_params` rather than `params`. Here we get the first argument
|
||||||
# name and fill it with the model parameters.
|
# name and fill it with the model parameters.
|
||||||
if isinstance(optimizer_cls, str):
|
if isinstance(optimizer_cls, str):
|
||||||
optimizer_cls = OPTIMIZERS.get(self.optimizer_cfg['type'])
|
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
|
||||||
|
optimizer_cls = registry.get(self.optimizer_cfg['type'])
|
||||||
fisrt_arg_name = next(
|
fisrt_arg_name = next(
|
||||||
iter(inspect.signature(optimizer_cls).parameters))
|
iter(inspect.signature(optimizer_cls).parameters))
|
||||||
# if no paramwise option is specified, just use the global setting
|
# if no paramwise option is specified, just use the global setting
|
||||||
|
@ -17,7 +17,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
|||||||
from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS,
|
from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS,
|
||||||
LION_OPTIMIZERS,
|
LION_OPTIMIZERS,
|
||||||
TORCH_OPTIMIZERS)
|
TORCH_OPTIMIZERS)
|
||||||
from mmengine.registry import build_from_cfg
|
from mmengine.registry import DefaultScope, Registry, build_from_cfg
|
||||||
from mmengine.testing._internal import MultiProcessTestCase
|
from mmengine.testing._internal import MultiProcessTestCase
|
||||||
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
|
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
|
||||||
from mmengine.utils.version_utils import digit_version
|
from mmengine.utils.version_utils import digit_version
|
||||||
@ -391,6 +391,22 @@ class TestBuilder(TestCase):
|
|||||||
optim_wrapper = optim_constructor(self.model)
|
optim_wrapper = optim_constructor(self.model)
|
||||||
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
|
self._check_default_optimizer(optim_wrapper.optimizer, self.model)
|
||||||
|
|
||||||
|
# Support building custom optimizers
|
||||||
|
CUSTOM_OPTIMIZERS = Registry(
|
||||||
|
'custom optimizer', scope='custom optimizer', parent=OPTIMIZERS)
|
||||||
|
|
||||||
|
class CustomOptimizer(torch.optim.SGD):
|
||||||
|
|
||||||
|
def __init__(self, model_params, *args, **kwargs):
|
||||||
|
super().__init__(params=model_params, *args, **kwargs)
|
||||||
|
|
||||||
|
CUSTOM_OPTIMIZERS.register_module()(CustomOptimizer)
|
||||||
|
optimizer_cfg = dict(optimizer=dict(type='CustomOptimizer', lr=0.1), )
|
||||||
|
with DefaultScope.overwrite_default_scope('custom optimizer'):
|
||||||
|
optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg)
|
||||||
|
optim_wrapper = optim_constructor(self.model)
|
||||||
|
OPTIMIZERS.children.pop('custom optimizer')
|
||||||
|
|
||||||
def test_default_optimizer_constructor_with_model_wrapper(self):
|
def test_default_optimizer_constructor_with_model_wrapper(self):
|
||||||
# basic config with pseudo data parallel
|
# basic config with pseudo data parallel
|
||||||
model = PseudoDataParallel()
|
model = PseudoDataParallel()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user