Merge branch 'zhengmiao/fix-optim-constructor' into 'refactor_dev'
[Fix] Optimizer-> OptimWrapper See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!25pull/1801/head
commit
8c4e35304f
|
@ -61,7 +61,7 @@ from .cocktail_optimizer import CocktailOptimizer
|
|||
@OPTIMIZER_BUILDERS.register_module
|
||||
class CocktailOptimizerConstructor(object):
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
||||
|
||||
def __call__(self, model):
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ from .my_optimizer import MyOptimizer
|
|||
@OPTIMIZER_BUILDERS.register_module()
|
||||
class MyOptimizerConstructor(object):
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
||||
|
||||
def __call__(self, model):
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ from .cocktail_optimizer import CocktailOptimizer
|
|||
@OPTIMIZER_BUILDERS.register_module
|
||||
class CocktailOptimizerConstructor(object):
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
||||
|
||||
def __call__(self, model):
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ from .my_optimizer import MyOptimizer
|
|||
@OPTIMIZER_BUILDERS.register_module()
|
||||
class MyOptimizerConstructor(object):
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
||||
|
||||
def __call__(self, model):
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import build_optimizer, build_optimizer_constructor
|
||||
from .builder import build_optimizer
|
||||
from .data_structures import * # noqa: F401, F403
|
||||
from .evaluation import * # noqa: F401, F403
|
||||
from .optimizers import * # noqa: F401, F403
|
||||
from .seg import * # noqa: F401, F403
|
||||
from .utils import * # noqa: F401, F403
|
||||
|
||||
__all__ = ['build_optimizer', 'build_optimizer_constructor']
|
||||
__all__ = ['build_optimizer']
|
||||
|
|
|
@ -1,27 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
|
||||
|
||||
|
||||
def build_optimizer_constructor(cfg):
|
||||
constructor_type = cfg.get('type')
|
||||
if constructor_type in OPTIMIZER_CONSTRUCTORS:
|
||||
return OPTIMIZER_CONSTRUCTORS.build(cfg)
|
||||
else:
|
||||
raise KeyError(f'{constructor_type} is not registered '
|
||||
'in the optimizer builder registry.')
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
def build_optimizer(model, cfg):
|
||||
optimizer_cfg = copy.deepcopy(cfg)
|
||||
constructor_type = optimizer_cfg.pop('constructor',
|
||||
'DefaultOptimizerConstructor')
|
||||
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
||||
optim_constructor = build_optimizer_constructor(
|
||||
optim_wrapper_cfg = copy.deepcopy(cfg)
|
||||
constructor_type = optim_wrapper_cfg.pop('constructor',
|
||||
'DefaultOptimWrapperConstructor')
|
||||
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
||||
optim_wrapper_builder = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
||||
dict(
|
||||
type=constructor_type,
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
optim_wrapper_cfg=optim_wrapper_cfg,
|
||||
paramwise_cfg=paramwise_cfg))
|
||||
optimizer = optim_constructor(model)
|
||||
return optimizer
|
||||
optim_wrapper = optim_wrapper_builder(model)
|
||||
return optim_wrapper
|
||||
|
|
|
@ -3,9 +3,9 @@ import json
|
|||
import warnings
|
||||
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.optim import DefaultOptimizerConstructor
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
|
||||
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
|
@ -100,8 +100,8 @@ def get_layer_id_for_vit(var_name, max_layer_id):
|
|||
return max_layer_id - 1
|
||||
|
||||
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for ConvNeXt,
|
||||
|
@ -186,7 +186,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
|||
params.extend(parameter_groups.values())
|
||||
|
||||
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
|
@ -195,7 +195,7 @@ class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
|
|||
Please use ``LearningRateDecayOptimizerConstructor`` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer_cfg, paramwise_cfg):
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg):
|
||||
warnings.warn('DeprecationWarning: Original '
|
||||
'LayerDecayOptimizerConstructor of BEiT '
|
||||
'will be deprecated. Please use '
|
||||
|
@ -206,4 +206,4 @@ class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
|
|||
'be deleted, please use decay_rate instead.')
|
||||
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
|
||||
super(LayerDecayOptimizerConstructor,
|
||||
self).__init__(optimizer_cfg, paramwise_cfg)
|
||||
self).__init__(optim_wrapper_cfg, paramwise_cfg)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
|
||||
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
|
||||
MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS,
|
||||
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS,
|
||||
VISUALIZERS, WEIGHT_INITIALIZERS)
|
||||
|
@ -8,6 +8,6 @@ from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
|
|||
__all__ = [
|
||||
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
|
||||
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
|
||||
'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
|
||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
|
||||
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS'
|
||||
]
|
||||
|
|
|
@ -14,7 +14,7 @@ from mmengine.registry import METRICS as MMENGINE_METRICS
|
|||
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
|
||||
from mmengine.registry import MODELS as MMENGINE_MODELS
|
||||
from mmengine.registry import \
|
||||
OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS
|
||||
OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
|
||||
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
|
||||
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
|
||||
from mmengine.registry import \
|
||||
|
@ -54,8 +54,8 @@ WEIGHT_INITIALIZERS = Registry(
|
|||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIMIZER_CONSTRUCTORS = Registry(
|
||||
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
|
||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
||||
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# from copyreg import constructor
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.core.optimizers.layer_decay_optimizer_constructor import (
|
||||
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
|
||||
from mmseg.core.builder import build_optimizer
|
||||
from mmseg.core.optimizers.layer_decay_optimizer_constructor import \
|
||||
LearningRateDecayOptimizerConstructor
|
||||
|
||||
base_lr = 1
|
||||
decay_rate = 2
|
||||
|
@ -209,22 +211,30 @@ def test_learning_rate_decay_optimizer_constructor():
|
|||
# Test lr wd for ConvNeXT
|
||||
backbone = ToyConvNeXt()
|
||||
model = PseudoDataParallel(ToySegmentor(backbone))
|
||||
optimizer_cfg = dict(
|
||||
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
|
||||
# stagewise decay
|
||||
stagewise_paramwise_cfg = dict(
|
||||
decay_rate=decay_rate, decay_type='stage_wise', num_layers=6)
|
||||
optim_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg, stagewise_paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
check_optimizer_lr_wd(optimizer, expected_stage_wise_lr_wd_convnext)
|
||||
optimizer_cfg = dict(
|
||||
type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=stagewise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_stage_wise_lr_wd_convnext)
|
||||
# layerwise decay
|
||||
layerwise_paramwise_cfg = dict(
|
||||
decay_rate=decay_rate, decay_type='layer_wise', num_layers=6)
|
||||
optim_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg, layerwise_paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
check_optimizer_lr_wd(optimizer, expected_layer_wise_lr_wd_convnext)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_lr_wd_convnext)
|
||||
|
||||
# Test lr wd for BEiT
|
||||
backbone = ToyBEiT()
|
||||
|
@ -232,22 +242,26 @@ def test_learning_rate_decay_optimizer_constructor():
|
|||
|
||||
layerwise_paramwise_cfg = dict(
|
||||
decay_rate=decay_rate, decay_type='layer_wise', num_layers=3)
|
||||
optim_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg, layerwise_paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
||||
|
||||
# Test invalidation of lr wd for Vit
|
||||
backbone = ToyViT()
|
||||
model = PseudoDataParallel(ToySegmentor(backbone))
|
||||
with pytest.raises(NotImplementedError):
|
||||
optim_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg, layerwise_paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
optim_wrapper_cfg, layerwise_paramwise_cfg)
|
||||
optim_constructor(model)
|
||||
with pytest.raises(NotImplementedError):
|
||||
optim_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg, stagewise_paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
optim_wrapper_cfg, stagewise_paramwise_cfg)
|
||||
optim_constructor(model)
|
||||
|
||||
# Test lr wd for MAE
|
||||
backbone = ToyMAE()
|
||||
|
@ -255,10 +269,14 @@ def test_learning_rate_decay_optimizer_constructor():
|
|||
|
||||
layerwise_paramwise_cfg = dict(
|
||||
decay_rate=decay_rate, decay_type='layer_wise', num_layers=3)
|
||||
optim_constructor = LearningRateDecayOptimizerConstructor(
|
||||
optimizer_cfg, layerwise_paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer_cfg,
|
||||
paramwise_cfg=layerwise_paramwise_cfg,
|
||||
constructor='LearningRateDecayOptimizerConstructor')
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
||||
|
||||
|
||||
def test_beit_layer_decay_optimizer_constructor():
|
||||
|
@ -266,10 +284,14 @@ def test_beit_layer_decay_optimizer_constructor():
|
|||
# paramwise_cfg with BEiTExampleModel
|
||||
backbone = ToyBEiT()
|
||||
model = PseudoDataParallel(ToySegmentor(backbone))
|
||||
optimizer_cfg = dict(
|
||||
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05)
|
||||
paramwise_cfg = dict(layer_decay_rate=2, num_layers=3)
|
||||
optim_constructor = LayerDecayOptimizerConstructor(optimizer_cfg,
|
||||
paramwise_cfg)
|
||||
optimizer = optim_constructor(model)
|
||||
check_optimizer_lr_wd(optimizer, expected_layer_wise_wd_lr_beit)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
constructor='LayerDecayOptimizerConstructor',
|
||||
paramwise_cfg=paramwise_cfg,
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05))
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
# optimizer = optim_wrapper_builder(model)
|
||||
check_optimizer_lr_wd(optim_wrapper.optimizer,
|
||||
expected_layer_wise_wd_lr_beit)
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.optim import DefaultOptimizerConstructor
|
||||
|
||||
from mmseg.core.builder import build_optimizer, build_optimizer_constructor
|
||||
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
|
||||
from mmseg.core.builder import build_optimizer
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
@ -26,34 +23,12 @@ base_wd = 0.0001
|
|||
momentum = 0.9
|
||||
|
||||
|
||||
def test_build_optimizer_constructor():
|
||||
optimizer_cfg = dict(
|
||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
|
||||
optim_constructor_cfg = dict(
|
||||
type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg)
|
||||
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
||||
# Test whether optimizer constructor can be built from parent.
|
||||
assert type(optim_constructor) is DefaultOptimizerConstructor
|
||||
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class MyOptimizerConstructor(DefaultOptimizerConstructor):
|
||||
pass
|
||||
|
||||
optim_constructor_cfg = dict(
|
||||
type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg)
|
||||
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
||||
# Test optimizer constructor can be built from child registry.
|
||||
assert type(optim_constructor) is MyOptimizerConstructor
|
||||
|
||||
# Test unregistered constructor cannot be built
|
||||
with pytest.raises(KeyError):
|
||||
build_optimizer_constructor(dict(type='A'))
|
||||
|
||||
|
||||
def test_build_optimizer():
|
||||
model = ExampleModel()
|
||||
optimizer_cfg = dict(
|
||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
|
||||
optimizer = build_optimizer(model, optimizer_cfg)
|
||||
optim_wrapper_cfg = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
|
||||
optim_wrapper = build_optimizer(model, optim_wrapper_cfg)
|
||||
# test whether optimizer is successfully built from parent.
|
||||
assert isinstance(optimizer, torch.optim.SGD)
|
||||
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
|
||||
|
|
Loading…
Reference in New Issue