[Fix] Register optimizer constructor with mmseg (#1456)

* [fix] register optimizer onstructor with mmseg

* fix lint

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* Update tests/test_core/test_optimizer.py

* Update tests/test_core/test_optimizer.py

* Update tests/test_core/test_optimizer.py

* Update tests/test_core/test_optimizer.py

* fix lint

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
This commit is contained in:
FangjianLin 2022-04-16 16:45:17 +08:00 committed by GitHub
parent 737e7e6e6c
commit 5b605b086d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 106 additions and 7 deletions

View File

@ -8,11 +8,11 @@ import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_optimizer, build_runner, get_dist_info)
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg
from mmseg import digit_version
from mmseg.core import DistEvalHook, EvalHook
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import find_latest_checkpoint, get_root_logger

View File

@ -1,6 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
from .evaluation import * # noqa: F401, F403
from .layer_decay_optimizer_constructor import \
LayerDecayOptimizerConstructor # noqa: F401
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
__all__ = [
'LayerDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'build_optimizer',
'build_optimizer_constructor'
]

33
mmseg/core/builder.py Normal file
View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
from mmcv.utils import Registry, build_from_cfg
OPTIMIZER_BUILDERS = Registry(
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
else:
raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.')
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
get_dist_info)
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
from mmseg.utils import get_root_logger
from .builder import OPTIMIZER_BUILDERS
def get_num_layer_for_vit(var_name, num_max_layer):

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
get_dist_info)
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
from ...utils import get_root_logger
from mmseg.utils import get_root_logger
from ..builder import OPTIMIZER_BUILDERS
def get_num_layer_layer_wise(var_name, num_max_layer=12):

View File

@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmcv.runner import DefaultOptimizerConstructor
from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
def forward(self, x):
return x
base_lr = 0.01
base_wd = 0.0001
momentum = 0.9
def test_build_optimizer_constructor():
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
optim_constructor_cfg = dict(
type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
# Test whether optimizer constructor can be built from parent.
assert type(optim_constructor) is DefaultOptimizerConstructor
@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
pass
optim_constructor_cfg = dict(
type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
# Test optimizer constructor can be built from child registry.
assert type(optim_constructor) is MyOptimizerConstructor
# Test unregistered constructor cannot be built
with pytest.raises(KeyError):
build_optimizer_constructor(dict(type='A'))
def test_build_optimizer():
model = ExampleModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
optimizer = build_optimizer(model, optimizer_cfg)
# test whether optimizer is successfully built from parent.
assert isinstance(optimizer, torch.optim.SGD)