mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor]: Refactor layer wise lr decay (#375)
* [Refactor]: Refactor layer wise lr decay * [Fix]: Remove is_dcn_module and add log * [Fix]: Fix lint * [Fix]: Check relative pos bias table is not decayed * [Fix]: Fix lint * [Fix]: Fix UT
This commit is contained in:
parent
045b1fde8e
commit
a703ba2fcb
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmengine.dist import get_dist_info
|
from mmengine.dist import get_dist_info
|
||||||
@ -13,7 +13,7 @@ from mmselfsup.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
|||||||
|
|
||||||
|
|
||||||
def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
|
def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
|
||||||
"""Get the layer id to set the different learning rates.
|
"""Get the layer id to set the different learning rates for ViT.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
var_name (str): The key of the model.
|
var_name (str): The key of the model.
|
||||||
@ -36,7 +36,7 @@ def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
|
|||||||
|
|
||||||
def get_layer_id_for_swin(var_name: str, max_layer_id: int,
|
def get_layer_id_for_swin(var_name: str, max_layer_id: int,
|
||||||
depths: List[int]) -> int:
|
depths: List[int]) -> int:
|
||||||
"""Get the layer id to set the different learning rates.
|
"""Get the layer id to set the different learning rates for Swin.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
var_name (str): The key of the model.
|
var_name (str): The key of the model.
|
||||||
@ -65,31 +65,37 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
|||||||
"""Different learning rates are set for different layers of backbone.
|
"""Different learning rates are set for different layers of backbone.
|
||||||
|
|
||||||
Note: Currently, this optimizer constructor is built for ViT and Swin.
|
Note: Currently, this optimizer constructor is built for ViT and Swin.
|
||||||
|
|
||||||
|
In addition to applying layer-wise learning rate decay schedule, this
|
||||||
|
module will not apply weight decay to ``normalization parameters``,
|
||||||
|
``bias``, ``position embedding``, ``class token``, and
|
||||||
|
``relative position bias table, automatically. What's more, the
|
||||||
|
``paramwise_cfg`` in the base module will be ignored.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add_params(self,
|
def add_params(self, params: List[dict], module: nn.Module,
|
||||||
params: List[dict],
|
optimizer_cfg: dict, **kwargs) -> None:
|
||||||
module: nn.Module,
|
|
||||||
optimizer_cfg: Dict,
|
|
||||||
prefix: str = '',
|
|
||||||
is_dcn_module: Optional[Union[int, float]] = None) -> None:
|
|
||||||
"""Add all parameters of module to the params list.
|
"""Add all parameters of module to the params list.
|
||||||
|
|
||||||
The parameters of the given module will be added to the list of param
|
The parameters of the given module will be added to the list of param
|
||||||
groups, with specific rules defined by paramwise_cfg.
|
groups, with specific rules defined by paramwise_cfg.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params (list[dict]): A list of param groups, it will be modified
|
params (List[dict]): A list of param groups, it will be modified
|
||||||
in place.
|
in place.
|
||||||
module (nn.Module): The module to be added.
|
module (nn.Module): The module to be added.
|
||||||
optimizer_cfg (Dict): The configuration of optimizer.
|
optimizer_cfg (dict): The configuration of optimizer.
|
||||||
prefix (str): The prefix of the module
|
prefix (str): The prefix of the module.
|
||||||
is_dcn_module (int|float|None): If the current module is a
|
|
||||||
submodule of DCN, `is_dcn_module` will be passed to
|
|
||||||
control conv_offset layer's learning rate. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
logger = MMLogger.get_current_instance()
|
logger = MMLogger.get_current_instance()
|
||||||
|
|
||||||
|
# Check if self.param_cfg is not None
|
||||||
|
if len(self.paramwise_cfg) > 0:
|
||||||
|
logger.info('The paramwise_cfg will be ignored, and normalization \
|
||||||
|
parameters, bias, position embedding, class token and \
|
||||||
|
relative position bias table will not be decayed by \
|
||||||
|
default.')
|
||||||
|
|
||||||
model_type = optimizer_cfg.pop('model_type', None)
|
model_type = optimizer_cfg.pop('model_type', None)
|
||||||
# model_type should not be None
|
# model_type should not be None
|
||||||
assert model_type is not None, 'When using layer-wise learning rate \
|
assert model_type is not None, 'When using layer-wise learning rate \
|
||||||
@ -113,8 +119,11 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
|||||||
for name, param in module.named_parameters():
|
for name, param in module.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
continue # frozen weights
|
continue # frozen weights
|
||||||
|
# will not decay normalization params, bias, position embedding
|
||||||
|
# class token, relative position bias table
|
||||||
if len(param.shape) == 1 or name.endswith('.bias') or name in (
|
if len(param.shape) == 1 or name.endswith('.bias') or name in (
|
||||||
'backbone.pos_embed', 'backbone.cls_token'):
|
'backbone.pos_embed', 'backbone.cls_token'
|
||||||
|
) or name.endswith('.relative_position_bias_table'):
|
||||||
group_name = 'no_decay'
|
group_name = 'no_decay'
|
||||||
this_weight_decay = 0.
|
this_weight_decay = 0.
|
||||||
else:
|
else:
|
||||||
|
@ -118,3 +118,7 @@ def test_learning_rate_decay_optimizer_wrapper_constructor():
|
|||||||
assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0
|
assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0
|
||||||
assert optimizer_wrapper.optimizer.param_groups[-3]['lr_scale'] == 2.0
|
assert optimizer_wrapper.optimizer.param_groups[-3]['lr_scale'] == 2.0
|
||||||
assert optimizer_wrapper.optimizer.param_groups[-5]['lr_scale'] == 4.0
|
assert optimizer_wrapper.optimizer.param_groups[-5]['lr_scale'] == 4.0
|
||||||
|
# check relative pos bias table is not decayed
|
||||||
|
assert optimizer_wrapper.optimizer.param_groups[-4][
|
||||||
|
'weight_decay'] == 0.0 and 'backbone.stages.3.blocks.1.attn.w_msa.relative_position_bias_table' in optimizer_wrapper.optimizer.param_groups[ # noqa
|
||||||
|
-4]['param_names']
|
||||||
|
Loading…
x
Reference in New Issue
Block a user