diff --git a/mmengine/evaluator/builder.py b/mmengine/evaluator/builder.py index 2a8fb3d8..fcc80031 100644 --- a/mmengine/evaluator/builder.py +++ b/mmengine/evaluator/builder.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Union +from typing import Optional, Union from ..registry import EVALUATORS from .base import BaseEvaluator @@ -7,14 +7,27 @@ from .composed_evaluator import ComposedEvaluator def build_evaluator( - cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]: + cfg: Union[dict, list], + default_scope: Optional[str] = None +) -> Union[BaseEvaluator, ComposedEvaluator]: """Build function of evaluator. When the evaluator config is a list, it will automatically build composed evaluators. + + Args: + cfg (dict | list): Config of evaluator. When the config is a list, it + will automatically build composed evaluators. + default_scope (str, optional): The ``default_scope`` is used to + reset the current registry. Defaults to None. + + Returns: + BaseEvaluator or ComposedEvaluator: The built evaluator. """ if isinstance(cfg, list): - evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg] + evaluators = [ + EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg + ] return ComposedEvaluator(evaluators=evaluators) else: - return EVALUATORS.build(cfg) + return EVALUATORS.build(cfg, default_scope=default_scope) diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py index 98832cfd..029b7aa3 100644 --- a/mmengine/optim/__init__.py +++ b/mmengine/optim/__init__.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .optimizer import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, - DefaultOptimizerConstructor, build_optimizer, - build_optimizer_constructor) + DefaultOptimizerConstructor, build_optimizer) from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, CosineAnnealingLR, CosineAnnealingMomentum, CosineAnnealingParamScheduler, ExponentialLR, @@ -13,11 +12,11 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, __all__ = [ 'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optimizer', - 'build_optimizer_constructor', 'DefaultOptimizerConstructor', 'ConstantLR', - 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', - 'ConstantMomentum', 'CosineAnnealingMomentum', 'ExponentialMomentum', - 'LinearMomentum', 'MultiStepMomentum', 'StepMomentum', - 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', - 'ExponentialParamScheduler', 'LinearParamScheduler', - 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler' + 'DefaultOptimizerConstructor', 'ConstantLR', 'CosineAnnealingLR', + 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum', + 'CosineAnnealingMomentum', 'ExponentialMomentum', 'LinearMomentum', + 'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler', + 'CosineAnnealingParamScheduler', 'ExponentialParamScheduler', + 'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler', + '_ParamScheduler' ] diff --git a/mmengine/optim/optimizer/__init__.py b/mmengine/optim/optimizer/__init__.py index 77d5ac18..a74c6b8e 100644 --- a/mmengine/optim/optimizer/__init__.py +++ b/mmengine/optim/optimizer/__init__.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer, - build_optimizer_constructor) +from .builder import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, build_optimizer from .default_constructor import DefaultOptimizerConstructor __all__ = [ 'OPTIMIZER_CONSTRUCTORS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', - 'build_optimizer', 'build_optimizer_constructor' + 'build_optimizer' ] diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index a48f65e2..a3e1612d 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import inspect -from typing import Callable, List +from typing import List, Optional import torch import torch.nn as nn @@ -10,6 +10,11 @@ from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS def register_torch_optimizers() -> List[str]: + """Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry. + + Returns: + List[str]: A list of registered optimizers' name. + """ torch_optimizers = [] for module_name in dir(torch.optim): if module_name.startswith('__'): @@ -25,19 +30,35 @@ def register_torch_optimizers() -> List[str]: TORCH_OPTIMIZERS = register_torch_optimizers() -def build_optimizer_constructor(cfg: dict) -> Callable: - return OPTIMIZER_CONSTRUCTORS.build(cfg) +def build_optimizer( + model: nn.Module, + cfg: dict, + default_scope: Optional[str] = None) -> torch.optim.Optimizer: + """Build function of optimizer. + If ``constructor`` is set in the ``cfg``, this method will build an + optimizer constructor, and use optimizer constructor to build the + optimizer. If ``constructor`` is not set, the + ``DefaultOptimizerConstructor`` will be used by default. -def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer: + Args: + model (nn.Module): Model to be optimized. + cfg (dict): Config of optimizer and optimizer constructor. + default_scope (str, optional): The ``default_scope`` is used to + reset the current registry. Defaults to None. + + Returns: + torch.optim.Optimizer: The built optimizer. + """ optimizer_cfg = copy.deepcopy(cfg) constructor_type = optimizer_cfg.pop('constructor', 'DefaultOptimizerConstructor') paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) - optim_constructor = build_optimizer_constructor( + optim_constructor = OPTIMIZER_CONSTRUCTORS.build( dict( type=constructor_type, optimizer_cfg=optimizer_cfg, - paramwise_cfg=paramwise_cfg)) - optimizer = optim_constructor(model) + paramwise_cfg=paramwise_cfg), + default_scope=default_scope) + optimizer = optim_constructor(model, default_scope=default_scope) return optimizer diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index be573d3c..18b9db47 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -6,8 +6,7 @@ import torch import torch.nn as nn from torch.nn import GroupNorm, LayerNorm -from mmengine.registry import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, - build_from_cfg) +from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS from mmengine.utils import is_list_of, mmcv_full_available from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm @@ -242,7 +241,9 @@ class DefaultOptimizerConstructor: prefix=child_prefix, is_dcn_module=is_dcn_module) - def __call__(self, model: nn.Module) -> torch.optim.Optimizer: + def __call__(self, + model: nn.Module, + default_scope: Optional[str] = None) -> torch.optim.Optimizer: if hasattr(model, 'module'): model = model.module @@ -250,11 +251,11 @@ class DefaultOptimizerConstructor: # if no paramwise option is specified, just use the global setting if not self.paramwise_cfg: optimizer_cfg['params'] = model.parameters() - return build_from_cfg(optimizer_cfg, OPTIMIZERS) + return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope) # set param-wise lr and weight decay recursively params: List = [] self.add_params(params, model) optimizer_cfg['params'] = params - return build_from_cfg(optimizer_cfg, OPTIMIZERS) + return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope) diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 24890f87..2115c20c 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -7,8 +7,7 @@ import torch import torch.nn as nn from mmengine.optim import (OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, - DefaultOptimizerConstructor, build_optimizer, - build_optimizer_constructor) + DefaultOptimizerConstructor, build_optimizer) from mmengine.optim.optimizer.builder import TORCH_OPTIMIZERS from mmengine.registry import build_from_cfg from mmengine.utils import mmcv_full_available @@ -236,7 +235,7 @@ class TestBuilder(TestCase): type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg, paramwise_cfg=paramwise_cfg) - optim_constructor = build_optimizer_constructor(optim_constructor_cfg) + optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg) optimizer = optim_constructor(self.model) self._check_sgd_optimizer(optimizer, self.model, **paramwise_cfg) @@ -271,7 +270,7 @@ class TestBuilder(TestCase): type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg, paramwise_cfg=paramwise_cfg) - optim_constructor = build_optimizer_constructor(optim_constructor_cfg) + optim_constructor = OPTIMIZER_CONSTRUCTORS.build(optim_constructor_cfg) optimizer = optim_constructor(self.model) param_groups = optimizer.param_groups