Support default_scope when building optimizer and evaluator. (#109)

* Support default_scope when building optimizer and evaluator.

* add docstring

* fix

* fix
This commit is contained in:
RangiLyu 2022-03-08 16:05:29 +08:00 committed by GitHub
parent be6f18988e
commit 49b7d0ce6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 32 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Union from typing import Optional, Union
from ..registry import EVALUATORS from ..registry import EVALUATORS
from .base import BaseEvaluator from .base import BaseEvaluator
@ -7,14 +7,27 @@ from .composed_evaluator import ComposedEvaluator
def build_evaluator( def build_evaluator(
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]: cfg: Union[dict, list],
default_scope: Optional[str] = None
) -> Union[BaseEvaluator, ComposedEvaluator]:
"""Build function of evaluator. """Build function of evaluator.
When the evaluator config is a list, it will automatically build composed When the evaluator config is a list, it will automatically build composed
evaluators. evaluators.
Args:
cfg (dict | list): Config of evaluator. When the config is a list, it
will automatically build composed evaluators.
default_scope (str, optional): The ``default_scope`` is used to
reset the current registry. Defaults to None.
Returns:
BaseEvaluator or ComposedEvaluator: The built evaluator.
""" """
if isinstance(cfg, list): if isinstance(cfg, list):
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg] evaluators = [
EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg
]
return ComposedEvaluator(evaluators=evaluators) return ComposedEvaluator(evaluators=evaluators)
else: else:
return EVALUATORS.build(cfg) return EVALUATORS.build(cfg, default_scope=default_scope)

View File

@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer, DefaultOptimizerConstructor, build_optimizer)
build_optimizer_constructor)
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
CosineAnnealingLR, CosineAnnealingMomentum, CosineAnnealingLR, CosineAnnealingMomentum,
CosineAnnealingParamScheduler, ExponentialLR, CosineAnnealingParamScheduler, ExponentialLR,
@ -13,11 +12,11 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
__all__ = [ __all__ = [
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer', 'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer',
'build_optimizer_constructor', 'DefaultOptimizerConstructor', 'ConstantLR', 'DefaultOptimizerConstructor', 'ConstantLR', 'CosineAnnealingLR',
'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum',
'ConstantMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum',
'LinearMomentum', 'MultiStepMomentum', 'StepMomentum', 'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
'ConstantParamScheduler', 'CosineAnnealingParamScheduler', 'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler', 'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler' '_ParamScheduler'
] ]

View File

@ -1,9 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer, from .builder import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer
build_optimizer_constructor)
from .default_constructor import DefaultOptimizerConstructor from .default_constructor import DefaultOptimizerConstructor
__all__ = [ __all__ = [
'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
'build_optimizer', 'build_optimizer_constructor' 'build_optimizer'
] ]

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import inspect import inspect
from typing import Callable, List from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,6 +10,11 @@ from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
def register_torch_optimizers() -> List[str]: def register_torch_optimizers() -> List[str]:
"""Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry.
Returns:
List[str]: A list of registered optimizers' name.
"""
torch_optimizers = [] torch_optimizers = []
for module_name in dir(torch.optim): for module_name in dir(torch.optim):
if module_name.startswith('__'): if module_name.startswith('__'):
@ -25,19 +30,35 @@ def register_torch_optimizers() -> List[str]:
TORCH_OPTIMIZERS = register_torch_optimizers() TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer_constructor(cfg: dict) -> Callable: def build_optimizer(
return OPTIMIZER_CONSTRUCTORS.build(cfg) model: nn.Module,
cfg: dict,
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
"""Build function of optimizer.
If ``constructor`` is set in the ``cfg``, this method will build an
optimizer constructor, and use optimizer constructor to build the
optimizer. If ``constructor`` is not set, the
``DefaultOptimizerConstructor`` will be used by default.
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer: Args:
model (nn.Module): Model to be optimized.
cfg (dict): Config of optimizer and optimizer constructor.
default_scope (str, optional): The ``default_scope`` is used to
reset the current registry. Defaults to None.
Returns:
torch.optim.Optimizer: The built optimizer.
"""
optimizer_cfg = copy.deepcopy(cfg) optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor', constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor') 'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor( optim_constructor = OPTIMIZER_CONSTRUCTORS.build(
dict( dict(
type=constructor_type, type=constructor_type,
optimizer_cfg=optimizer_cfg, optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg)) paramwise_cfg=paramwise_cfg),
optimizer = optim_constructor(model) default_scope=default_scope)
optimizer = optim_constructor(model, default_scope=default_scope)
return optimizer return optimizer

View File

@ -6,8 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm from torch.nn import GroupNorm, LayerNorm
from mmengine.registry import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
build_from_cfg)
from mmengine.utils import is_list_of, mmcv_full_available from mmengine.utils import is_list_of, mmcv_full_available
from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm
@ -242,7 +241,9 @@ class DefaultOptimizerConstructor:
prefix=child_prefix, prefix=child_prefix,
is_dcn_module=is_dcn_module) is_dcn_module=is_dcn_module)
def __call__(self, model: nn.Module) -> torch.optim.Optimizer: def __call__(self,
model: nn.Module,
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
@ -250,11 +251,11 @@ class DefaultOptimizerConstructor:
# if no paramwise option is specified, just use the global setting # if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg: if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters() optimizer_cfg['params'] = model.parameters()
return build_from_cfg(optimizer_cfg, OPTIMIZERS) return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
# set param-wise lr and weight decay recursively # set param-wise lr and weight decay recursively
params: List = [] params: List = []
self.add_params(params, model) self.add_params(params, model)
optimizer_cfg['params'] = params optimizer_cfg['params'] = params
return build_from_cfg(optimizer_cfg, OPTIMIZERS) return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)

View File

@ -7,8 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer, DefaultOptimizerConstructor, build_optimizer)
build_optimizer_constructor)
from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS
from mmengine.registry import build_from_cfg from mmengine.registry import build_from_cfg
from mmengine.utils import mmcv_full_available from mmengine.utils import mmcv_full_available
@ -236,7 +235,7 @@ class TestBuilder(TestCase):
type='DefaultOptimizerConstructor', type='DefaultOptimizerConstructor',
optimizer_cfg=optimizer_cfg, optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg) paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg) optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
optimizer = optim_constructor(self.model) optimizer = optim_constructor(self.model)
self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg) self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg)
@ -271,7 +270,7 @@ class TestBuilder(TestCase):
type='MyOptimizerConstructor', type='MyOptimizerConstructor',
optimizer_cfg=optimizer_cfg, optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg) paramwise_cfg=paramwise_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg) optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg)
optimizer = optim_constructor(self.model) optimizer = optim_constructor(self.model)
param_groups = optimizer.param_groups param_groups = optimizer.param_groups