mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Register optimizer constructor with mmseg (#1456)
* [fix] register optimizer onstructor with mmseg * fix lint * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * fix lint Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
This commit is contained in:
parent
737e7e6e6c
commit
5b605b086d
@ -8,11 +8,11 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
||||
build_optimizer, build_runner, get_dist_info)
|
||||
build_runner, get_dist_info)
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from mmseg import digit_version
|
||||
from mmseg.core import DistEvalHook, EvalHook
|
||||
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.utils import find_latest_checkpoint, get_root_logger
|
||||
|
||||
|
@ -1,6 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
|
||||
build_optimizer_constructor)
|
||||
from .evaluation import * # noqa: F401, F403
|
||||
from .layer_decay_optimizer_constructor import \
|
||||
LayerDecayOptimizerConstructor # noqa: F401
|
||||
from .seg import * # noqa: F401, F403
|
||||
from .utils import * # noqa: F401, F403
|
||||
|
||||
__all__ = [
|
||||
'LayerDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'build_optimizer',
|
||||
'build_optimizer_constructor'
|
||||
]
|
||||
|
33
mmseg/core/builder.py
Normal file
33
mmseg/core/builder.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
OPTIMIZER_BUILDERS = Registry(
|
||||
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
|
||||
|
||||
|
||||
def build_optimizer_constructor(cfg):
|
||||
constructor_type = cfg.get('type')
|
||||
if constructor_type in OPTIMIZER_BUILDERS:
|
||||
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
||||
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
|
||||
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
|
||||
else:
|
||||
raise KeyError(f'{constructor_type} is not registered '
|
||||
'in the optimizer builder registry.')
|
||||
|
||||
|
||||
def build_optimizer(model, cfg):
|
||||
optimizer_cfg = copy.deepcopy(cfg)
|
||||
constructor_type = optimizer_cfg.pop('constructor',
|
||||
'DefaultOptimizerConstructor')
|
||||
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
||||
optim_constructor = build_optimizer_constructor(
|
||||
dict(
|
||||
type=constructor_type,
|
||||
optimizer_cfg=optimizer_cfg,
|
||||
paramwise_cfg=paramwise_cfg))
|
||||
optimizer = optim_constructor(model)
|
||||
return optimizer
|
@ -1,8 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
|
||||
get_dist_info)
|
||||
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
from .builder import OPTIMIZER_BUILDERS
|
||||
|
||||
|
||||
def get_num_layer_for_vit(var_name, num_max_layer):
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
|
||||
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
|
||||
get_dist_info)
|
||||
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
||||
|
||||
from ...utils import get_root_logger
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import OPTIMIZER_BUILDERS
|
||||
|
||||
|
||||
def get_num_layer_layer_wise(var_name, num_max_layer=12):
|
||||
|
59
tests/test_core/test_optimizer.py
Normal file
59
tests/test_core/test_optimizer.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import DefaultOptimizerConstructor
|
||||
|
||||
from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer,
|
||||
build_optimizer_constructor)
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param1 = nn.Parameter(torch.ones(1))
|
||||
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
|
||||
self.bn = nn.BatchNorm2d(2)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
base_lr = 0.01
|
||||
base_wd = 0.0001
|
||||
momentum = 0.9
|
||||
|
||||
|
||||
def test_build_optimizer_constructor():
|
||||
optimizer_cfg = dict(
|
||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
|
||||
optim_constructor_cfg = dict(
|
||||
type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg)
|
||||
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
||||
# Test whether optimizer constructor can be built from parent.
|
||||
assert type(optim_constructor) is DefaultOptimizerConstructor
|
||||
|
||||
@OPTIMIZER_BUILDERS.register_module()
|
||||
class MyOptimizerConstructor(DefaultOptimizerConstructor):
|
||||
pass
|
||||
|
||||
optim_constructor_cfg = dict(
|
||||
type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg)
|
||||
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
||||
# Test optimizer constructor can be built from child registry.
|
||||
assert type(optim_constructor) is MyOptimizerConstructor
|
||||
|
||||
# Test unregistered constructor cannot be built
|
||||
with pytest.raises(KeyError):
|
||||
build_optimizer_constructor(dict(type='A'))
|
||||
|
||||
|
||||
def test_build_optimizer():
|
||||
model = ExampleModel()
|
||||
optimizer_cfg = dict(
|
||||
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
|
||||
optimizer = build_optimizer(model, optimizer_cfg)
|
||||
# test whether optimizer is successfully built from parent.
|
||||
assert isinstance(optimizer, torch.optim.SGD)
|
Loading…
x
Reference in New Issue
Block a user