mirror of https://github.com/open-mmlab/mmcv.git
Add type hints for mmcv/runner/optimizer (#2001)
parent
1577f40744
commit
966b742817
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -10,7 +11,7 @@ OPTIMIZERS = Registry('optimizer')
|
|||
OPTIMIZER_BUILDERS = Registry('optimizer builder')
|
||||
|
||||
|
||||
def register_torch_optimizers():
|
||||
def register_torch_optimizers() -> List:
|
||||
torch_optimizers = []
|
||||
for module_name in dir(torch.optim):
|
||||
if module_name.startswith('__'):
|
||||
|
@ -26,11 +27,11 @@ def register_torch_optimizers():
|
|||
TORCH_OPTIMIZERS = register_torch_optimizers()
|
||||
|
||||
|
||||
def build_optimizer_constructor(cfg):
|
||||
def build_optimizer_constructor(cfg: Dict):
|
||||
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
||||
|
||||
|
||||
def build_optimizer(model, cfg):
|
||||
def build_optimizer(model, cfg: Dict):
|
||||
optimizer_cfg = copy.deepcopy(cfg)
|
||||
constructor_type = optimizer_cfg.pop('constructor',
|
||||
'DefaultOptimizerConstructor')
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import GroupNorm, LayerNorm
|
||||
|
||||
from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
|
||||
|
@ -93,7 +95,9 @@ class DefaultOptimizerConstructor:
|
|||
>>> # model.cls_head is (0.01, 0.95).
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
def __init__(self,
|
||||
optimizer_cfg: Dict,
|
||||
paramwise_cfg: Optional[Dict] = None):
|
||||
if not isinstance(optimizer_cfg, dict):
|
||||
raise TypeError('optimizer_cfg should be a dict',
|
||||
f'but got {type(optimizer_cfg)}')
|
||||
|
@ -103,7 +107,7 @@ class DefaultOptimizerConstructor:
|
|||
self.base_wd = optimizer_cfg.get('weight_decay', None)
|
||||
self._validate_cfg()
|
||||
|
||||
def _validate_cfg(self):
|
||||
def _validate_cfg(self) -> None:
|
||||
if not isinstance(self.paramwise_cfg, dict):
|
||||
raise TypeError('paramwise_cfg should be None or a dict, '
|
||||
f'but got {type(self.paramwise_cfg)}')
|
||||
|
@ -126,7 +130,7 @@ class DefaultOptimizerConstructor:
|
|||
if self.base_wd is None:
|
||||
raise ValueError('base_wd should not be None')
|
||||
|
||||
def _is_in(self, param_group, param_group_list):
|
||||
def _is_in(self, param_group: Dict, param_group_list: List) -> bool:
|
||||
assert is_list_of(param_group_list, dict)
|
||||
param = set(param_group['params'])
|
||||
param_set = set()
|
||||
|
@ -135,7 +139,11 @@ class DefaultOptimizerConstructor:
|
|||
|
||||
return not param.isdisjoint(param_set)
|
||||
|
||||
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
||||
def add_params(self,
|
||||
params: List[Dict],
|
||||
module: nn.Module,
|
||||
prefix: str = '',
|
||||
is_dcn_module: Optional[Union[int, float]] = None) -> None:
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
|
@ -232,7 +240,7 @@ class DefaultOptimizerConstructor:
|
|||
prefix=child_prefix,
|
||||
is_dcn_module=is_dcn_module)
|
||||
|
||||
def __call__(self, model):
|
||||
def __call__(self, model: nn.Module):
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
|
@ -243,7 +251,7 @@ class DefaultOptimizerConstructor:
|
|||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
||||
|
||||
# set param-wise lr and weight decay recursively
|
||||
params = []
|
||||
params: List[Dict] = []
|
||||
self.add_params(params, model)
|
||||
optimizer_cfg['params'] = params
|
||||
|
||||
|
|
Loading…
Reference in New Issue