[Feature] Implement layer-wise learning rate decay optimizer constructor. (#1399)

* [Feature] Implement layer-wise learning rate decay optimizer constructor.

* Use num_layers instead of max_depth to avoid misleading

* Add UT

* Update docstring

* Update log info

* update LearningRateDecay configs

---------

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
pull/1408/head
Ma Zerun 2023-03-07 17:30:39 +08:00 committed by GitHub
parent 827be6e22d
commit 274a67223e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 452 additions and 79 deletions

View File

@ -84,16 +84,11 @@ test_dataloader = val_dataloader
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit',
layer_decay_rate=0.65),
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
_delete_=True,
layer_decay_rate=0.65,
custom_keys={
# the following configurations are designed for BEiT
'.ln': dict(decay_mult=0.0),

View File

@ -77,17 +77,12 @@ test_dataloader = val_dataloader
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=5e-4,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit',
# 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs
layer_decay_rate=0.65),
type='AdamW', lr=5e-4, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
_delete_=True,
# 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs
layer_decay_rate=0.65,
custom_keys={
# the following configurations are designed for BEiT
'.ln': dict(decay_mult=0.0),

View File

@ -92,14 +92,10 @@ model = dict(
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=8e-3,
betas=(0.9, 0.999),
weight_decay=0.05,
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.65), # layer-wise lr decay factor
type='AdamW', lr=8e-3, betas=(0.9, 0.999), weight_decay=0.05),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.65,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -76,15 +76,10 @@ model = dict(
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-4,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.65), # layer-wise lr decay factor
type='AdamW', lr=4e-4, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.65,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -75,15 +75,10 @@ model = dict(
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=2e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.65), # layer-wise lr decay factor
type='AdamW', lr=2e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.65,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -78,15 +78,10 @@ model = dict(
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.75), # layer-wise lr decay factor
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -77,15 +77,10 @@ model = dict(
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.75), # layer-wise lr decay factor
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -77,15 +77,10 @@ model = dict(
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.75), # layer-wise lr decay factor
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -76,15 +76,10 @@ model = dict(
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=8e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.65), # layer-wise lr decay factor
type='AdamW', lr=8e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.65,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -76,15 +76,10 @@ model = dict(
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-4,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
model_type='vit', # layer-wise lr decay type
layer_decay_rate=0.65), # layer-wise lr decay factor
type='AdamW', lr=4e-4, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.65,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -92,12 +92,11 @@ optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=5e-4 * (8 * 128 / 256),
model_type='mixmim',
layer_decay_rate=0.7,
betas=(0.9, 0.999),
weight_decay=0.05),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.7,
custom_keys={
'.ln': dict(decay_mult=0.0), # do not decay on ln and bias
'.bias': dict(decay_mult=0.0)

View File

@ -14,11 +14,11 @@ model = dict(
# optimizer settings
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=5e-3, model_type='swin', layer_decay_rate=0.9),
optimizer=dict(type='AdamW', lr=5e-3),
clip_grad=dict(max_norm=5.0),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.9,
custom_keys={
'.norm': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -57,11 +57,11 @@ model = dict(
# optimizer settings
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=5e-3, model_type='swin', layer_decay_rate=0.9),
optimizer=dict(type='AdamW', lr=5e-3),
clip_grad=dict(max_norm=5.0),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.9,
custom_keys={
'.norm': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -60,11 +60,11 @@ model = dict(
# optimizer settings
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=5e-3, model_type='swin', layer_decay_rate=0.7),
optimizer=dict(type='AdamW', lr=5e-3),
clip_grad=dict(max_norm=5.0),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.7,
custom_keys={
'.norm': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),

View File

@ -47,3 +47,5 @@ Optimizers
:nosignatures:
Lamb
LARS
LearningRateDecayOptimWrapperConstructor

View File

@ -2,5 +2,7 @@
from .adan_t import Adan
from .lamb import Lamb
from .lars import LARS
from .layer_decay_optim_wrapper_constructor import \
LearningRateDecayOptimWrapperConstructor
__all__ = ['Lamb', 'Adan', 'LARS']
__all__ = ['Lamb', 'Adan', 'LARS', 'LearningRateDecayOptimWrapperConstructor']

View File

@ -0,0 +1,166 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Callable, List, Optional
from mmengine.logging import MMLogger
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from torch import nn
from torch.nn import GroupNorm, LayerNorm
from mmpretrain.registry import OPTIM_WRAPPER_CONSTRUCTORS
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone.
By default, each parameter share the same optimizer settings, and we
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
It is a dict and may contain the following fields:
- ``layer_decay_rate`` (float): The learning rate of a parameter will
multiply it by multiple times according to the layer depth of the
parameter. Usually, it's less than 1, so that the earlier layers will
have a lower learning rate. Defaults to 1.
- ``bias_decay_mult`` (float): It will be multiplied to the weight
decay for all bias parameters (except for those in normalization layers).
- ``norm_decay_mult`` (float): It will be multiplied to the weight
decay for all weight and bias parameters of normalization layers.
- ``flat_decay_mult`` (float): It will be multiplied to the weight
decay for all one-dimensional parameters
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
one of the keys in ``custom_keys`` is a substring of the name of one
parameter, then the setting of the parameter will be specified by
``custom_keys[key]`` and other setting like ``bias_decay_mult`` will be
ignored. It should be a dict and may contain fields ``decay_mult``.
(The ``lr_mult`` is disabled in this constructor).
Example:
In the config file, you can use this constructor as below:
.. code:: python
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=4e-3,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75, # layer-wise lr decay factor
norm_decay_mult=0.,
flat_decay_mult=0.,
custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
"""
def add_params(self,
params: List[dict],
module: nn.Module,
prefix: str = '',
get_layer_depth: Optional[Callable] = None,
**kwargs) -> None:
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (List[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
optimizer_cfg (dict): The configuration of optimizer.
prefix (str): The prefix of the module.
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
logger = MMLogger.get_current_instance()
# The model should have `get_layer_depth` method
if get_layer_depth is None and not hasattr(module, 'get_layer_depth'):
raise NotImplementedError('The layer-wise learning rate decay need'
f' the model {type(module)} has'
' `get_layer_depth` method.')
else:
get_layer_depth = get_layer_depth or module.get_layer_depth
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
decay_rate = self.paramwise_cfg.get('layer_decay_rate', 1.0)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
for name, param in module.named_parameters(recurse=False):
param_group = {'params': [param]}
param_name = prefix + name
if not param.requires_grad:
continue
if self.base_wd is not None:
base_wd = self.base_wd
custom_key = next(
filter(lambda k: k in param_name, sorted_keys), None)
# custom parameters decay
if custom_key is not None:
custom_cfg = custom_keys[custom_key].copy()
decay_mult = custom_cfg.pop('decay_mult', 1.)
param_group['weight_decay'] = base_wd * decay_mult
# add custom settings to param_group
param_group.update(custom_cfg)
# norm decay
elif is_norm and norm_decay_mult is not None:
param_group['weight_decay'] = base_wd * norm_decay_mult
# bias decay
elif name == 'bias' and bias_decay_mult is not None:
param_group['weight_decay'] = base_wd * bias_decay_mult
# flatten parameters decay
elif param.ndim == 1 and flat_decay_mult is not None:
param_group['weight_decay'] = base_wd * flat_decay_mult
else:
param_group['weight_decay'] = base_wd
layer_id, max_id = get_layer_depth(param_name)
scale = decay_rate**(max_id - layer_id - 1)
param_group['lr'] = self.base_lr * scale
param_group['lr_scale'] = scale
param_group['layer_id'] = layer_id
param_group['param_name'] = param_name
params.append(param_group)
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}{child_name}.'
self.add_params(
params,
child_mod,
prefix=child_prefix,
get_layer_depth=get_layer_depth,
)
if prefix == '':
layer_params = defaultdict(list)
for param in params:
layer_params[param['layer_id']].append(param)
for layer_id, layer_params in layer_params.items():
lr_scale = layer_params[0]['lr_scale']
lr = layer_params[0]['lr']
msg = [
f'layer {layer_id} params '
f'(lr={lr:.3g}, lr_scale={lr_scale:.3g}):'
]
for param in layer_params:
msg.append(f'\t{param["param_name"]}: '
f'weight_decay={param["weight_decay"]:.3g}')
logger.debug('\n'.join(msg))

View File

@ -489,3 +489,45 @@ class MixMIMTransformer(BaseBackbone):
x = torch.flatten(x, 1)
return (x, )
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers = sum(self.depths) + 2
if not param_name.startswith(prefix):
# For subsequent module like neck and head
if param_name.startswith('neck'):
return num_layers - 2, num_layers
else:
return num_layers - 1, num_layers
param_name = param_name[len(prefix):]
stem_layers = ('patch_embed', 'absolute_pos_embed', 'pos_embed')
if any(stem in param_name for stem in stem_layers):
layer_depth = 0
elif param_name.startswith('layers'):
layer_id = int(param_name.split('.')[1])
block_id = param_name.split('.')[3]
if block_id in ('downsample', 'reduction', 'norm'):
layer_depth = sum(self.depths[:layer_id + 1])
else:
layer_depth = sum(self.depths[:layer_id]) + int(block_id) + 1
else:
layer_depth = num_layers - 2
return layer_depth, num_layers

View File

@ -546,3 +546,40 @@ class SwinTransformer(BaseBackbone):
# The index buffer need to be re-generated.
index_buffer = ckpt_key.replace('bias_table', 'index')
del state_dict[index_buffer]
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers = sum(self.depths) + 2
if not param_name.startswith(prefix):
# For subsequent module like head
return num_layers - 1, num_layers
param_name = param_name[len(prefix):]
if param_name.startswith('patch_embed'):
layer_depth = 0
elif param_name.startswith('stages'):
stage_id = int(param_name.split('.')[1])
block_id = param_name.split('.')[3]
if block_id in ('reduction', 'norm'):
layer_depth = sum(self.depths[:stage_id + 1])
else:
layer_depth = sum(self.depths[:stage_id]) + int(block_id) + 1
else:
layer_depth = num_layers - 1
return layer_depth, num_layers

View File

@ -461,3 +461,38 @@ class VisionTransformer(BaseBackbone):
outs.append(out)
return tuple(outs)
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers = self.num_layers + 2
if not param_name.startswith(prefix):
# For subsequent module like head
return num_layers - 1, num_layers
param_name = param_name[len(prefix):]
if param_name in ('cls_token', 'pos_embed'):
layer_depth = 0
elif param_name.startswith('patch_embed'):
layer_depth = 0
elif param_name.startswith('layers'):
layer_id = int(param_name.split('.')[1])
layer_depth = layer_id + 1
else:
layer_depth = num_layers - 1
return layer_depth, num_layers

View File

@ -241,3 +241,19 @@ class ImageClassifier(BaseClassifier):
"""
feats = self.extract_feat(inputs)
return self.head.predict(feats, data_samples, **kwargs)
def get_layer_depth(self, param_name: str):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
Returns:
Tuple[int, int]: The layer-wise depth and the max depth.
"""
if hasattr(self.backbone, 'get_layer_depth'):
return self.backbone.get_layer_depth(param_name, 'backbone.')
else:
raise NotImplementedError(
f"The babckone {type(self.backbone)} doesn't "
'support `get_layer_depth` by now.')

View File

@ -161,3 +161,19 @@ class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta):
dict[str, Tensor]: A dictionary of loss components.
"""
raise NotImplementedError
def get_layer_depth(self, param_name: str):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
Returns:
Tuple[int, int]: The layer-wise depth and the max depth.
"""
if hasattr(self.backbone, 'get_layer_depth'):
return self.backbone.get_layer_depth(param_name, 'backbone.')
else:
raise NotImplementedError(
f"The babckone {type(self.backbone)} doesn't "
'support `get_layer_depth` by now.')

View File

@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from torch import nn
from mmpretrain.engine import LearningRateDecayOptimWrapperConstructor
from mmpretrain.models import ImageClassifier, VisionTransformer
class ToyViTBackbone(nn.Module):
get_layer_depth = VisionTransformer.get_layer_depth
def __init__(self, num_layers=2):
super().__init__()
self.cls_token = nn.Parameter(torch.ones(1))
self.pos_embed = nn.Parameter(torch.ones(1))
self.num_layers = num_layers
self.layers = nn.ModuleList()
for _ in range(num_layers):
layer = nn.Conv2d(3, 3, 1)
self.layers.append(layer)
class ToyViT(nn.Module):
get_layer_depth = ImageClassifier.get_layer_depth
def __init__(self):
super().__init__()
# add some variables to meet unit test coverate rate
self.backbone = ToyViTBackbone()
self.head = nn.Linear(1, 1)
class TestLearningRateDecayOptimWrapperConstructor(TestCase):
base_lr = 1.0
base_wd = 0.05
def test_add_params(self):
model = ToyViT()
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=self.base_lr,
betas=(0.9, 0.999),
weight_decay=self.base_wd))
paramwise_cfg = dict(
layer_decay_rate=2.0,
bias_decay_mult=0.,
custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
})
constructor = LearningRateDecayOptimWrapperConstructor(
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg,
)
optimizer_wrapper = constructor(model)
expected_groups = [{
'weight_decay': 0.0,
'lr': 8 * self.base_lr,
'param_name': 'backbone.cls_token',
}, {
'weight_decay': 0.0,
'lr': 8 * self.base_lr,
'param_name': 'backbone.pos_embed',
}, {
'weight_decay': self.base_wd,
'lr': 4 * self.base_lr,
'param_name': 'backbone.layers.0.weight',
}, {
'weight_decay': 0.0,
'lr': 4 * self.base_lr,
'param_name': 'backbone.layers.0.bias',
}, {
'weight_decay': self.base_wd,
'lr': 2 * self.base_lr,
'param_name': 'backbone.layers.1.weight',
}, {
'weight_decay': 0.0,
'lr': 2 * self.base_lr,
'param_name': 'backbone.layers.1.bias',
}, {
'weight_decay': self.base_wd,
'lr': 1 * self.base_lr,
'param_name': 'head.weight',
}, {
'weight_decay': 0.0,
'lr': 1 * self.base_lr,
'param_name': 'head.bias',
}]
self.assertIsInstance(optimizer_wrapper.optimizer, torch.optim.AdamW)
self.assertEqual(optimizer_wrapper.optimizer.defaults['lr'],
self.base_lr)
self.assertEqual(optimizer_wrapper.optimizer.defaults['weight_decay'],
self.base_wd)
param_groups = optimizer_wrapper.optimizer.param_groups
self.assertEqual(len(param_groups), len(expected_groups))
for expect, param in zip(expected_groups, param_groups):
self.assertEqual(param['param_name'], expect['param_name'])
self.assertEqual(param['lr'], expect['lr'])
self.assertEqual(param['weight_decay'], expect['weight_decay'])