Merge branch 'zhengmiao/fix-optim-constructor' into 'refactor_dev'

[Fix] Optimizer-> OptimWrapper

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!25
pull/1801/head
zhengmiao 2022-06-02 05:19:37 +00:00
commit 8c4e35304f
11 changed files with 86 additions and 98 deletions

View File

@ -61,7 +61,7 @@ from .cocktail_optimizer import CocktailOptimizer
@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
def __call__(self, model):

View File

@ -91,7 +91,7 @@ from .my_optimizer import MyOptimizer
@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(object):
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
def __call__(self, model):

View File

@ -59,7 +59,7 @@ from .cocktail_optimizer import CocktailOptimizer
@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
def __call__(self, model):

View File

@ -93,7 +93,7 @@ from .my_optimizer import MyOptimizer
@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(object):
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
def __call__(self, model):

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_optimizer, build_optimizer_constructor
from .builder import build_optimizer
from .data_structures import * # noqa: F401, F403
from .evaluation import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
__all__ = ['build_optimizer', 'build_optimizer_constructor']
__all__ = ['build_optimizer']

View File

@ -1,27 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_CONSTRUCTORS:
return OPTIMIZER_CONSTRUCTORS.build(cfg)
else:
raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.')
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
optim_wrapper_cfg = copy.deepcopy(cfg)
constructor_type = optim_wrapper_cfg.pop('constructor',
'DefaultOptimWrapperConstructor')
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
optim_wrapper_builder = OPTIM_WRAPPER_CONSTRUCTORS.build(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
optim_wrapper_cfg=optim_wrapper_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer
optim_wrapper = optim_wrapper_builder(model)
return optim_wrapper

View File

@ -3,9 +3,9 @@ import json
import warnings
from mmengine.dist import get_dist_info
from mmengine.optim import DefaultOptimizerConstructor
from mmengine.optim import DefaultOptimWrapperConstructor
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
from mmseg.utils import get_root_logger
@ -100,8 +100,8 @@ def get_layer_id_for_vit(var_name, max_layer_id):
return max_layer_id - 1
@OPTIMIZER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone.
Note: Currently, this optimizer constructor is built for ConvNeXt,
@ -186,7 +186,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
params.extend(parameter_groups.values())
@OPTIMIZER_CONSTRUCTORS.register_module()
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
"""Different learning rates are set for different layers of backbone.
@ -195,7 +195,7 @@ class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
Please use ``LearningRateDecayOptimizerConstructor`` instead.
"""
def __init__(self, optimizer_cfg, paramwise_cfg):
def __init__(self, optim_wrapper_cfg, paramwise_cfg):
warnings.warn('DeprecationWarning: Original '
'LayerDecayOptimizerConstructor of BEiT '
'will be deprecated. Please use '
@ -206,4 +206,4 @@ class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
'be deleted, please use decay_rate instead.')
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
super(LayerDecayOptimizerConstructor,
self).__init__(optimizer_cfg, paramwise_cfg)
self).__init__(optim_wrapper_cfg, paramwise_cfg)

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS,
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS,
VISUALIZERS, WEIGHT_INITIALIZERS)
@ -8,6 +8,6 @@ from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
__all__ = [
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS'
]

View File

@ -14,7 +14,7 @@ from mmengine.registry import METRICS as MMENGINE_METRICS
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import \
OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS
OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
from mmengine.registry import \
@ -54,8 +54,8 @@ WEIGHT_INITIALIZERS = Registry(
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
# manage constructors that customize the optimization hyperparameters.
OPTIMIZER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
# from copyreg import constructor
import pytest
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.core.optimizers.layer_decay_optimizer_constructor import (
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
from mmseg.core.builder import build_optimizer
from mmseg.core.optimizers.layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
base_lr = 1
decay_rate = 2
@ -209,22 +211,30 @@ def test_learning_rate_decay_optimizer_constructor():
# Test lr wd for ConvNeXT
backbone = ToyConvNeXt()
model = PseudoDataParallel(ToySegmentor(backbone))
optimizer_cfg = dict(
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
# stagewise decay
stagewise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='stage_wise', num_layers=6)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, stagewise_paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer_lr_wd(optimizer, expected_stage_wise_lr_wd_convnext)
optimizer_cfg = dict(
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=optimizer_cfg,
paramwise_cfg=stagewise_paramwise_cfg,
constructor='LearningRateDecayOptimizerConstructor')
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_stage_wise_lr_wd_convnext)
# layerwise decay
layerwise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='layer_wise', num_layers=6)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, layerwise_paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer_lr_wd(optimizer, expected_layer_wise_lr_wd_convnext)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=optimizer_cfg,
paramwise_cfg=layerwise_paramwise_cfg,
constructor='LearningRateDecayOptimizerConstructor')
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_layer_wise_lr_wd_convnext)
# Test lr wd for BEiT
backbone = ToyBEiT()
@ -232,22 +242,26 @@ def test_learning_rate_decay_optimizer_constructor():
layerwise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='layer_wise', num_layers=3)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, layerwise_paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=optimizer_cfg,
paramwise_cfg=layerwise_paramwise_cfg,
constructor='LearningRateDecayOptimizerConstructor')
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_layer_wise_wd_lr_beit)
# Test invalidation of lr wd for Vit
backbone = ToyViT()
model = PseudoDataParallel(ToySegmentor(backbone))
with pytest.raises(NotImplementedError):
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, layerwise_paramwise_cfg)
optimizer = optim_constructor(model)
optim_wrapper_cfg, layerwise_paramwise_cfg)
optim_constructor(model)
with pytest.raises(NotImplementedError):
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, stagewise_paramwise_cfg)
optimizer = optim_constructor(model)
optim_wrapper_cfg, stagewise_paramwise_cfg)
optim_constructor(model)
# Test lr wd for MAE
backbone = ToyMAE()
@ -255,10 +269,14 @@ def test_learning_rate_decay_optimizer_constructor():
layerwise_paramwise_cfg = dict(
decay_rate=decay_rate, decay_type='layer_wise', num_layers=3)
optim_constructor = LearningRateDecayOptimizerConstructor(
optimizer_cfg, layerwise_paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=optimizer_cfg,
paramwise_cfg=layerwise_paramwise_cfg,
constructor='LearningRateDecayOptimizerConstructor')
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_layer_wise_wd_lr_beit)
def test_beit_layer_decay_optimizer_constructor():
@ -266,10 +284,14 @@ def test_beit_layer_decay_optimizer_constructor():
# paramwise_cfg with BEiTExampleModel
backbone = ToyBEiT()
model = PseudoDataParallel(ToySegmentor(backbone))
optimizer_cfg = dict(
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05)
paramwise_cfg = dict(layer_decay_rate=2, num_layers=3)
optim_constructor = LayerDecayOptimizerConstructor(optimizer_cfg,
paramwise_cfg)
optimizer = optim_constructor(model)
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit)
optim_wrapper_cfg = dict(
type='OptimWrapper',
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=paramwise_cfg,
optimizer=dict(
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05))
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
# optimizer = optim_wrapper_builder(model)
check_optimizer_lr_wd(optim_wrapper.optimizer,
expected_layer_wise_wd_lr_beit)

View File

@ -1,11 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmengine.optim import DefaultOptimizerConstructor
from mmseg.core.builder import build_optimizer, build_optimizer_constructor
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
from mmseg.core.builder import build_optimizer
class ExampleModel(nn.Module):
@ -26,34 +23,12 @@ base_wd = 0.0001
momentum = 0.9
def test_build_optimizer_constructor():
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
optim_constructor_cfg = dict(
type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
# Test whether optimizer constructor can be built from parent.
assert type(optim_constructor) is DefaultOptimizerConstructor
@OPTIMIZER_CONSTRUCTORS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
pass
optim_constructor_cfg = dict(
type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
# Test optimizer constructor can be built from child registry.
assert type(optim_constructor) is MyOptimizerConstructor
# Test unregistered constructor cannot be built
with pytest.raises(KeyError):
build_optimizer_constructor(dict(type='A'))
def test_build_optimizer():
model = ExampleModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
optimizer = build_optimizer(model, optimizer_cfg)
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
# test whether optimizer is successfully built from parent.
assert isinstance(optimizer, torch.optim.SGD)
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)