mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Support default_scope when building optimizer and evaluator. (#109)
* Support default_scope when building optimizer and evaluator. * add docstring * fix * fix
This commit is contained in:
parent
be6f18988e
commit
49b7d0ce6f
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from ..registry import EVALUATORS
|
||||
from .base import BaseEvaluator
|
||||
@ -7,14 +7,27 @@ from .composed_evaluator import ComposedEvaluator
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
|
||||
cfg: Union[dict, list],
|
||||
default_scope: Optional[str] = None
|
||||
) -> Union[BaseEvaluator, ComposedEvaluator]:
|
||||
"""Build function of evaluator.
|
||||
|
||||
When the evaluator config is a list, it will automatically build composed
|
||||
evaluators.
|
||||
|
||||
Args:
|
||||
cfg (dict | list): Config of evaluator. When the config is a list, it
|
||||
will automatically build composed evaluators.
|
||||
default_scope (str, optional): The ``default_scope`` is used to
|
||||
reset the current registry. Defaults to None.
|
||||
|
||||
Returns:
|
||||
BaseEvaluator or ComposedEvaluator: The built evaluator.
|
||||
"""
|
||||
if isinstance(cfg, list):
|
||||
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
|
||||
evaluators = [
|
||||
EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg
|
||||
]
|
||||
return ComposedEvaluator(evaluators=evaluators)
|
||||
else:
|
||||
return EVALUATORS.build(cfg)
|
||||
return EVALUATORS.build(cfg, default_scope=default_scope)
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
||||
DefaultOptimizerConstructor, build_optimizer,
|
||||
build_optimizer_constructor)
|
||||
DefaultOptimizerConstructor, build_optimizer)
|
||||
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||
CosineAnnealingLR, CosineAnnealingMomentum,
|
||||
CosineAnnealingParamScheduler, ExponentialLR,
|
||||
@ -13,11 +12,11 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||
|
||||
__all__ = [
|
||||
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer',
|
||||
'build_optimizer_constructor', 'DefaultOptimizerConstructor', 'ConstantLR',
|
||||
'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR',
|
||||
'ConstantMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum',
|
||||
'LinearMomentum', 'MultiStepMomentum', 'StepMomentum',
|
||||
'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
|
||||
'ExponentialParamScheduler', 'LinearParamScheduler',
|
||||
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
|
||||
'DefaultOptimizerConstructor', 'ConstantLR', 'CosineAnnealingLR',
|
||||
'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum',
|
||||
'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum',
|
||||
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
|
||||
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
|
||||
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
|
||||
'_ParamScheduler'
|
||||
]
|
||||
|
@ -1,9 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer,
|
||||
build_optimizer_constructor)
|
||||
from .builder import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer
|
||||
from .default_constructor import DefaultOptimizerConstructor
|
||||
|
||||
__all__ = [
|
||||
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
||||
'build_optimizer', 'build_optimizer_constructor'
|
||||
'build_optimizer'
|
||||
]
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Callable, List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,6 +10,11 @@ from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
||||
|
||||
|
||||
def register_torch_optimizers() -> List[str]:
|
||||
"""Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of registered optimizers' name.
|
||||
"""
|
||||
torch_optimizers = []
|
||||
for module_name in dir(torch.optim):
|
||||
if module_name.startswith('__'):
|
||||
@ -25,19 +30,35 @@ def register_torch_optimizers() -> List[str]:
|
||||
TORCH_OPTIMIZERS = register_torch_optimizers()
|
||||
|
||||
|
||||
def build_optimizer_constructor(cfg: dict) -> Callable:
|
||||
return OPTIMIZER_CONSTRUCTORS.build(cfg)
|
||||
def build_optimizer(
|
||||
model: nn.Module,
|
||||
cfg: dict,
|
||||
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
|
||||
"""Build function of optimizer.
|
||||
|
||||
If ``constructor`` is set in the ``cfg``, this method will build an
|
||||
optimizer constructor, and use optimizer constructor to build the
|
||||
optimizer. If ``constructor`` is not set, the
|
||||
``DefaultOptimizerConstructor`` will be used by default.
|
||||
|
||||
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
|
||||
Args:
|
||||
model (nn.Module): Model to be optimized.
|
||||
cfg (dict): Config of optimizer and optimizer constructor.
|
||||
default_scope (str, optional): The ``default_scope`` is used to
|
||||
reset the current registry. Defaults to None.
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: The built optimizer.
|
||||
"""
|
||||
optimizer_cfg = copy.deepcopy(cfg)
|
||||
constructor_type = optimizer_cfg.pop('constructor',
|
||||
'DefaultOptimizerConstructor')
|
||||
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
||||
optim_constructor = build_optimizer_constructor(
|
||||
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(
|
||||
dict(
|
||||
type=constructor_type,
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
paramwise_cfg=paramwise_cfg))
|
||||
optimizer = optim_constructor(model)
|
||||
paramwise_cfg=paramwise_cfg),
|
||||
default_scope=default_scope)
|
||||
optimizer = optim_constructor(model, default_scope=default_scope)
|
||||
return optimizer
|
||||
|
@ -6,8 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import GroupNorm, LayerNorm
|
||||
|
||||
from mmengine.registry import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
||||
build_from_cfg)
|
||||
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
||||
from mmengine.utils import is_list_of, mmcv_full_available
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
|
||||
@ -242,7 +241,9 @@ class DefaultOptimizerConstructor:
|
||||
prefix=child_prefix,
|
||||
is_dcn_module=is_dcn_module)
|
||||
|
||||
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
|
||||
def __call__(self,
|
||||
model: nn.Module,
|
||||
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
@ -250,11 +251,11 @@ class DefaultOptimizerConstructor:
|
||||
# if no paramwise option is specified, just use the global setting
|
||||
if not self.paramwise_cfg:
|
||||
optimizer_cfg['params'] = model.parameters()
|
||||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
||||
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
|
||||
|
||||
# set param-wise lr and weight decay recursively
|
||||
params: List = []
|
||||
self.add_params(params, model)
|
||||
optimizer_cfg['params'] = params
|
||||
|
||||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
||||
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
|
||||
|
@ -7,8 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
||||
DefaultOptimizerConstructor, build_optimizer,
|
||||
build_optimizer_constructor)
|
||||
DefaultOptimizerConstructor, build_optimizer)
|
||||
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
|
||||
from mmengine.registry import build_from_cfg
|
||||
from mmengine.utils import mmcv_full_available
|
||||
@ -236,7 +235,7 @@ class TestBuilder(TestCase):
|
||||
type='DefaultOptimizerConstructor',
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
paramwise_cfg=paramwise_cfg)
|
||||
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
||||
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
|
||||
|
||||
@ -271,7 +270,7 @@ class TestBuilder(TestCase):
|
||||
type='MyOptimizerConstructor',
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
paramwise_cfg=paramwise_cfg)
|
||||
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
||||
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
|
||||
optimizer = optim_constructor(self.model)
|
||||
|
||||
param_groups = optimizer.param_groups
|
||||
|
Loading…
x
Reference in New Issue
Block a user