From 5b605b086d690123e41dcbda872cb8a53cd48937 Mon Sep 17 00:00:00 2001 From: FangjianLin <93248678+linfangjian01@users.noreply.github.com> Date: Sat, 16 Apr 2022 16:45:17 +0800 Subject: [PATCH] [Fix] Register optimizer constructor with mmseg (#1456) * [fix] register optimizer onstructor with mmseg * fix lint * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * fix lint Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> --- mmseg/apis/train.py | 4 +- mmseg/core/__init__.py | 7 +++ mmseg/core/builder.py | 33 +++++++++++ .../core/layer_decay_optimizer_constructor.py | 4 +- .../layer_decay_optimizer_constructor.py | 6 +- tests/test_core/test_optimizer.py | 59 +++++++++++++++++++ 6 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 mmseg/core/builder.py create mode 100644 tests/test_core/test_optimizer.py diff --git a/mmseg/apis/train.py b/mmseg/apis/train.py index bda7b213f..3563e3620 100644 --- a/mmseg/apis/train.py +++ b/mmseg/apis/train.py @@ -8,11 +8,11 @@ import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, - build_optimizer, build_runner, get_dist_info) + build_runner, get_dist_info) from mmcv.utils import build_from_cfg from mmseg import digit_version -from mmseg.core import DistEvalHook, EvalHook +from mmseg.core import DistEvalHook, EvalHook, build_optimizer from mmseg.datasets import build_dataloader, build_dataset from mmseg.utils import find_latest_checkpoint, get_root_logger diff --git a/mmseg/core/__init__.py b/mmseg/core/__init__.py index c60b48c0c..c91334926 100644 --- a/mmseg/core/__init__.py +++ b/mmseg/core/__init__.py @@ -1,6 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .builder import (OPTIMIZER_BUILDERS, build_optimizer, + build_optimizer_constructor) from .evaluation import * # noqa: F401, F403 from .layer_decay_optimizer_constructor import \ LayerDecayOptimizerConstructor # noqa: F401 from .seg import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 + +__all__ = [ + 'LayerDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'build_optimizer', + 'build_optimizer_constructor' +] diff --git a/mmseg/core/builder.py b/mmseg/core/builder.py new file mode 100644 index 000000000..406dd9b4b --- /dev/null +++ b/mmseg/core/builder.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS +from mmcv.utils import Registry, build_from_cfg + +OPTIMIZER_BUILDERS = Registry( + 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) + + +def build_optimizer_constructor(cfg): + constructor_type = cfg.get('type') + if constructor_type in OPTIMIZER_BUILDERS: + return build_from_cfg(cfg, OPTIMIZER_BUILDERS) + elif constructor_type in MMCV_OPTIMIZER_BUILDERS: + return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) + else: + raise KeyError(f'{constructor_type} is not registered ' + 'in the optimizer builder registry.') + + +def build_optimizer(model, cfg): + optimizer_cfg = copy.deepcopy(cfg) + constructor_type = optimizer_cfg.pop('constructor', + 'DefaultOptimizerConstructor') + paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + optim_constructor = build_optimizer_constructor( + dict( + type=constructor_type, + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) + return optimizer diff --git a/mmseg/core/layer_decay_optimizer_constructor.py b/mmseg/core/layer_decay_optimizer_constructor.py index 30a09ba08..bd3db92c5 100644 --- a/mmseg/core/layer_decay_optimizer_constructor.py +++ b/mmseg/core/layer_decay_optimizer_constructor.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, - get_dist_info) +from mmcv.runner import DefaultOptimizerConstructor, get_dist_info from mmseg.utils import get_root_logger +from .builder import OPTIMIZER_BUILDERS def get_num_layer_for_vit(var_name, num_max_layer): diff --git a/mmseg/core/utils/layer_decay_optimizer_constructor.py b/mmseg/core/utils/layer_decay_optimizer_constructor.py index ec9dc156d..29804878c 100644 --- a/mmseg/core/utils/layer_decay_optimizer_constructor.py +++ b/mmseg/core/utils/layer_decay_optimizer_constructor.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import json -from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, - get_dist_info) +from mmcv.runner import DefaultOptimizerConstructor, get_dist_info -from ...utils import get_root_logger +from mmseg.utils import get_root_logger +from ..builder import OPTIMIZER_BUILDERS def get_num_layer_layer_wise(var_name, num_max_layer=12): diff --git a/tests/test_core/test_optimizer.py b/tests/test_core/test_optimizer.py new file mode 100644 index 000000000..247f9feb1 --- /dev/null +++ b/tests/test_core/test_optimizer.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +import torch.nn as nn +from mmcv.runner import DefaultOptimizerConstructor + +from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer, + build_optimizer_constructor) + + +class ExampleModel(nn.Module): + + def __init__(self): + super().__init__() + self.param1 = nn.Parameter(torch.ones(1)) + self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False) + self.conv2 = nn.Conv2d(4, 2, kernel_size=1) + self.bn = nn.BatchNorm2d(2) + + def forward(self, x): + return x + + +base_lr = 0.01 +base_wd = 0.0001 +momentum = 0.9 + + +def test_build_optimizer_constructor(): + optimizer_cfg = dict( + type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum) + optim_constructor_cfg = dict( + type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg) + optim_constructor = build_optimizer_constructor(optim_constructor_cfg) + # Test whether optimizer constructor can be built from parent. + assert type(optim_constructor) is DefaultOptimizerConstructor + + @OPTIMIZER_BUILDERS.register_module() + class MyOptimizerConstructor(DefaultOptimizerConstructor): + pass + + optim_constructor_cfg = dict( + type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg) + optim_constructor = build_optimizer_constructor(optim_constructor_cfg) + # Test optimizer constructor can be built from child registry. + assert type(optim_constructor) is MyOptimizerConstructor + + # Test unregistered constructor cannot be built + with pytest.raises(KeyError): + build_optimizer_constructor(dict(type='A')) + + +def test_build_optimizer(): + model = ExampleModel() + optimizer_cfg = dict( + type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum) + optimizer = build_optimizer(model, optimizer_cfg) + # test whether optimizer is successfully built from parent. + assert isinstance(optimizer, torch.optim.SGD)