mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor]: Refactor layer wise lr decay (#375)
* [Refactor]: Refactor layer wise lr decay * [Fix]: Remove is_dcn_module and add log * [Fix]: Fix lint * [Fix]: Check relative pos bias table is not decayed * [Fix]: Fix lint * [Fix]: Fix UT
This commit is contained in:
parent
045b1fde8e
commit
a703ba2fcb
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from mmengine.dist import get_dist_info
|
||||
@ -13,7 +13,7 @@ from mmselfsup.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
||||
|
||||
|
||||
def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
|
||||
"""Get the layer id to set the different learning rates.
|
||||
"""Get the layer id to set the different learning rates for ViT.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
@ -36,7 +36,7 @@ def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
|
||||
|
||||
def get_layer_id_for_swin(var_name: str, max_layer_id: int,
|
||||
depths: List[int]) -> int:
|
||||
"""Get the layer id to set the different learning rates.
|
||||
"""Get the layer id to set the different learning rates for Swin.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
@ -65,31 +65,37 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for ViT and Swin.
|
||||
|
||||
In addition to applying layer-wise learning rate decay schedule, this
|
||||
module will not apply weight decay to ``normalization parameters``,
|
||||
``bias``, ``position embedding``, ``class token``, and
|
||||
``relative position bias table, automatically. What's more, the
|
||||
``paramwise_cfg`` in the base module will be ignored.
|
||||
"""
|
||||
|
||||
def add_params(self,
|
||||
params: List[dict],
|
||||
module: nn.Module,
|
||||
optimizer_cfg: Dict,
|
||||
prefix: str = '',
|
||||
is_dcn_module: Optional[Union[int, float]] = None) -> None:
|
||||
def add_params(self, params: List[dict], module: nn.Module,
|
||||
optimizer_cfg: dict, **kwargs) -> None:
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (list[dict]): A list of param groups, it will be modified
|
||||
params (List[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
optimizer_cfg (Dict): The configuration of optimizer.
|
||||
prefix (str): The prefix of the module
|
||||
is_dcn_module (int|float|None): If the current module is a
|
||||
submodule of DCN, `is_dcn_module` will be passed to
|
||||
control conv_offset layer's learning rate. Defaults to None.
|
||||
optimizer_cfg (dict): The configuration of optimizer.
|
||||
prefix (str): The prefix of the module.
|
||||
"""
|
||||
logger = MMLogger.get_current_instance()
|
||||
|
||||
# Check if self.param_cfg is not None
|
||||
if len(self.paramwise_cfg) > 0:
|
||||
logger.info('The paramwise_cfg will be ignored, and normalization \
|
||||
parameters, bias, position embedding, class token and \
|
||||
relative position bias table will not be decayed by \
|
||||
default.')
|
||||
|
||||
model_type = optimizer_cfg.pop('model_type', None)
|
||||
# model_type should not be None
|
||||
assert model_type is not None, 'When using layer-wise learning rate \
|
||||
@ -113,8 +119,11 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||
for name, param in module.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue # frozen weights
|
||||
# will not decay normalization params, bias, position embedding
|
||||
# class token, relative position bias table
|
||||
if len(param.shape) == 1 or name.endswith('.bias') or name in (
|
||||
'backbone.pos_embed', 'backbone.cls_token'):
|
||||
'backbone.pos_embed', 'backbone.cls_token'
|
||||
) or name.endswith('.relative_position_bias_table'):
|
||||
group_name = 'no_decay'
|
||||
this_weight_decay = 0.
|
||||
else:
|
||||
|
@ -118,3 +118,7 @@ def test_learning_rate_decay_optimizer_wrapper_constructor():
|
||||
assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0
|
||||
assert optimizer_wrapper.optimizer.param_groups[-3]['lr_scale'] == 2.0
|
||||
assert optimizer_wrapper.optimizer.param_groups[-5]['lr_scale'] == 4.0
|
||||
# check relative pos bias table is not decayed
|
||||
assert optimizer_wrapper.optimizer.param_groups[-4][
|
||||
'weight_decay'] == 0.0 and 'backbone.stages.3.blocks.1.attn.w_msa.relative_position_bias_table' in optimizer_wrapper.optimizer.param_groups[ # noqa
|
||||
-4]['param_names']
|
||||
|
Loading…
x
Reference in New Issue
Block a user