mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* Support default_scope when building optimizer and evaluator. * add docstring * fix * fix
65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import inspect
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmengine.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
|
|
|
|
|
def register_torch_optimizers() -> List[str]:
|
|
"""Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry.
|
|
|
|
Returns:
|
|
List[str]: A list of registered optimizers' name.
|
|
"""
|
|
torch_optimizers = []
|
|
for module_name in dir(torch.optim):
|
|
if module_name.startswith('__'):
|
|
continue
|
|
_optim = getattr(torch.optim, module_name)
|
|
if inspect.isclass(_optim) and issubclass(_optim,
|
|
torch.optim.Optimizer):
|
|
OPTIMIZERS.register_module(module=_optim)
|
|
torch_optimizers.append(module_name)
|
|
return torch_optimizers
|
|
|
|
|
|
TORCH_OPTIMIZERS = register_torch_optimizers()
|
|
|
|
|
|
def build_optimizer(
|
|
model: nn.Module,
|
|
cfg: dict,
|
|
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
|
|
"""Build function of optimizer.
|
|
|
|
If ``constructor`` is set in the ``cfg``, this method will build an
|
|
optimizer constructor, and use optimizer constructor to build the
|
|
optimizer. If ``constructor`` is not set, the
|
|
``DefaultOptimizerConstructor`` will be used by default.
|
|
|
|
Args:
|
|
model (nn.Module): Model to be optimized.
|
|
cfg (dict): Config of optimizer and optimizer constructor.
|
|
default_scope (str, optional): The ``default_scope`` is used to
|
|
reset the current registry. Defaults to None.
|
|
|
|
Returns:
|
|
torch.optim.Optimizer: The built optimizer.
|
|
"""
|
|
optimizer_cfg = copy.deepcopy(cfg)
|
|
constructor_type = optimizer_cfg.pop('constructor',
|
|
'DefaultOptimizerConstructor')
|
|
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
|
optim_constructor = OPTIMIZER_CONSTRUCTORS.build(
|
|
dict(
|
|
type=constructor_type,
|
|
optimizer_cfg=optimizer_cfg,
|
|
paramwise_cfg=paramwise_cfg),
|
|
default_scope=default_scope)
|
|
optimizer = optim_constructor(model, default_scope=default_scope)
|
|
return optimizer
|