[Fix] Register optimizer constructor with mmseg (#1456)

* [fix] register optimizer onstructor with mmseg

* fix lint

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* Update tests/test_core/test_optimizer.py

* Update tests/test_core/test_optimizer.py

* Update tests/test_core/test_optimizer.py

* Update tests/test_core/test_optimizer.py

* fix lint

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
This commit is contained in:
FangjianLin 2022-04-16 16:45:17 +08:00 committed by GitHub
parent 737e7e6e6c
commit 5b605b086d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 106 additions and 7 deletions

View File

@ -8,11 +8,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_optimizer, build_runner, get_dist_info) build_runner, get_dist_info)
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmseg import digit_version from mmseg import digit_version
from mmseg.core import DistEvalHook, EvalHook from mmseg.core import DistEvalHook, EvalHook, build_optimizer
from mmseg.datasets import build_dataloader, build_dataset from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import find_latest_checkpoint, get_root_logger from mmseg.utils import find_latest_checkpoint, get_root_logger

View File

@ -1,6 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
from .evaluation import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403
from .layer_decay_optimizer_constructor import \ from .layer_decay_optimizer_constructor import \
LayerDecayOptimizerConstructor # noqa: F401 LayerDecayOptimizerConstructor # noqa: F401
from .seg import * # noqa: F401, F403 from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403 from .utils import * # noqa: F401, F403
__all__ = [
'LayerDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'build_optimizer',
'build_optimizer_constructor'
]

33
mmseg/core/builder.py Normal file
View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
from mmcv.utils import Registry, build_from_cfg
OPTIMIZER_BUILDERS = Registry(
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
else:
raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.')
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
get_dist_info)
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from .builder import OPTIMIZER_BUILDERS
def get_num_layer_for_vit(var_name, num_max_layer): def get_num_layer_for_vit(var_name, num_max_layer):

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import json import json
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
get_dist_info)
from ...utils import get_root_logger from mmseg.utils import get_root_logger
from ..builder import OPTIMIZER_BUILDERS
def get_num_layer_layer_wise(var_name, num_max_layer=12): def get_num_layer_layer_wise(var_name, num_max_layer=12):

View File

@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmcv.runner import DefaultOptimizerConstructor
from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
def forward(self, x):
return x
base_lr = 0.01
base_wd = 0.0001
momentum = 0.9
def test_build_optimizer_constructor():
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
optim_constructor_cfg = dict(
type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
# Test whether optimizer constructor can be built from parent.
assert type(optim_constructor) is DefaultOptimizerConstructor
@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
pass
optim_constructor_cfg = dict(
type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg)
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
# Test optimizer constructor can be built from child registry.
assert type(optim_constructor) is MyOptimizerConstructor
# Test unregistered constructor cannot be built
with pytest.raises(KeyError):
build_optimizer_constructor(dict(type='A'))
def test_build_optimizer():
model = ExampleModel()
optimizer_cfg = dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
optimizer = build_optimizer(model, optimizer_cfg)
# test whether optimizer is successfully built from parent.
assert isinstance(optimizer, torch.optim.SGD)