From 966b7428172230dddace6a7322c588dad90a9f98 Mon Sep 17 00:00:00 2001 From: tripleMu <92794867+triple-Mu@users.noreply.github.com> Date: Sat, 28 May 2022 23:45:42 +0800 Subject: [PATCH] Add type hints for mmcv/runner/optimizer (#2001) --- mmcv/runner/optimizer/builder.py | 7 ++++--- mmcv/runner/optimizer/default_constructor.py | 20 ++++++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mmcv/runner/optimizer/builder.py b/mmcv/runner/optimizer/builder.py index f9234eed8..49d8f05a2 100644 --- a/mmcv/runner/optimizer/builder.py +++ b/mmcv/runner/optimizer/builder.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import inspect +from typing import Dict, List import torch @@ -10,7 +11,7 @@ OPTIMIZERS = Registry('optimizer') OPTIMIZER_BUILDERS = Registry('optimizer builder') -def register_torch_optimizers(): +def register_torch_optimizers() -> List: torch_optimizers = [] for module_name in dir(torch.optim): if module_name.startswith('__'): @@ -26,11 +27,11 @@ def register_torch_optimizers(): TORCH_OPTIMIZERS = register_torch_optimizers() -def build_optimizer_constructor(cfg): +def build_optimizer_constructor(cfg: Dict): return build_from_cfg(cfg, OPTIMIZER_BUILDERS) -def build_optimizer(model, cfg): +def build_optimizer(model, cfg: Dict): optimizer_cfg = copy.deepcopy(cfg) constructor_type = optimizer_cfg.pop('constructor', 'DefaultOptimizerConstructor') diff --git a/mmcv/runner/optimizer/default_constructor.py b/mmcv/runner/optimizer/default_constructor.py index ae97db880..2039b43ad 100644 --- a/mmcv/runner/optimizer/default_constructor.py +++ b/mmcv/runner/optimizer/default_constructor.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +from typing import Dict, List, Optional, Union import torch +import torch.nn as nn from torch.nn import GroupNorm, LayerNorm from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of @@ -93,7 +95,9 @@ class DefaultOptimizerConstructor: >>> # model.cls_head is (0.01, 0.95). """ - def __init__(self, optimizer_cfg, paramwise_cfg=None): + def __init__(self, + optimizer_cfg: Dict, + paramwise_cfg: Optional[Dict] = None): if not isinstance(optimizer_cfg, dict): raise TypeError('optimizer_cfg should be a dict', f'but got {type(optimizer_cfg)}') @@ -103,7 +107,7 @@ class DefaultOptimizerConstructor: self.base_wd = optimizer_cfg.get('weight_decay', None) self._validate_cfg() - def _validate_cfg(self): + def _validate_cfg(self) -> None: if not isinstance(self.paramwise_cfg, dict): raise TypeError('paramwise_cfg should be None or a dict, ' f'but got {type(self.paramwise_cfg)}') @@ -126,7 +130,7 @@ class DefaultOptimizerConstructor: if self.base_wd is None: raise ValueError('base_wd should not be None') - def _is_in(self, param_group, param_group_list): + def _is_in(self, param_group: Dict, param_group_list: List) -> bool: assert is_list_of(param_group_list, dict) param = set(param_group['params']) param_set = set() @@ -135,7 +139,11 @@ class DefaultOptimizerConstructor: return not param.isdisjoint(param_set) - def add_params(self, params, module, prefix='', is_dcn_module=None): + def add_params(self, + params: List[Dict], + module: nn.Module, + prefix: str = '', + is_dcn_module: Optional[Union[int, float]] = None) -> None: """Add all parameters of module to the params list. The parameters of the given module will be added to the list of param @@ -232,7 +240,7 @@ class DefaultOptimizerConstructor: prefix=child_prefix, is_dcn_module=is_dcn_module) - def __call__(self, model): + def __call__(self, model: nn.Module): if hasattr(model, 'module'): model = model.module @@ -243,7 +251,7 @@ class DefaultOptimizerConstructor: return build_from_cfg(optimizer_cfg, OPTIMIZERS) # set param-wise lr and weight decay recursively - params = [] + params: List[Dict] = [] self.add_params(params, model) optimizer_cfg['params'] = params