[Refactor]: Refactor layer wise lr decay (#375)

* [Refactor]: Refactor layer wise lr decay

* [Fix]: Remove is_dcn_module and add log

* [Fix]: Fix lint

* [Fix]: Check relative pos bias table is not decayed

* [Fix]: Fix lint

* [Fix]: Fix UT
This commit is contained in:
Yuan Liu 2022-08-05 17:38:53 +08:00 committed by GitHub
parent 045b1fde8e
commit a703ba2fcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 16 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import json import json
from typing import Dict, List, Optional, Union from typing import List
import torch import torch
from mmengine.dist import get_dist_info from mmengine.dist import get_dist_info
@ -13,7 +13,7 @@ from mmselfsup.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int: def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
"""Get the layer id to set the different learning rates. """Get the layer id to set the different learning rates for ViT.
Args: Args:
var_name (str): The key of the model. var_name (str): The key of the model.
@ -36,7 +36,7 @@ def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
def get_layer_id_for_swin(var_name: str, max_layer_id: int, def get_layer_id_for_swin(var_name: str, max_layer_id: int,
depths: List[int]) -> int: depths: List[int]) -> int:
"""Get the layer id to set the different learning rates. """Get the layer id to set the different learning rates for Swin.
Args: Args:
var_name (str): The key of the model. var_name (str): The key of the model.
@ -65,31 +65,37 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone. """Different learning rates are set for different layers of backbone.
Note: Currently, this optimizer constructor is built for ViT and Swin. Note: Currently, this optimizer constructor is built for ViT and Swin.
In addition to applying layer-wise learning rate decay schedule, this
module will not apply weight decay to ``normalization parameters``,
``bias``, ``position embedding``, ``class token``, and
``relative position bias table, automatically. What's more, the
``paramwise_cfg`` in the base module will be ignored.
""" """
def add_params(self, def add_params(self, params: List[dict], module: nn.Module,
params: List[dict], optimizer_cfg: dict, **kwargs) -> None:
module: nn.Module,
optimizer_cfg: Dict,
prefix: str = '',
is_dcn_module: Optional[Union[int, float]] = None) -> None:
"""Add all parameters of module to the params list. """Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg. groups, with specific rules defined by paramwise_cfg.
Args: Args:
params (list[dict]): A list of param groups, it will be modified params (List[dict]): A list of param groups, it will be modified
in place. in place.
module (nn.Module): The module to be added. module (nn.Module): The module to be added.
optimizer_cfg (Dict): The configuration of optimizer. optimizer_cfg (dict): The configuration of optimizer.
prefix (str): The prefix of the module prefix (str): The prefix of the module.
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
""" """
logger = MMLogger.get_current_instance() logger = MMLogger.get_current_instance()
# Check if self.param_cfg is not None
if len(self.paramwise_cfg) > 0:
logger.info('The paramwise_cfg will be ignored, and normalization \
parameters, bias, position embedding, class token and \
relative position bias table will not be decayed by \
default.')
model_type = optimizer_cfg.pop('model_type', None) model_type = optimizer_cfg.pop('model_type', None)
# model_type should not be None # model_type should not be None
assert model_type is not None, 'When using layer-wise learning rate \ assert model_type is not None, 'When using layer-wise learning rate \
@ -113,8 +119,11 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
for name, param in module.named_parameters(): for name, param in module.named_parameters():
if not param.requires_grad: if not param.requires_grad:
continue # frozen weights continue # frozen weights
# will not decay normalization params, bias, position embedding
# class token, relative position bias table
if len(param.shape) == 1 or name.endswith('.bias') or name in ( if len(param.shape) == 1 or name.endswith('.bias') or name in (
'backbone.pos_embed', 'backbone.cls_token'): 'backbone.pos_embed', 'backbone.cls_token'
) or name.endswith('.relative_position_bias_table'):
group_name = 'no_decay' group_name = 'no_decay'
this_weight_decay = 0. this_weight_decay = 0.
else: else:

View File

@ -118,3 +118,7 @@ def test_learning_rate_decay_optimizer_wrapper_constructor():
assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0 assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0
assert optimizer_wrapper.optimizer.param_groups[-3]['lr_scale'] == 2.0 assert optimizer_wrapper.optimizer.param_groups[-3]['lr_scale'] == 2.0
assert optimizer_wrapper.optimizer.param_groups[-5]['lr_scale'] == 4.0 assert optimizer_wrapper.optimizer.param_groups[-5]['lr_scale'] == 4.0
# check relative pos bias table is not decayed
assert optimizer_wrapper.optimizer.param_groups[-4][
'weight_decay'] == 0.0 and 'backbone.stages.3.blocks.1.attn.w_msa.relative_position_bias_table' in optimizer_wrapper.optimizer.param_groups[ # noqa
-4]['param_names']