[Refactor]: Rename optimizer wrapper constructor

pull/352/head
YuanLiuuuuuu 2022-06-02 13:32:48 +00:00 committed by fangyixiao18
parent 301ae8ded2
commit 62c909d3d2
5 changed files with 43 additions and 41 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
from .layer_decay_optim_wrapper_constructor import \
LearningRateDecayOptimWrapperConstructor
from .optimizers import LARS
__all__ = ['LARS', 'LearningRateDecayOptimizerConstructor']
__all__ = ['LARS', 'LearningRateDecayOptimWrapperConstructor']

View File

@ -4,10 +4,10 @@ from typing import Dict, List, Optional, Union
import torch
from mmengine.dist import get_dist_info
from mmengine.optim import DefaultOptimizerConstructor
from mmengine.optim import DefaultOptimWrapperConstructor
from torch import nn
from mmselfsup.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
from mmselfsup.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
from mmselfsup.utils import get_root_logger
@ -59,8 +59,8 @@ def get_layer_id_for_swin(var_name: str, max_layer_id: int,
return max_layer_id - 1
@OPTIMIZER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone.
Note: Currently, this optimizer constructor is built for ViT and Swin.

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import (DATASETS, HOOKS, MODELS, OPTIMIZER_CONSTRUCTORS,
from .registry import (DATASETS, HOOKS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
OPTIMIZERS, TRANSFORMS)
__all__ = [
'MODELS', 'DATASETS', 'TRANSFORMS', 'HOOKS', 'OPTIMIZERS',
'OPTIMIZER_CONSTRUCTORS'
'OPTIM_WRAPPER_CONSTRUCTORS'
]

View File

@ -7,7 +7,7 @@ from mmengine.registry import METRICS as MMENGINE_METRICS
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import \
OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS
OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
from mmengine.registry import \
@ -47,8 +47,8 @@ WEIGHT_INITIALIZERS = Registry(
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
# manage constructors that customize the optimization hyperparameters.
OPTIMIZER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)

View File

@ -4,7 +4,7 @@ import torch
from mmcls.models import SwinTransformer
from torch import nn
from mmselfsup.core import LearningRateDecayOptimizerConstructor
from mmselfsup.core import LearningRateDecayOptimWrapperConstructor
class ToyViTBackbone(nn.Module):
@ -76,43 +76,45 @@ def check_optimizer_lr_wd(optimizer, gt_lr_wd):
assert param_dict['lr_scale'] == param_dict['lr']
def test_learning_rate_decay_optimizer_constructor():
def test_learning_rate_decay_optimizer_wrapper_constructor():
model = ToyViT()
optimizer_config = dict(
type='AdamW',
lr=base_lr,
betas=(0.9, 0.999),
weight_decay=base_wd,
model_type='vit',
layer_decay_rate=2.0)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=base_lr,
betas=(0.9, 0.999),
weight_decay=base_wd,
model_type='vit',
layer_decay_rate=2.0))
# test when model_type is None
with pytest.raises(AssertionError):
optimizer_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg=optimizer_config)
optimizer_config['model_type'] = None
optimizer = optimizer_constructor(model)
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( # noqa
optim_wrapper_cfg=optim_wrapper_cfg)
optim_wrapper_cfg['optimizer']['model_type'] = None
optimizer_wrapper = optimizer_wrapper_constructor(model)
# test when model_type is invalid
with pytest.raises(AssertionError):
optimizer_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg=optimizer_config)
optimizer_config['model_type'] = 'invalid'
optimizer = optimizer_constructor(model)
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( # noqa
optim_wrapper_cfg=optim_wrapper_cfg)
optim_wrapper_cfg['optimizer']['model_type'] = 'invalid'
optimizer_wrapper = optimizer_wrapper_constructor(model)
# test vit
optimizer_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg=optimizer_config)
optimizer_config['model_type'] = 'vit'
optimizer = optimizer_constructor(model)
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_vit)
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor(
optim_wrapper_cfg=optim_wrapper_cfg)
optim_wrapper_cfg['optimizer']['model_type'] = 'vit'
optimizer_wrapper = optimizer_wrapper_constructor(model)
check_optimizer_lr_wd(optimizer_wrapper, expected_layer_wise_wd_lr_vit)
# test swin
model = ToySwin()
optimizer_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg=optimizer_config)
optimizer_config['model_type'] = 'swin'
optimizer = optimizer_constructor(model)
assert optimizer.param_groups[-1]['lr_scale'] == 1.0
assert optimizer.param_groups[-3]['lr_scale'] == 2.0
assert optimizer.param_groups[-5]['lr_scale'] == 4.0
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor(
optim_wrapper_cfg=optim_wrapper_cfg)
optim_wrapper_cfg['optimizer']['model_type'] = 'swin'
optimizer_wrapper = optimizer_wrapper_constructor(model)
assert optimizer_wrapper.param_groups[-1]['lr_scale'] == 1.0
assert optimizer_wrapper.param_groups[-3]['lr_scale'] == 2.0
assert optimizer_wrapper.param_groups[-5]['lr_scale'] == 4.0