mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Support default_scope when building optimizer and evaluator. (#109)
* Support default_scope when building optimizer and evaluator. * add docstring * fix * fix
This commit is contained in:
parent
be6f18988e
commit
49b7d0ce6f
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from ..registry import EVALUATORS
|
from ..registry import EVALUATORS
|
||||||
from .base import BaseEvaluator
|
from .base import BaseEvaluator
|
||||||
@ -7,14 +7,27 @@ from .composed_evaluator import ComposedEvaluator
|
|||||||
|
|
||||||
|
|
||||||
def build_evaluator(
|
def build_evaluator(
|
||||||
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
|
cfg: Union[dict, list],
|
||||||
|
default_scope: Optional[str] = None
|
||||||
|
) -> Union[BaseEvaluator, ComposedEvaluator]:
|
||||||
"""Build function of evaluator.
|
"""Build function of evaluator.
|
||||||
|
|
||||||
When the evaluator config is a list, it will automatically build composed
|
When the evaluator config is a list, it will automatically build composed
|
||||||
evaluators.
|
evaluators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict | list): Config of evaluator. When the config is a list, it
|
||||||
|
will automatically build composed evaluators.
|
||||||
|
default_scope (str, optional): The ``default_scope`` is used to
|
||||||
|
reset the current registry. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseEvaluator or ComposedEvaluator: The built evaluator.
|
||||||
"""
|
"""
|
||||||
if isinstance(cfg, list):
|
if isinstance(cfg, list):
|
||||||
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
|
evaluators = [
|
||||||
|
EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg
|
||||||
|
]
|
||||||
return ComposedEvaluator(evaluators=evaluators)
|
return ComposedEvaluator(evaluators=evaluators)
|
||||||
else:
|
else:
|
||||||
return EVALUATORS.build(cfg)
|
return EVALUATORS.build(cfg, default_scope=default_scope)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
||||||
DefaultOptimizerConstructor, build_optimizer,
|
DefaultOptimizerConstructor, build_optimizer)
|
||||||
build_optimizer_constructor)
|
|
||||||
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||||
CosineAnnealingLR, CosineAnnealingMomentum,
|
CosineAnnealingLR, CosineAnnealingMomentum,
|
||||||
CosineAnnealingParamScheduler, ExponentialLR,
|
CosineAnnealingParamScheduler, ExponentialLR,
|
||||||
@ -13,11 +12,11 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer',
|
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer',
|
||||||
'build_optimizer_constructor', 'DefaultOptimizerConstructor', 'ConstantLR',
|
'DefaultOptimizerConstructor', 'ConstantLR', 'CosineAnnealingLR',
|
||||||
'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR',
|
'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum',
|
||||||
'ConstantMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum',
|
'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum',
|
||||||
'LinearMomentum', 'MultiStepMomentum', 'StepMomentum',
|
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
|
||||||
'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
|
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
|
||||||
'ExponentialParamScheduler', 'LinearParamScheduler',
|
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
|
||||||
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
|
'_ParamScheduler'
|
||||||
]
|
]
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .builder import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer,
|
from .builder import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer
|
||||||
build_optimizer_constructor)
|
|
||||||
from .default_constructor import DefaultOptimizerConstructor
|
from .default_constructor import DefaultOptimizerConstructor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
||||||
'build_optimizer', 'build_optimizer_constructor'
|
'build_optimizer'
|
||||||
]
|
]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, List
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -10,6 +10,11 @@ from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
|||||||
|
|
||||||
|
|
||||||
def register_torch_optimizers() -> List[str]:
|
def register_torch_optimizers() -> List[str]:
|
||||||
|
"""Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of registered optimizers' name.
|
||||||
|
"""
|
||||||
torch_optimizers = []
|
torch_optimizers = []
|
||||||
for module_name in dir(torch.optim):
|
for module_name in dir(torch.optim):
|
||||||
if module_name.startswith('__'):
|
if module_name.startswith('__'):
|
||||||
@ -25,19 +30,35 @@ def register_torch_optimizers() -> List[str]:
|
|||||||
TORCH_OPTIMIZERS = register_torch_optimizers()
|
TORCH_OPTIMIZERS = register_torch_optimizers()
|
||||||
|
|
||||||
|
|
||||||
def build_optimizer_constructor(cfg: dict) -> Callable:
|
def build_optimizer(
|
||||||
return OPTIMIZER_CONSTRUCTORS.build(cfg)
|
model: nn.Module,
|
||||||
|
cfg: dict,
|
||||||
|
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
|
||||||
|
"""Build function of optimizer.
|
||||||
|
|
||||||
|
If ``constructor`` is set in the ``cfg``, this method will build an
|
||||||
|
optimizer constructor, and use optimizer constructor to build the
|
||||||
|
optimizer. If ``constructor`` is not set, the
|
||||||
|
``DefaultOptimizerConstructor`` will be used by default.
|
||||||
|
|
||||||
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
|
Args:
|
||||||
|
model (nn.Module): Model to be optimized.
|
||||||
|
cfg (dict): Config of optimizer and optimizer constructor.
|
||||||
|
default_scope (str, optional): The ``default_scope`` is used to
|
||||||
|
reset the current registry. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.optim.Optimizer: The built optimizer.
|
||||||
|
"""
|
||||||
optimizer_cfg = copy.deepcopy(cfg)
|
optimizer_cfg = copy.deepcopy(cfg)
|
||||||
constructor_type = optimizer_cfg.pop('constructor',
|
constructor_type = optimizer_cfg.pop('constructor',
|
||||||
'DefaultOptimizerConstructor')
|
'DefaultOptimizerConstructor')
|
||||||
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
||||||
optim_constructor = build_optimizer_constructor(
|
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(
|
||||||
dict(
|
dict(
|
||||||
type=constructor_type,
|
type=constructor_type,
|
||||||
optimizer_cfg=optimizer_cfg,
|
optimizer_cfg=optimizer_cfg,
|
||||||
paramwise_cfg=paramwise_cfg))
|
paramwise_cfg=paramwise_cfg),
|
||||||
optimizer = optim_constructor(model)
|
default_scope=default_scope)
|
||||||
|
optimizer = optim_constructor(model, default_scope=default_scope)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
@ -6,8 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import GroupNorm, LayerNorm
|
from torch.nn import GroupNorm, LayerNorm
|
||||||
|
|
||||||
from mmengine.registry import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
||||||
build_from_cfg)
|
|
||||||
from mmengine.utils import is_list_of, mmcv_full_available
|
from mmengine.utils import is_list_of, mmcv_full_available
|
||||||
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||||
|
|
||||||
@ -242,7 +241,9 @@ class DefaultOptimizerConstructor:
|
|||||||
prefix=child_prefix,
|
prefix=child_prefix,
|
||||||
is_dcn_module=is_dcn_module)
|
is_dcn_module=is_dcn_module)
|
||||||
|
|
||||||
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
|
def __call__(self,
|
||||||
|
model: nn.Module,
|
||||||
|
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
|
|
||||||
@ -250,11 +251,11 @@ class DefaultOptimizerConstructor:
|
|||||||
# if no paramwise option is specified, just use the global setting
|
# if no paramwise option is specified, just use the global setting
|
||||||
if not self.paramwise_cfg:
|
if not self.paramwise_cfg:
|
||||||
optimizer_cfg['params'] = model.parameters()
|
optimizer_cfg['params'] = model.parameters()
|
||||||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
|
||||||
|
|
||||||
# set param-wise lr and weight decay recursively
|
# set param-wise lr and weight decay recursively
|
||||||
params: List = []
|
params: List = []
|
||||||
self.add_params(params, model)
|
self.add_params(params, model)
|
||||||
optimizer_cfg['params'] = params
|
optimizer_cfg['params'] = params
|
||||||
|
|
||||||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
|
||||||
|
@ -7,8 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
||||||
DefaultOptimizerConstructor, build_optimizer,
|
DefaultOptimizerConstructor, build_optimizer)
|
||||||
build_optimizer_constructor)
|
|
||||||
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
|
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
|
||||||
from mmengine.registry import build_from_cfg
|
from mmengine.registry import build_from_cfg
|
||||||
from mmengine.utils import mmcv_full_available
|
from mmengine.utils import mmcv_full_available
|
||||||
@ -236,7 +235,7 @@ class TestBuilder(TestCase):
|
|||||||
type='DefaultOptimizerConstructor',
|
type='DefaultOptimizerConstructor',
|
||||||
optimizer_cfg=optimizer_cfg,
|
optimizer_cfg=optimizer_cfg,
|
||||||
paramwise_cfg=paramwise_cfg)
|
paramwise_cfg=paramwise_cfg)
|
||||||
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
|
||||||
optimizer = optim_constructor(self.model)
|
optimizer = optim_constructor(self.model)
|
||||||
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
|
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
|
||||||
|
|
||||||
@ -271,7 +270,7 @@ class TestBuilder(TestCase):
|
|||||||
type='MyOptimizerConstructor',
|
type='MyOptimizerConstructor',
|
||||||
optimizer_cfg=optimizer_cfg,
|
optimizer_cfg=optimizer_cfg,
|
||||||
paramwise_cfg=paramwise_cfg)
|
paramwise_cfg=paramwise_cfg)
|
||||||
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
|
||||||
optimizer = optim_constructor(self.model)
|
optimizer = optim_constructor(self.model)
|
||||||
|
|
||||||
param_groups = optimizer.param_groups
|
param_groups = optimizer.param_groups
|
||||||
|
Loading…
x
Reference in New Issue
Block a user