Add type hints for mmcv/runner/optimizer (#2001)

pull/2023/head
tripleMu 2022-05-28 23:45:42 +08:00 committed by GitHub
parent 1577f40744
commit 966b742817
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 9 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
from typing import Dict, List
import torch
@ -10,7 +11,7 @@ OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')
def register_torch_optimizers():
def register_torch_optimizers() -> List:
torch_optimizers = []
for module_name in dir(torch.optim):
if module_name.startswith('__'):
@ -26,11 +27,11 @@ def register_torch_optimizers():
TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer_constructor(cfg):
def build_optimizer_constructor(cfg: Dict):
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
def build_optimizer(model, cfg):
def build_optimizer(model, cfg: Dict):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')

View File

@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
@ -93,7 +95,9 @@ class DefaultOptimizerConstructor:
>>> # model.cls_head is (0.01, 0.95).
"""
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __init__(self,
optimizer_cfg: Dict,
paramwise_cfg: Optional[Dict] = None):
if not isinstance(optimizer_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
f'but got {type(optimizer_cfg)}')
@ -103,7 +107,7 @@ class DefaultOptimizerConstructor:
self.base_wd = optimizer_cfg.get('weight_decay', None)
self._validate_cfg()
def _validate_cfg(self):
def _validate_cfg(self) -> None:
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
@ -126,7 +130,7 @@ class DefaultOptimizerConstructor:
if self.base_wd is None:
raise ValueError('base_wd should not be None')
def _is_in(self, param_group, param_group_list):
def _is_in(self, param_group: Dict, param_group_list: List) -> bool:
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()
@ -135,7 +139,11 @@ class DefaultOptimizerConstructor:
return not param.isdisjoint(param_set)
def add_params(self, params, module, prefix='', is_dcn_module=None):
def add_params(self,
params: List[Dict],
module: nn.Module,
prefix: str = '',
is_dcn_module: Optional[Union[int, float]] = None) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
@ -232,7 +240,7 @@ class DefaultOptimizerConstructor:
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def __call__(self, model):
def __call__(self, model: nn.Module):
if hasattr(model, 'module'):
model = model.module
@ -243,7 +251,7 @@ class DefaultOptimizerConstructor:
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
# set param-wise lr and weight decay recursively
params = []
params: List[Dict] = []
self.add_params(params, model)
optimizer_cfg['params'] = params