[Refactor]: Refactor layer wise lr decay (#375)

* [Refactor]: Refactor layer wise lr decay

* [Fix]: Remove is_dcn_module and add log

* [Fix]: Fix lint

* [Fix]: Check relative pos bias table is not decayed

* [Fix]: Fix lint

* [Fix]: Fix UT
This commit is contained in:
Yuan Liu 2022-08-05 17:38:53 +08:00 committed by GitHub
parent 045b1fde8e
commit a703ba2fcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 16 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
from typing import Dict, List, Optional, Union
from typing import List
import torch
from mmengine.dist import get_dist_info
@ -13,7 +13,7 @@ from mmselfsup.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
"""Get the layer id to set the different learning rates.
"""Get the layer id to set the different learning rates for ViT.
Args:
var_name (str): The key of the model.
@ -36,7 +36,7 @@ def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
def get_layer_id_for_swin(var_name: str, max_layer_id: int,
depths: List[int]) -> int:
"""Get the layer id to set the different learning rates.
"""Get the layer id to set the different learning rates for Swin.
Args:
var_name (str): The key of the model.
@ -65,31 +65,37 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone.
Note: Currently, this optimizer constructor is built for ViT and Swin.
In addition to applying layer-wise learning rate decay schedule, this
module will not apply weight decay to ``normalization parameters``,
``bias``, ``position embedding``, ``class token``, and
``relative position bias table, automatically. What's more, the
``paramwise_cfg`` in the base module will be ignored.
"""
def add_params(self,
params: List[dict],
module: nn.Module,
optimizer_cfg: Dict,
prefix: str = '',
is_dcn_module: Optional[Union[int, float]] = None) -> None:
def add_params(self, params: List[dict], module: nn.Module,
optimizer_cfg: dict, **kwargs) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
params (List[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
optimizer_cfg (Dict): The configuration of optimizer.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
optimizer_cfg (dict): The configuration of optimizer.
prefix (str): The prefix of the module.
"""
logger = MMLogger.get_current_instance()
# Check if self.param_cfg is not None
if len(self.paramwise_cfg) > 0:
logger.info('The paramwise_cfg will be ignored, and normalization \
parameters, bias, position embedding, class token and \
relative position bias table will not be decayed by \
default.')
model_type = optimizer_cfg.pop('model_type', None)
# model_type should not be None
assert model_type is not None, 'When using layer-wise learning rate \
@ -113,8 +119,11 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
# will not decay normalization params, bias, position embedding
# class token, relative position bias table
if len(param.shape) == 1 or name.endswith('.bias') or name in (
'backbone.pos_embed', 'backbone.cls_token'):
'backbone.pos_embed', 'backbone.cls_token'
) or name.endswith('.relative_position_bias_table'):
group_name = 'no_decay'
this_weight_decay = 0.
else:

View File

@ -118,3 +118,7 @@ def test_learning_rate_decay_optimizer_wrapper_constructor():
assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0
assert optimizer_wrapper.optimizer.param_groups[-3]['lr_scale'] == 2.0
assert optimizer_wrapper.optimizer.param_groups[-5]['lr_scale'] == 4.0
# check relative pos bias table is not decayed
assert optimizer_wrapper.optimizer.param_groups[-4][
'weight_decay'] == 0.0 and 'backbone.stages.3.blocks.1.attn.w_msa.relative_position_bias_table' in optimizer_wrapper.optimizer.param_groups[ # noqa
-4]['param_names']