[Refactor]: Refactor layer wise decay learning rate optimizer constructor
parent
2f2813ecd4
commit
a10a6bd64f
|
@ -1,10 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import build_optimizer
|
||||
from .constructor import DefaultOptimizerConstructor
|
||||
from .layer_decay_optimizer_constructor import \
|
||||
LearningRateDecayOptimizerConstructor
|
||||
from .optimizers import LARS
|
||||
from .transformer_finetune_constructor import TransformerFinetuneConstructor
|
||||
|
||||
__all__ = [
|
||||
'LARS', 'build_optimizer', 'TransformerFinetuneConstructor',
|
||||
'DefaultOptimizerConstructor'
|
||||
]
|
||||
__all__ = ['LARS', 'LearningRateDecayOptimizerConstructor']
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
from mmcv.runner.optimizer.builder import build_optimizer_constructor
|
||||
|
||||
|
||||
def build_optimizer(model, optimizer_cfg):
|
||||
"""Build optimizer from configs.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`): The model with parameters to be optimized.
|
||||
optimizer_cfg (dict): The config dict of the optimizer.
|
||||
Positional fields are:
|
||||
- type: class name of the optimizer.
|
||||
- lr: base learning rate.
|
||||
Optional fields are:
|
||||
- any arguments of the corresponding optimizer type, e.g.,
|
||||
weight_decay, momentum, etc.
|
||||
- paramwise_options: a dict with regular expression as keys
|
||||
to match parameter names and a dict containing options as
|
||||
values. Options include 6 fields: lr, lr_mult, momentum,
|
||||
momentum_mult, weight_decay, weight_decay_mult.
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: The initialized optimizer.
|
||||
|
||||
Example:
|
||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||
>>> paramwise_options = {
|
||||
>>> '(bn|gn)(\\d+)?.(weight|bias)': dict(weight_decay_mult=0.1),
|
||||
>>> '\\Ahead.': dict(lr_mult=10, momentum=0)}
|
||||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
||||
>>> weight_decay=0.0001,
|
||||
>>> paramwise_options=paramwise_options)
|
||||
>>> optimizer = build_optimizer(model, optimizer_cfg)
|
||||
"""
|
||||
optimizer_cfg = copy.deepcopy(optimizer_cfg)
|
||||
constructor_type = optimizer_cfg.pop('constructor',
|
||||
'DefaultOptimizerConstructor')
|
||||
paramwise_cfg = optimizer_cfg.pop('paramwise_options', None)
|
||||
optim_constructor = build_optimizer_constructor(
|
||||
dict(
|
||||
type=constructor_type,
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
paramwise_cfg=paramwise_cfg))
|
||||
optimizer = optim_constructor(model)
|
||||
return optimizer
|
|
@ -1,81 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import re
|
||||
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner.optimizer.builder import OPTIMIZER_BUILDERS, OPTIMIZERS
|
||||
from mmcv.utils import build_from_cfg
|
||||
from mmengine.logging import print_log
|
||||
|
||||
|
||||
@OPTIMIZER_BUILDERS.register_module(force=True)
|
||||
class DefaultOptimizerConstructor:
|
||||
"""Rewrote default constructor for optimizers. By default each parameter
|
||||
share the same optimizer settings, and we provide an argument
|
||||
``paramwise_cfg`` to specify parameter-wise settings. It is a dict and may
|
||||
contain the following fields:
|
||||
Args:
|
||||
model (:obj:`nn.Module`): The model with parameters to be optimized.
|
||||
optimizer_cfg (dict): The config dict of the optimizer.
|
||||
Positional fields are
|
||||
- `type`: class name of the optimizer.
|
||||
Optional fields are
|
||||
- any arguments of the corresponding optimizer type, e.g.,
|
||||
lr, weight_decay, momentum, etc.
|
||||
paramwise_cfg (dict, optional): Parameter-wise options.
|
||||
Defaults to None.
|
||||
Example 1:
|
||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
||||
>>> weight_decay=0.0001)
|
||||
>>> paramwise_cfg = dict('bias': dict(weight_decay=0., \
|
||||
lars_exclude=True))
|
||||
>>> optim_builder = DefaultOptimizerConstructor(
|
||||
>>> optimizer_cfg, paramwise_cfg)
|
||||
>>> optimizer = optim_builder(model)
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
if not isinstance(optimizer_cfg, dict):
|
||||
raise TypeError('optimizer_cfg should be a dict',
|
||||
f'but got {type(optimizer_cfg)}')
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
|
||||
|
||||
def __call__(self, model):
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
optimizer_cfg = self.optimizer_cfg.copy()
|
||||
paramwise_options = self.paramwise_cfg
|
||||
|
||||
# if no paramwise option is specified, just use the global setting
|
||||
if paramwise_options is None:
|
||||
optimizer_cfg['params'] = model.parameters()
|
||||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
||||
else:
|
||||
assert isinstance(paramwise_options, dict)
|
||||
params = []
|
||||
for name, param in model.named_parameters():
|
||||
param_group = {'params': [param]}
|
||||
if not param.requires_grad:
|
||||
params.append(param_group)
|
||||
continue
|
||||
|
||||
for regexp, options in paramwise_options.items():
|
||||
if re.search(regexp, name):
|
||||
for key, value in options.items():
|
||||
if key.endswith('_mult'): # is a multiplier
|
||||
key = key[:-5]
|
||||
assert key in optimizer_cfg, \
|
||||
f'{key} not in optimizer_cfg'
|
||||
value = optimizer_cfg[key] * value
|
||||
param_group[key] = value
|
||||
if not dist.is_initialized() or \
|
||||
dist.get_rank() == 0:
|
||||
print_log(f'paramwise_options -- \
|
||||
{name}: {key}={value}')
|
||||
|
||||
# otherwise use the global settings
|
||||
params.append(param_group)
|
||||
|
||||
optimizer_cfg['params'] = params
|
||||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.optim import DefaultOptimizerConstructor
|
||||
from torch import nn
|
||||
|
||||
from mmselfsup.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
||||
from mmselfsup.utils import get_root_logger
|
||||
|
||||
|
||||
def get_layer_id_for_vit(var_name: str, max_layer_id: int) -> int:
|
||||
"""Get the layer id to set the different learning rates.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
num_max_layer (int): Maximum number of backbone layers.
|
||||
Returns:
|
||||
int: Returns the layer id of the key.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.patch_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.layers'):
|
||||
layer_id = int(var_name.split('.')[2])
|
||||
return layer_id + 1
|
||||
else:
|
||||
return max_layer_id - 1
|
||||
|
||||
|
||||
def get_layer_id_for_swin(var_name: str, max_layer_id: int,
|
||||
depths: List[int]) -> int:
|
||||
"""Get the layer id to set the different learning rates.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
num_max_layer (int): Maximum number of backbone layers.
|
||||
depths (List[int]): Depths for each stage.
|
||||
Returns:
|
||||
int: Returns the layer id of the key.
|
||||
"""
|
||||
if 'mask_token' in var_name:
|
||||
return 0
|
||||
elif 'patch_embed' in var_name:
|
||||
return 0
|
||||
elif var_name.startswith('backbone.stages'):
|
||||
layer_id = int(var_name.split('.')[2])
|
||||
block_id = var_name.split('.')[4]
|
||||
if block_id == 'reduction' or block_id == 'norm':
|
||||
return sum(depths[:layer_id + 1])
|
||||
layer_id = sum(depths[:layer_id]) + int(block_id)
|
||||
return layer_id + 1
|
||||
else:
|
||||
return max_layer_id - 1
|
||||
|
||||
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for ViT and Swin.
|
||||
"""
|
||||
|
||||
def add_params(self,
|
||||
params: List[dict],
|
||||
module: nn.Module,
|
||||
optimizer_cfg: Dict,
|
||||
prefix: str = '',
|
||||
is_dcn_module: Optional[Union[int, float]] = None) -> None:
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (list[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
optimizer_cfg (Dict): The configuration of optimizer.
|
||||
prefix (str): The prefix of the module
|
||||
is_dcn_module (int|float|None): If the current module is a
|
||||
submodule of DCN, `is_dcn_module` will be passed to
|
||||
control conv_offset layer's learning rate. Defaults to None.
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
|
||||
model_type = optimizer_cfg.pop('model_type', None)
|
||||
# model_type should not be None
|
||||
assert model_type is not None, 'When using layer-wise learning rate \
|
||||
decay, model_type should not be None.'
|
||||
|
||||
# currently, we only support layer-wise learning rate decay for vit
|
||||
# and swin.
|
||||
assert model_type in ['vit', 'swin'], f'Currently, we do not support \
|
||||
layer-wise learning rate decay for {model_type}'
|
||||
|
||||
if model_type == 'vit':
|
||||
num_layers = len(module.backbone.layers) + 2
|
||||
elif model_type == 'swin':
|
||||
num_layers = sum(module.backbone.depths) + 2
|
||||
|
||||
weight_decay = self.base_wd
|
||||
# if layer_decay_rate is not provided, not decay
|
||||
decay_rate = optimizer_cfg.pop('layer_decay_rate', 1.0)
|
||||
parameter_groups = {}
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue # frozen weights
|
||||
if len(param.shape) == 1 or name.endswith('.bias') or name in (
|
||||
'pos_embed', 'cls_token'):
|
||||
group_name = 'no_decay'
|
||||
this_weight_decay = 0.
|
||||
else:
|
||||
group_name = 'decay'
|
||||
this_weight_decay = weight_decay
|
||||
|
||||
if model_type == 'vit':
|
||||
layer_id = get_layer_id_for_vit(name, num_layers)
|
||||
elif model_type == 'swin':
|
||||
layer_id = get_layer_id_for_swin(name, num_layers,
|
||||
module.backbone.depths)
|
||||
|
||||
group_name = f'layer_{layer_id}_{group_name}'
|
||||
if group_name not in parameter_groups:
|
||||
scale = decay_rate**(num_layers - layer_id - 1)
|
||||
|
||||
parameter_groups[group_name] = {
|
||||
'weight_decay': this_weight_decay,
|
||||
'params': [],
|
||||
'param_names': [],
|
||||
'lr_scale': scale,
|
||||
'group_name': group_name,
|
||||
'lr': scale * self.base_lr,
|
||||
}
|
||||
|
||||
parameter_groups[group_name]['params'].append(param)
|
||||
parameter_groups[group_name]['param_names'].append(name)
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
to_display = {}
|
||||
for key in parameter_groups:
|
||||
to_display[key] = {
|
||||
'param_names': parameter_groups[key]['param_names'],
|
||||
'lr_scale': parameter_groups[key]['lr_scale'],
|
||||
'lr': parameter_groups[key]['lr'],
|
||||
'weight_decay': parameter_groups[key]['weight_decay'],
|
||||
}
|
||||
logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
|
||||
params.extend(parameter_groups.values())
|
||||
|
||||
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
|
||||
"""When paramwise option is None, also use layer wise learning rate
|
||||
decay."""
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
optimizer_cfg = self.optimizer_cfg.copy()
|
||||
|
||||
# set param-wise lr and weight decay recursively
|
||||
params: List = []
|
||||
self.add_params(params, model, optimizer_cfg)
|
||||
optimizer_cfg['params'] = params
|
||||
|
||||
return OPTIMIZERS.build(optimizer_cfg)
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.runner.optimizer.builder import OPTIMIZERS
|
||||
from torch.optim import * # noqa: F401,F403
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from mmselfsup.registry import OPTIMIZERS
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module()
|
||||
|
@ -37,14 +37,14 @@ class LARS(Optimizer):
|
|||
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=required,
|
||||
lr=float,
|
||||
momentum=0,
|
||||
weight_decay=0,
|
||||
dampening=0,
|
||||
eta=0.001,
|
||||
nesterov=False,
|
||||
eps=1e-8):
|
||||
if lr is not required and lr < 0.0:
|
||||
if not isinstance(lr, float) and lr < 0.0:
|
||||
raise ValueError(f'Invalid learning rate: {lr}')
|
||||
if momentum < 0.0:
|
||||
raise ValueError(f'Invalid momentum value: {momentum}')
|
||||
|
|
|
@ -1,158 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import re
|
||||
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner.optimizer.builder import OPTIMIZER_BUILDERS, OPTIMIZERS
|
||||
from mmcv.utils import build_from_cfg
|
||||
from mmengine.logging import print_log
|
||||
|
||||
|
||||
@OPTIMIZER_BUILDERS.register_module()
|
||||
class TransformerFinetuneConstructor:
|
||||
"""Rewrote default constructor for optimizers.
|
||||
|
||||
By default each parameter share the same optimizer settings, and we
|
||||
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
||||
In addition, we provide two optional parameters, ``model_type`` and
|
||||
``layer_decay`` to set the commonly used layer-wise learning rate decay
|
||||
schedule. Currently, we only support layer-wise learning rate schedule
|
||||
for swin and vit.
|
||||
|
||||
Args:
|
||||
optimizer_cfg (dict): The config dict of the optimizer.
|
||||
Positional fields are
|
||||
- `type`: class name of the optimizer.
|
||||
Optional fields are
|
||||
- any arguments of the corresponding optimizer type, e.g.,
|
||||
lr, weight_decay, momentum, model_type, layer_decay, etc.
|
||||
paramwise_cfg (dict, optional): Parameter-wise options.
|
||||
Defaults to None.
|
||||
|
||||
|
||||
Example 1:
|
||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
||||
>>> weight_decay=0.0001, model_type='vit')
|
||||
>>> paramwise_cfg = dict('bias': dict(weight_decay=0., \
|
||||
lars_exclude=True))
|
||||
>>> optim_builder = TransformerFinetuneConstructor(
|
||||
>>> optimizer_cfg, paramwise_cfg)
|
||||
>>> optimizer = optim_builder(model)
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
if not isinstance(optimizer_cfg, dict):
|
||||
raise TypeError('optimizer_cfg should be a dict',
|
||||
f'but got {type(optimizer_cfg)}')
|
||||
self.optimizer_cfg = optimizer_cfg
|
||||
self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
|
||||
self.layer_decay = self.optimizer_cfg.pop('layer_decay', 0.0)
|
||||
# Choose which type of layer-wise lr decay to use. Currently, we only
|
||||
# support ViT and Swin.
|
||||
self.model_type = self.optimizer_cfg.pop('model_type', None)
|
||||
|
||||
def __call__(self, model):
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
optimizer_cfg = self.optimizer_cfg.copy()
|
||||
paramwise_options = self.paramwise_cfg
|
||||
|
||||
# generate layer-wise lr decay
|
||||
if self.layer_decay > 0:
|
||||
if self.model_type == 'swin':
|
||||
self._generate_swin_layer_wise_lr_decay(
|
||||
model, paramwise_options)
|
||||
elif self.model_type == 'vit':
|
||||
self._generate_vit_layer_wise_lr_decay(model,
|
||||
paramwise_options)
|
||||
else:
|
||||
raise NotImplementedError(f'Currently, we do not support \
|
||||
layer-wise lr decay for {self.model_type}')
|
||||
|
||||
# if no paramwise option is specified, just use the global setting
|
||||
if paramwise_options is None:
|
||||
optimizer_cfg['params'] = model.parameters()
|
||||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
||||
else:
|
||||
assert isinstance(paramwise_options, dict)
|
||||
params = []
|
||||
for name, param in model.named_parameters():
|
||||
param_group = {'params': [param]}
|
||||
if not param.requires_grad:
|
||||
params.append(param_group)
|
||||
continue
|
||||
|
||||
for regexp, options in paramwise_options.items():
|
||||
if re.search(regexp, name):
|
||||
for key, value in options.items():
|
||||
if key.endswith('_mult'): # is a multiplier
|
||||
key = key[:-5]
|
||||
assert key in optimizer_cfg, \
|
||||
f'{key} not in optimizer_cfg'
|
||||
value = optimizer_cfg[key] * value
|
||||
param_group[key] = value
|
||||
if not dist.is_initialized() or \
|
||||
dist.get_rank() == 0:
|
||||
print_log(f'paramwise_options -- \
|
||||
{name}: {key}={value}')
|
||||
|
||||
# otherwise use the global settings
|
||||
params.append(param_group)
|
||||
|
||||
optimizer_cfg['params'] = params
|
||||
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
||||
|
||||
def _generate_swin_layer_wise_lr_decay(self, model, paramwise_options):
|
||||
"""Generate layer-wise learning rate decay for Swin Transformer."""
|
||||
num_layers = sum(model.backbone.depths) + 2
|
||||
layer_scales = list(self.layer_decay**i
|
||||
for i in reversed(range(num_layers)))
|
||||
|
||||
for name, _ in model.named_parameters():
|
||||
|
||||
layer_id = self._get_swin_layer(name, num_layers,
|
||||
model.backbone.depths)
|
||||
paramwise_options[name] = dict(lr_mult=layer_scales[layer_id])
|
||||
|
||||
def _get_swin_layer(self, name, num_layers, depths):
|
||||
if 'mask_token' in name:
|
||||
return 0
|
||||
elif 'patch_embed' in name:
|
||||
return 0
|
||||
elif name.startswith('backbone.stages'):
|
||||
layer_id = int(name.split('.')[2])
|
||||
block_id = name.split('.')[4]
|
||||
if block_id == 'reduction' or block_id == 'norm':
|
||||
return sum(depths[:layer_id + 1])
|
||||
layer_id = sum(depths[:layer_id]) + int(block_id)
|
||||
return layer_id + 1
|
||||
else:
|
||||
return num_layers - 1
|
||||
|
||||
def _generate_vit_layer_wise_lr_decay(self, model, paramwise_options):
|
||||
"""Generate layer-wise learning rate decay for Vision Transformer."""
|
||||
num_layers = len(model.backbone.layers) + 1
|
||||
layer_scales = list(self.layer_decay**(num_layers - i)
|
||||
for i in range(num_layers + 1))
|
||||
|
||||
if 'pos_embed' in paramwise_options:
|
||||
paramwise_options['pos_embed'].update(
|
||||
dict(lr_mult=layer_scales[0]))
|
||||
else:
|
||||
paramwise_options['pos_embed'] = dict(lr_mult=layer_scales[0])
|
||||
|
||||
if 'cls_token' in paramwise_options:
|
||||
paramwise_options['cls_token'].update(
|
||||
dict(lr_mult=layer_scales[0]))
|
||||
else:
|
||||
paramwise_options['cls_token'] = dict(lr_mult=layer_scales[0])
|
||||
|
||||
if 'patch_embed' in paramwise_options:
|
||||
paramwise_options['patch_embed'].update(
|
||||
dict(lr_mult=layer_scales[0]))
|
||||
else:
|
||||
paramwise_options['patch_embed'] = dict(lr_mult=layer_scales[0])
|
||||
|
||||
for i in range(num_layers - 1):
|
||||
paramwise_options[f'backbone\\.layers\\.{i}\\.'] = dict(
|
||||
lr_mult=layer_scales[i + 1])
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcls.models import SwinTransformer
|
||||
from torch import nn
|
||||
|
||||
from mmselfsup.core import LearningRateDecayOptimizerConstructor
|
||||
|
||||
|
||||
class ToyViTBackbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.ones(1))
|
||||
self.patch_embed = nn.Parameter(torch.ones(1))
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(2):
|
||||
layer = nn.Conv2d(3, 3, 1)
|
||||
self.layers.append(layer)
|
||||
|
||||
|
||||
class ToyViT(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# add some variables to meet unit test coverate rate
|
||||
self.backbone = ToyViTBackbone()
|
||||
self.head = nn.Linear(1, 1)
|
||||
|
||||
|
||||
class ToySwin(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# add some variables to meet unit test coverate rate
|
||||
self.backbone = SwinTransformer()
|
||||
self.head = nn.Linear(1, 1)
|
||||
|
||||
|
||||
expected_layer_wise_wd_lr_vit = [{
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 8
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 4
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 4
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 2
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 2
|
||||
}, {
|
||||
'weight_decay': 0.05,
|
||||
'lr_scale': 1
|
||||
}, {
|
||||
'weight_decay': 0.0,
|
||||
'lr_scale': 1
|
||||
}]
|
||||
|
||||
base_lr = 1.0
|
||||
base_wd = 0.05
|
||||
|
||||
|
||||
def check_optimizer_lr_wd(optimizer, gt_lr_wd):
|
||||
assert isinstance(optimizer, torch.optim.AdamW)
|
||||
assert optimizer.defaults['lr'] == base_lr
|
||||
assert optimizer.defaults['weight_decay'] == base_wd
|
||||
param_groups = optimizer.param_groups
|
||||
assert len(param_groups) == len(gt_lr_wd)
|
||||
for i, param_dict in enumerate(param_groups):
|
||||
assert param_dict['weight_decay'] == gt_lr_wd[i]['weight_decay']
|
||||
assert param_dict['lr_scale'] == gt_lr_wd[i]['lr_scale']
|
||||
assert param_dict['lr_scale'] == param_dict['lr']
|
||||
|
||||
|
||||
def test_learning_rate_decay_optimizer_constructor():
|
||||
model = ToyViT()
|
||||
optimizer_config = dict(
|
||||
type='AdamW',
|
||||
lr=base_lr,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=base_wd,
|
||||
model_type='vit',
|
||||
layer_decay_rate=2.0)
|
||||
|
||||
# test when model_type is None
|
||||
with pytest.raises(AssertionError):
|
||||
optimizer_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg=optimizer_config)
|
||||
optimizer_config['model_type'] = None
|
||||
optimizer = optimizer_constructor(model)
|
||||
|
||||
# test when model_type is invalid
|
||||
with pytest.raises(AssertionError):
|
||||
optimizer_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg=optimizer_config)
|
||||
optimizer_config['model_type'] = 'invalid'
|
||||
optimizer = optimizer_constructor(model)
|
||||
|
||||
# test vit
|
||||
optimizer_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg=optimizer_config)
|
||||
optimizer_config['model_type'] = 'vit'
|
||||
optimizer = optimizer_constructor(model)
|
||||
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_vit)
|
||||
|
||||
# test swin
|
||||
model = ToySwin()
|
||||
optimizer_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg=optimizer_config)
|
||||
optimizer_config['model_type'] = 'swin'
|
||||
optimizer = optimizer_constructor(model)
|
||||
assert optimizer.param_groups[-1]['lr_scale'] == 1.0
|
||||
assert optimizer.param_groups[-3]['lr_scale'] == 2.0
|
||||
assert optimizer.param_groups[-5]['lr_scale'] == 4.0
|
Loading…
Reference in New Issue