[Refactor]: Rename optimizer wrapper constructor
parent
301ae8ded2
commit
62c909d3d2
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .layer_decay_optimizer_constructor import \
|
||||
LearningRateDecayOptimizerConstructor
|
||||
from .layer_decay_optim_wrapper_constructor import \
|
||||
LearningRateDecayOptimWrapperConstructor
|
||||
from .optimizers import LARS
|
||||
|
||||
__all__ = ['LARS', 'LearningRateDecayOptimizerConstructor']
|
||||
__all__ = ['LARS', 'LearningRateDecayOptimWrapperConstructor']
|
||||
|
|
|
@ -4,10 +4,10 @@ from typing import Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.optim import DefaultOptimizerConstructor
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
from torch import nn
|
||||
|
||||
from mmselfsup.registry import OPTIMIZER_CONSTRUCTORS, OPTIMIZERS
|
||||
from mmselfsup.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS
|
||||
from mmselfsup.utils import get_root_logger
|
||||
|
||||
|
||||
|
@ -59,8 +59,8 @@ def get_layer_id_for_swin(var_name: str, max_layer_id: int,
|
|||
return max_layer_id - 1
|
||||
|
||||
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for ViT and Swin.
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .registry import (DATASETS, HOOKS, MODELS, OPTIMIZER_CONSTRUCTORS,
|
||||
from .registry import (DATASETS, HOOKS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
OPTIMIZERS, TRANSFORMS)
|
||||
|
||||
__all__ = [
|
||||
'MODELS', 'DATASETS', 'TRANSFORMS', 'HOOKS', 'OPTIMIZERS',
|
||||
'OPTIMIZER_CONSTRUCTORS'
|
||||
'OPTIM_WRAPPER_CONSTRUCTORS'
|
||||
]
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmengine.registry import METRICS as MMENGINE_METRICS
|
|||
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
|
||||
from mmengine.registry import MODELS as MMENGINE_MODELS
|
||||
from mmengine.registry import \
|
||||
OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS
|
||||
OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
|
||||
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
|
||||
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
|
||||
from mmengine.registry import \
|
||||
|
@ -47,8 +47,8 @@ WEIGHT_INITIALIZERS = Registry(
|
|||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIMIZER_CONSTRUCTORS = Registry(
|
||||
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
|
||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
||||
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from mmcls.models import SwinTransformer
|
||||
from torch import nn
|
||||
|
||||
from mmselfsup.core import LearningRateDecayOptimizerConstructor
|
||||
from mmselfsup.core import LearningRateDecayOptimWrapperConstructor
|
||||
|
||||
|
||||
class ToyViTBackbone(nn.Module):
|
||||
|
@ -76,43 +76,45 @@ def check_optimizer_lr_wd(optimizer, gt_lr_wd):
|
|||
assert param_dict['lr_scale'] == param_dict['lr']
|
||||
|
||||
|
||||
def test_learning_rate_decay_optimizer_constructor():
|
||||
def test_learning_rate_decay_optimizer_wrapper_constructor():
|
||||
model = ToyViT()
|
||||
optimizer_config = dict(
|
||||
type='AdamW',
|
||||
lr=base_lr,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=base_wd,
|
||||
model_type='vit',
|
||||
layer_decay_rate=2.0)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=base_lr,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=base_wd,
|
||||
model_type='vit',
|
||||
layer_decay_rate=2.0))
|
||||
|
||||
# test when model_type is None
|
||||
with pytest.raises(AssertionError):
|
||||
optimizer_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg=optimizer_config)
|
||||
optimizer_config['model_type'] = None
|
||||
optimizer = optimizer_constructor(model)
|
||||
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( # noqa
|
||||
optim_wrapper_cfg=optim_wrapper_cfg)
|
||||
optim_wrapper_cfg['optimizer']['model_type'] = None
|
||||
optimizer_wrapper = optimizer_wrapper_constructor(model)
|
||||
|
||||
# test when model_type is invalid
|
||||
with pytest.raises(AssertionError):
|
||||
optimizer_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg=optimizer_config)
|
||||
optimizer_config['model_type'] = 'invalid'
|
||||
optimizer = optimizer_constructor(model)
|
||||
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( # noqa
|
||||
optim_wrapper_cfg=optim_wrapper_cfg)
|
||||
optim_wrapper_cfg['optimizer']['model_type'] = 'invalid'
|
||||
optimizer_wrapper = optimizer_wrapper_constructor(model)
|
||||
|
||||
# test vit
|
||||
optimizer_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg=optimizer_config)
|
||||
optimizer_config['model_type'] = 'vit'
|
||||
optimizer = optimizer_constructor(model)
|
||||
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_vit)
|
||||
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor(
|
||||
optim_wrapper_cfg=optim_wrapper_cfg)
|
||||
optim_wrapper_cfg['optimizer']['model_type'] = 'vit'
|
||||
optimizer_wrapper = optimizer_wrapper_constructor(model)
|
||||
check_optimizer_lr_wd(optimizer_wrapper, expected_layer_wise_wd_lr_vit)
|
||||
|
||||
# test swin
|
||||
model = ToySwin()
|
||||
optimizer_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg=optimizer_config)
|
||||
optimizer_config['model_type'] = 'swin'
|
||||
optimizer = optimizer_constructor(model)
|
||||
assert optimizer.param_groups[-1]['lr_scale'] == 1.0
|
||||
assert optimizer.param_groups[-3]['lr_scale'] == 2.0
|
||||
assert optimizer.param_groups[-5]['lr_scale'] == 4.0
|
||||
optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor(
|
||||
optim_wrapper_cfg=optim_wrapper_cfg)
|
||||
optim_wrapper_cfg['optimizer']['model_type'] = 'swin'
|
||||
optimizer_wrapper = optimizer_wrapper_constructor(model)
|
||||
assert optimizer_wrapper.param_groups[-1]['lr_scale'] == 1.0
|
||||
assert optimizer_wrapper.param_groups[-3]['lr_scale'] == 2.0
|
||||
assert optimizer_wrapper.param_groups[-5]['lr_scale'] == 4.0
|
Loading…
Reference in New Issue