Support default_scope when building optimizer and evaluator. (#109)

* Support default_scope when building optimizer and evaluator.

* add docstring

* fix

* fix
This commit is contained in:
RangiLyu 2022-03-08 16:05:29 +08:00 committed by GitHub
parent be6f18988e
commit 49b7d0ce6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 32 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
from typing import Optional, Union
from ..registry import EVALUATORS
from .base import BaseEvaluator
@ -7,14 +7,27 @@ from .composed_evaluator import ComposedEvaluator
def build_evaluator(
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
cfg: Union[dict, list],
default_scope: Optional[str] = None
) -> Union[BaseEvaluator, ComposedEvaluator]:
"""Build function of evaluator.
When the evaluator config is a list, it will automatically build composed
evaluators.
Args:
cfg (dict | list): Config of evaluator. When the config is a list, it
will automatically build composed evaluators.
default_scope (str, optional): The ``default_scope`` is used to
reset the current registry. Defaults to None.
Returns:
BaseEvaluator or ComposedEvaluator: The built evaluator.
"""
if isinstance(cfg, list):
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
evaluators = [
EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg
]
return ComposedEvaluator(evaluators=evaluators)
else:
return EVALUATORS.build(cfg)
return EVALUATORS.build(cfg, default_scope=default_scope)

View File

@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
DefaultOptimizerConstructor, build_optimizer)
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
CosineAnnealingLR, CosineAnnealingMomentum,
CosineAnnealingParamScheduler, ExponentialLR,
@ -13,11 +12,11 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
__all__ = [
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer',
'build_optimizer_constructor', 'DefaultOptimizerConstructor', 'ConstantLR',
'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR',
'ConstantMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum',
'LinearMomentum', 'MultiStepMomentum', 'StepMomentum',
'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
'DefaultOptimizerConstructor', 'ConstantLR', 'CosineAnnealingLR',
'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum',
'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum',
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
'_ParamScheduler'
]

View File

@ -1,9 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer,
build_optimizer_constructor)
from .builder import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer
from .default_constructor import DefaultOptimizerConstructor
__all__ = [
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor'
'build_optimizer'
]

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
from typing import Callable, List
from typing import List, Optional
import torch
import torch.nn as nn
@ -10,6 +10,11 @@ from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
def register_torch_optimizers() -> List[str]:
"""Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry.
Returns:
List[str]: A list of registered optimizers' name.
"""
torch_optimizers = []
for module_name in dir(torch.optim):
if module_name.startswith('__'):
@ -25,19 +30,35 @@ def register_torch_optimizers() -> List[str]:
TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer_constructor(cfg: dict) -> Callable:
return OPTIMIZER_CONSTRUCTORS.build(cfg)
def build_optimizer(
model: nn.Module,
cfg: dict,
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
"""Build function of optimizer.
If ``constructor`` is set in the ``cfg``, this method will build an
optimizer constructor, and use optimizer constructor to build the
optimizer. If ``constructor`` is not set, the
``DefaultOptimizerConstructor`` will be used by default.
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
Args:
model (nn.Module): Model to be optimized.
cfg (dict): Config of optimizer and optimizer constructor.
default_scope (str, optional): The ``default_scope`` is used to
reset the current registry. Defaults to None.
Returns:
torch.optim.Optimizer: The built optimizer.
"""
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
paramwise_cfg=paramwise_cfg),
default_scope=default_scope)
optimizer = optim_constructor(model, default_scope=default_scope)
return optimizer

View File

@ -6,8 +6,7 @@ import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
from mmengine.registry import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
build_from_cfg)
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
from mmengine.utils import is_list_of, mmcv_full_available
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
@ -242,7 +241,9 @@ class DefaultOptimizerConstructor:
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
def __call__(self,
model: nn.Module,
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
if hasattr(model, 'module'):
model = model.module
@ -250,11 +251,11 @@ class DefaultOptimizerConstructor:
# if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters()
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
# set param-wise lr and weight decay recursively
params: List = []
self.add_params(params, model)
optimizer_cfg['params'] = params
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)

View File

@ -7,8 +7,7 @@ import torch
import torch.nn as nn
from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
DefaultOptimizerConstructor, build_optimizer)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg
from mmengine.utils import mmcv_full_available
@ -236,7 +235,7 @@ class TestBuilder(TestCase):
type='DefaultOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
optimizer = optim_constructor(self.model)
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
@ -271,7 +270,7 @@ class TestBuilder(TestCase):
type='MyOptimizerConstructor',
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
optimizer = optim_constructor(self.model)
param_groups = optimizer.param_groups