diff --git a/mmselfsup/engine/optimizers/layer_decay_optim_wrapper_constructor.py b/mmselfsup/engine/optimizers/layer_decay_optim_wrapper_constructor.py index 558a60fd..85c236e9 100644 --- a/mmselfsup/engine/optimizers/layer_decay_optim_wrapper_constructor.py +++ b/mmselfsup/engine/optimizers/layer_decay_optim_wrapper_constructor.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import json -from typing import Dict, List, Optional, Union +from typing import List import torch from mmengine.dist import get_dist_info @@ -13,7 +13,7 @@ from mmselfsup.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int: - """Get the layer id to set the different learning rates. + """Get the layer id to set the different learning rates for ViT. Args: var_name (str): The key of the model. @@ -36,7 +36,7 @@ def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int: def get_layer_id_for_swin(var_name: str, max_layer_id: int, depths: List[int]) -> int: - """Get the layer id to set the different learning rates. + """Get the layer id to set the different learning rates for Swin. Args: var_name (str): The key of the model. @@ -65,31 +65,37 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): """Different learning rates are set for different layers of backbone. Note: Currently, this optimizer constructor is built for ViT and Swin. + + In addition to applying layer-wise learning rate decay schedule, this + module will not apply weight decay to ``normalization parameters``, + ``bias``, ``position embedding``, ``class token``, and + ``relative position bias table, automatically. What's more, the + ``paramwise_cfg`` in the base module will be ignored. """ - def add_params(self, - params: List[dict], - module: nn.Module, - optimizer_cfg: Dict, - prefix: str = '', - is_dcn_module: Optional[Union[int, float]] = None) -> None: + def add_params(self, params: List[dict], module: nn.Module, + optimizer_cfg: dict, **kwargs) -> None: """Add all parameters of module to the params list. The parameters of the given module will be added to the list of param groups, with specific rules defined by paramwise_cfg. Args: - params (list[dict]): A list of param groups, it will be modified + params (List[dict]): A list of param groups, it will be modified in place. module (nn.Module): The module to be added. - optimizer_cfg (Dict): The configuration of optimizer. - prefix (str): The prefix of the module - is_dcn_module (int|float|None): If the current module is a - submodule of DCN, `is_dcn_module` will be passed to - control conv_offset layer's learning rate. Defaults to None. + optimizer_cfg (dict): The configuration of optimizer. + prefix (str): The prefix of the module. """ logger = MMLogger.get_current_instance() + # Check if self.param_cfg is not None + if len(self.paramwise_cfg) > 0: + logger.info('The paramwise_cfg will be ignored, and normalization \ + parameters, bias, position embedding, class token and \ + relative position bias table will not be decayed by \ + default.') + model_type = optimizer_cfg.pop('model_type', None) # model_type should not be None assert model_type is not None, 'When using layer-wise learning rate \ @@ -113,8 +119,11 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): for name, param in module.named_parameters(): if not param.requires_grad: continue # frozen weights + # will not decay normalization params, bias, position embedding + # class token, relative position bias table if len(param.shape) == 1 or name.endswith('.bias') or name in ( - 'backbone.pos_embed', 'backbone.cls_token'): + 'backbone.pos_embed', 'backbone.cls_token' + ) or name.endswith('.relative_position_bias_table'): group_name = 'no_decay' this_weight_decay = 0. else: diff --git a/tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py b/tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py index a2e33c86..5d6a40da 100644 --- a/tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py +++ b/tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py @@ -118,3 +118,7 @@ def test_learning_rate_decay_optimizer_wrapper_constructor(): assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0 assert optimizer_wrapper.optimizer.param_groups[-3]['lr_scale'] == 2.0 assert optimizer_wrapper.optimizer.param_groups[-5]['lr_scale'] == 4.0 + # check relative pos bias table is not decayed + assert optimizer_wrapper.optimizer.param_groups[-4][ + 'weight_decay'] == 0.0 and 'backbone.stages.3.blocks.1.attn.w_msa.relative_position_bias_table' in optimizer_wrapper.optimizer.param_groups[ # noqa + -4]['param_names']