[Feature] Implement layer-wise learning rate decay optimizer constructor. (#1399)
* [Feature] Implement layer-wise learning rate decay optimizer constructor. * Use num_layers instead of max_depth to avoid misleading * Add UT * Update docstring * Update log info * update LearningRateDecay configs --------- Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>pull/1408/head
parent
827be6e22d
commit
274a67223e
|
@ -84,16 +84,11 @@ test_dataloader = val_dataloader
|
|||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=4e-3,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit',
|
||||
layer_decay_rate=0.65),
|
||||
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
_delete_=True,
|
||||
layer_decay_rate=0.65,
|
||||
custom_keys={
|
||||
# the following configurations are designed for BEiT
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
|
|
|
@ -77,17 +77,12 @@ test_dataloader = val_dataloader
|
|||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=5e-4,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit',
|
||||
# 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs
|
||||
layer_decay_rate=0.65),
|
||||
type='AdamW', lr=5e-4, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
_delete_=True,
|
||||
# 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs
|
||||
layer_decay_rate=0.65,
|
||||
custom_keys={
|
||||
# the following configurations are designed for BEiT
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
|
|
|
@ -92,14 +92,10 @@ model = dict(
|
|||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=8e-3,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.05,
|
||||
model_type='vit', # layer-wise lr decay type
|
||||
layer_decay_rate=0.65), # layer-wise lr decay factor
|
||||
type='AdamW', lr=8e-3, betas=(0.9, 0.999), weight_decay=0.05),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.65,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -76,15 +76,10 @@ model = dict(
|
|||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=4e-4,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit', # layer-wise lr decay type
|
||||
layer_decay_rate=0.65), # layer-wise lr decay factor
|
||||
type='AdamW', lr=4e-4, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.65,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -75,15 +75,10 @@ model = dict(
|
|||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=2e-3,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit', # layer-wise lr decay type
|
||||
layer_decay_rate=0.65), # layer-wise lr decay factor
|
||||
type='AdamW', lr=2e-3, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.65,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -78,15 +78,10 @@ model = dict(
|
|||
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=4e-3,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit', # layer-wise lr decay type
|
||||
layer_decay_rate=0.75), # layer-wise lr decay factor
|
||||
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.75,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -77,15 +77,10 @@ model = dict(
|
|||
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=4e-3,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit', # layer-wise lr decay type
|
||||
layer_decay_rate=0.75), # layer-wise lr decay factor
|
||||
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.75,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -77,15 +77,10 @@ model = dict(
|
|||
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=4e-3,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit', # layer-wise lr decay type
|
||||
layer_decay_rate=0.75), # layer-wise lr decay factor
|
||||
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.75,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -76,15 +76,10 @@ model = dict(
|
|||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=8e-3,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit', # layer-wise lr decay type
|
||||
layer_decay_rate=0.65), # layer-wise lr decay factor
|
||||
type='AdamW', lr=8e-3, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.65,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -76,15 +76,10 @@ model = dict(
|
|||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=4e-4,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
model_type='vit', # layer-wise lr decay type
|
||||
layer_decay_rate=0.65), # layer-wise lr decay factor
|
||||
type='AdamW', lr=4e-4, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.65,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -92,12 +92,11 @@ optim_wrapper = dict(
|
|||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=5e-4 * (8 * 128 / 256),
|
||||
model_type='mixmim',
|
||||
layer_decay_rate=0.7,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.05),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.7,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0), # do not decay on ln and bias
|
||||
'.bias': dict(decay_mult=0.0)
|
||||
|
|
|
@ -14,11 +14,11 @@ model = dict(
|
|||
# optimizer settings
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=5e-3, model_type='swin', layer_decay_rate=0.9),
|
||||
optimizer=dict(type='AdamW', lr=5e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.9,
|
||||
custom_keys={
|
||||
'.norm': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -57,11 +57,11 @@ model = dict(
|
|||
# optimizer settings
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=5e-3, model_type='swin', layer_decay_rate=0.9),
|
||||
optimizer=dict(type='AdamW', lr=5e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.9,
|
||||
custom_keys={
|
||||
'.norm': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -60,11 +60,11 @@ model = dict(
|
|||
# optimizer settings
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=5e-3, model_type='swin', layer_decay_rate=0.7),
|
||||
optimizer=dict(type='AdamW', lr=5e-3),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.7,
|
||||
custom_keys={
|
||||
'.norm': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
|
|
|
@ -47,3 +47,5 @@ Optimizers
|
|||
:nosignatures:
|
||||
|
||||
Lamb
|
||||
LARS
|
||||
LearningRateDecayOptimWrapperConstructor
|
||||
|
|
|
@ -2,5 +2,7 @@
|
|||
from .adan_t import Adan
|
||||
from .lamb import Lamb
|
||||
from .lars import LARS
|
||||
from .layer_decay_optim_wrapper_constructor import \
|
||||
LearningRateDecayOptimWrapperConstructor
|
||||
|
||||
__all__ = ['Lamb', 'Adan', 'LARS']
|
||||
__all__ = ['Lamb', 'Adan', 'LARS', 'LearningRateDecayOptimWrapperConstructor']
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import defaultdict
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
from torch import nn
|
||||
from torch.nn import GroupNorm, LayerNorm
|
||||
|
||||
from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
By default, each parameter share the same optimizer settings, and we
|
||||
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
||||
It is a dict and may contain the following fields:
|
||||
|
||||
- ``layer_decay_rate`` (float): The learning rate of a parameter will
|
||||
multiply it by multiple times according to the layer depth of the
|
||||
parameter. Usually, it's less than 1, so that the earlier layers will
|
||||
have a lower learning rate. Defaults to 1.
|
||||
- ``bias_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all bias parameters (except for those in normalization layers).
|
||||
- ``norm_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all weight and bias parameters of normalization layers.
|
||||
- ``flat_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all one-dimensional parameters
|
||||
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
|
||||
one of the keys in ``custom_keys`` is a substring of the name of one
|
||||
parameter, then the setting of the parameter will be specified by
|
||||
``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be
|
||||
ignored. It should be a dict and may contain fields ``decay_mult``.
|
||||
(The ``lr_mult`` is disabled in this constructor).
|
||||
|
||||
Example:
|
||||
|
||||
In the config file, you can use this constructor as below:
|
||||
|
||||
.. code:: python
|
||||
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=4e-3,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.75, # layer-wise lr decay factor
|
||||
norm_decay_mult=0.,
|
||||
flat_decay_mult=0.,
|
||||
custom_keys={
|
||||
'.cls_token': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0)
|
||||
}))
|
||||
"""
|
||||
|
||||
def add_params(self,
|
||||
params: List[dict],
|
||||
module: nn.Module,
|
||||
prefix: str = '',
|
||||
get_layer_depth: Optional[Callable] = None,
|
||||
**kwargs) -> None:
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (List[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
optimizer_cfg (dict): The configuration of optimizer.
|
||||
prefix (str): The prefix of the module.
|
||||
"""
|
||||
# get param-wise options
|
||||
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
||||
# first sort with alphabet order and then sort with reversed len of str
|
||||
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
|
||||
logger = MMLogger.get_current_instance()
|
||||
|
||||
# The model should have `get_layer_depth` method
|
||||
if get_layer_depth is None and not hasattr(module, 'get_layer_depth'):
|
||||
raise NotImplementedError('The layer-wise learning rate decay need'
|
||||
f' the model {type(module)} has'
|
||||
' `get_layer_depth` method.')
|
||||
else:
|
||||
get_layer_depth = get_layer_depth or module.get_layer_depth
|
||||
|
||||
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
|
||||
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
|
||||
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
|
||||
decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0)
|
||||
|
||||
# special rules for norm layers and depth-wise conv layers
|
||||
is_norm = isinstance(module,
|
||||
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
|
||||
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
param_group = {'params': [param]}
|
||||
param_name = prefix + name
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if self.base_wd is not None:
|
||||
base_wd = self.base_wd
|
||||
custom_key = next(
|
||||
filter(lambda k: k in param_name, sorted_keys), None)
|
||||
# custom parameters decay
|
||||
if custom_key is not None:
|
||||
custom_cfg = custom_keys[custom_key].copy()
|
||||
decay_mult = custom_cfg.pop('decay_mult', 1.)
|
||||
|
||||
param_group['weight_decay'] = base_wd * decay_mult
|
||||
# add custom settings to param_group
|
||||
param_group.update(custom_cfg)
|
||||
# norm decay
|
||||
elif is_norm and norm_decay_mult is not None:
|
||||
param_group['weight_decay'] = base_wd * norm_decay_mult
|
||||
# bias decay
|
||||
elif name == 'bias' and bias_decay_mult is not None:
|
||||
param_group['weight_decay'] = base_wd * bias_decay_mult
|
||||
# flatten parameters decay
|
||||
elif param.ndim == 1 and flat_decay_mult is not None:
|
||||
param_group['weight_decay'] = base_wd * flat_decay_mult
|
||||
else:
|
||||
param_group['weight_decay'] = base_wd
|
||||
|
||||
layer_id, max_id = get_layer_depth(param_name)
|
||||
scale = decay_rate**(max_id - layer_id - 1)
|
||||
param_group['lr'] = self.base_lr * scale
|
||||
param_group['lr_scale'] = scale
|
||||
param_group['layer_id'] = layer_id
|
||||
param_group['param_name'] = param_name
|
||||
|
||||
params.append(param_group)
|
||||
|
||||
for child_name, child_mod in module.named_children():
|
||||
child_prefix = f'{prefix}{child_name}.'
|
||||
self.add_params(
|
||||
params,
|
||||
child_mod,
|
||||
prefix=child_prefix,
|
||||
get_layer_depth=get_layer_depth,
|
||||
)
|
||||
|
||||
if prefix == '':
|
||||
layer_params = defaultdict(list)
|
||||
for param in params:
|
||||
layer_params[param['layer_id']].append(param)
|
||||
for layer_id, layer_params in layer_params.items():
|
||||
lr_scale = layer_params[0]['lr_scale']
|
||||
lr = layer_params[0]['lr']
|
||||
msg = [
|
||||
f'layer {layer_id} params '
|
||||
f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):'
|
||||
]
|
||||
for param in layer_params:
|
||||
msg.append(f'\t{param["param_name"]}: '
|
||||
f'weight_decay={param["weight_decay"]:.3g}')
|
||||
logger.debug('\n'.join(msg))
|
|
@ -489,3 +489,45 @@ class MixMIMTransformer(BaseBackbone):
|
|||
x = torch.flatten(x, 1)
|
||||
|
||||
return (x, )
|
||||
|
||||
def get_layer_depth(self, param_name: str, prefix: str = ''):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
||||
Args:
|
||||
param_name (str): The name of the parameter.
|
||||
prefix (str): The prefix for the parameter.
|
||||
Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The layer-wise depth and the num of layers.
|
||||
|
||||
Note:
|
||||
The first depth is the stem module (``layer_depth=0``), and the
|
||||
last depth is the subsequent module (``layer_depth=num_layers-1``)
|
||||
"""
|
||||
num_layers = sum(self.depths) + 2
|
||||
|
||||
if not param_name.startswith(prefix):
|
||||
# For subsequent module like neck and head
|
||||
if param_name.startswith('neck'):
|
||||
return num_layers - 2, num_layers
|
||||
else:
|
||||
return num_layers - 1, num_layers
|
||||
|
||||
param_name = param_name[len(prefix):]
|
||||
|
||||
stem_layers = ('patch_embed', 'absolute_pos_embed', 'pos_embed')
|
||||
if any(stem in param_name for stem in stem_layers):
|
||||
layer_depth = 0
|
||||
elif param_name.startswith('layers'):
|
||||
layer_id = int(param_name.split('.')[1])
|
||||
block_id = param_name.split('.')[3]
|
||||
|
||||
if block_id in ('downsample', 'reduction', 'norm'):
|
||||
layer_depth = sum(self.depths[:layer_id + 1])
|
||||
else:
|
||||
layer_depth = sum(self.depths[:layer_id]) + int(block_id) + 1
|
||||
else:
|
||||
layer_depth = num_layers - 2
|
||||
|
||||
return layer_depth, num_layers
|
||||
|
|
|
@ -546,3 +546,40 @@ class SwinTransformer(BaseBackbone):
|
|||
# The index buffer need to be re-generated.
|
||||
index_buffer = ckpt_key.replace('bias_table', 'index')
|
||||
del state_dict[index_buffer]
|
||||
|
||||
def get_layer_depth(self, param_name: str, prefix: str = ''):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
||||
Args:
|
||||
param_name (str): The name of the parameter.
|
||||
prefix (str): The prefix for the parameter.
|
||||
Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The layer-wise depth and the num of layers.
|
||||
|
||||
Note:
|
||||
The first depth is the stem module (``layer_depth=0``), and the
|
||||
last depth is the subsequent module (``layer_depth=num_layers-1``)
|
||||
"""
|
||||
num_layers = sum(self.depths) + 2
|
||||
|
||||
if not param_name.startswith(prefix):
|
||||
# For subsequent module like head
|
||||
return num_layers - 1, num_layers
|
||||
|
||||
param_name = param_name[len(prefix):]
|
||||
|
||||
if param_name.startswith('patch_embed'):
|
||||
layer_depth = 0
|
||||
elif param_name.startswith('stages'):
|
||||
stage_id = int(param_name.split('.')[1])
|
||||
block_id = param_name.split('.')[3]
|
||||
if block_id in ('reduction', 'norm'):
|
||||
layer_depth = sum(self.depths[:stage_id + 1])
|
||||
else:
|
||||
layer_depth = sum(self.depths[:stage_id]) + int(block_id) + 1
|
||||
else:
|
||||
layer_depth = num_layers - 1
|
||||
|
||||
return layer_depth, num_layers
|
||||
|
|
|
@ -461,3 +461,38 @@ class VisionTransformer(BaseBackbone):
|
|||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def get_layer_depth(self, param_name: str, prefix: str = ''):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
||||
Args:
|
||||
param_name (str): The name of the parameter.
|
||||
prefix (str): The prefix for the parameter.
|
||||
Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The layer-wise depth and the num of layers.
|
||||
|
||||
Note:
|
||||
The first depth is the stem module (``layer_depth=0``), and the
|
||||
last depth is the subsequent module (``layer_depth=num_layers-1``)
|
||||
"""
|
||||
num_layers = self.num_layers + 2
|
||||
|
||||
if not param_name.startswith(prefix):
|
||||
# For subsequent module like head
|
||||
return num_layers - 1, num_layers
|
||||
|
||||
param_name = param_name[len(prefix):]
|
||||
|
||||
if param_name in ('cls_token', 'pos_embed'):
|
||||
layer_depth = 0
|
||||
elif param_name.startswith('patch_embed'):
|
||||
layer_depth = 0
|
||||
elif param_name.startswith('layers'):
|
||||
layer_id = int(param_name.split('.')[1])
|
||||
layer_depth = layer_id + 1
|
||||
else:
|
||||
layer_depth = num_layers - 1
|
||||
|
||||
return layer_depth, num_layers
|
||||
|
|
|
@ -241,3 +241,19 @@ class ImageClassifier(BaseClassifier):
|
|||
"""
|
||||
feats = self.extract_feat(inputs)
|
||||
return self.head.predict(feats, data_samples, **kwargs)
|
||||
|
||||
def get_layer_depth(self, param_name: str):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
||||
Args:
|
||||
param_name (str): The name of the parameter.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The layer-wise depth and the max depth.
|
||||
"""
|
||||
if hasattr(self.backbone, 'get_layer_depth'):
|
||||
return self.backbone.get_layer_depth(param_name, 'backbone.')
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The babckone {type(self.backbone)} doesn't "
|
||||
'support `get_layer_depth` by now.')
|
||||
|
|
|
@ -161,3 +161,19 @@ class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta):
|
|||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_layer_depth(self, param_name: str):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
||||
Args:
|
||||
param_name (str): The name of the parameter.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The layer-wise depth and the max depth.
|
||||
"""
|
||||
if hasattr(self.backbone, 'get_layer_depth'):
|
||||
return self.backbone.get_layer_depth(param_name, 'backbone.')
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The babckone {type(self.backbone)} doesn't "
|
||||
'support `get_layer_depth` by now.')
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.engine import LearningRateDecayOptimWrapperConstructor
|
||||
from mmpretrain.models import ImageClassifier, VisionTransformer
|
||||
|
||||
|
||||
class ToyViTBackbone(nn.Module):
|
||||
|
||||
get_layer_depth = VisionTransformer.get_layer_depth
|
||||
|
||||
def __init__(self, num_layers=2):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.ones(1))
|
||||
self.pos_embed = nn.Parameter(torch.ones(1))
|
||||
self.num_layers = num_layers
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
layer = nn.Conv2d(3, 3, 1)
|
||||
self.layers.append(layer)
|
||||
|
||||
|
||||
class ToyViT(nn.Module):
|
||||
get_layer_depth = ImageClassifier.get_layer_depth
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# add some variables to meet unit test coverate rate
|
||||
self.backbone = ToyViTBackbone()
|
||||
self.head = nn.Linear(1, 1)
|
||||
|
||||
|
||||
class TestLearningRateDecayOptimWrapperConstructor(TestCase):
|
||||
base_lr = 1.0
|
||||
base_wd = 0.05
|
||||
|
||||
def test_add_params(self):
|
||||
model = ToyViT()
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=self.base_lr,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=self.base_wd))
|
||||
paramwise_cfg = dict(
|
||||
layer_decay_rate=2.0,
|
||||
bias_decay_mult=0.,
|
||||
custom_keys={
|
||||
'.cls_token': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0),
|
||||
})
|
||||
|
||||
constructor = LearningRateDecayOptimWrapperConstructor(
|
||||
optim_wrapper_cfg=optim_wrapper_cfg,
|
||||
paramwise_cfg=paramwise_cfg,
|
||||
)
|
||||
optimizer_wrapper = constructor(model)
|
||||
|
||||
expected_groups = [{
|
||||
'weight_decay': 0.0,
|
||||
'lr': 8 * self.base_lr,
|
||||
'param_name': 'backbone.cls_token',
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr': 8 * self.base_lr,
|
||||
'param_name': 'backbone.pos_embed',
|
||||
}, {
|
||||
'weight_decay': self.base_wd,
|
||||
'lr': 4 * self.base_lr,
|
||||
'param_name': 'backbone.layers.0.weight',
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr': 4 * self.base_lr,
|
||||
'param_name': 'backbone.layers.0.bias',
|
||||
}, {
|
||||
'weight_decay': self.base_wd,
|
||||
'lr': 2 * self.base_lr,
|
||||
'param_name': 'backbone.layers.1.weight',
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr': 2 * self.base_lr,
|
||||
'param_name': 'backbone.layers.1.bias',
|
||||
}, {
|
||||
'weight_decay': self.base_wd,
|
||||
'lr': 1 * self.base_lr,
|
||||
'param_name': 'head.weight',
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr': 1 * self.base_lr,
|
||||
'param_name': 'head.bias',
|
||||
}]
|
||||
self.assertIsInstance(optimizer_wrapper.optimizer, torch.optim.AdamW)
|
||||
self.assertEqual(optimizer_wrapper.optimizer.defaults['lr'],
|
||||
self.base_lr)
|
||||
self.assertEqual(optimizer_wrapper.optimizer.defaults['weight_decay'],
|
||||
self.base_wd)
|
||||
param_groups = optimizer_wrapper.optimizer.param_groups
|
||||
self.assertEqual(len(param_groups), len(expected_groups))
|
||||
|
||||
for expect, param in zip(expected_groups, param_groups):
|
||||
self.assertEqual(param['param_name'], expect['param_name'])
|
||||
self.assertEqual(param['lr'], expect['lr'])
|
||||
self.assertEqual(param['weight_decay'], expect['weight_decay'])
|
Loading…
Reference in New Issue