mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Register optimizer constructor with mmseg (#1456)
* [fix] register optimizer onstructor with mmseg * fix lint * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * fix lint Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
This commit is contained in:
parent
737e7e6e6c
commit
5b605b086d
@ -8,11 +8,11 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||||
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
||||||
build_optimizer, build_runner, get_dist_info)
|
build_runner, get_dist_info)
|
||||||
from mmcv.utils import build_from_cfg
|
from mmcv.utils import build_from_cfg
|
||||||
|
|
||||||
from mmseg import digit_version
|
from mmseg import digit_version
|
||||||
from mmseg.core import DistEvalHook, EvalHook
|
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
|
||||||
from mmseg.datasets import build_dataloader, build_dataset
|
from mmseg.datasets import build_dataloader, build_dataset
|
||||||
from mmseg.utils import find_latest_checkpoint, get_root_logger
|
from mmseg.utils import find_latest_checkpoint, get_root_logger
|
||||||
|
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
|
||||||
|
build_optimizer_constructor)
|
||||||
from .evaluation import * # noqa: F401, F403
|
from .evaluation import * # noqa: F401, F403
|
||||||
from .layer_decay_optimizer_constructor import \
|
from .layer_decay_optimizer_constructor import \
|
||||||
LayerDecayOptimizerConstructor # noqa: F401
|
LayerDecayOptimizerConstructor # noqa: F401
|
||||||
from .seg import * # noqa: F401, F403
|
from .seg import * # noqa: F401, F403
|
||||||
from .utils import * # noqa: F401, F403
|
from .utils import * # noqa: F401, F403
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'LayerDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS', 'build_optimizer',
|
||||||
|
'build_optimizer_constructor'
|
||||||
|
]
|
||||||
|
33
mmseg/core/builder.py
Normal file
33
mmseg/core/builder.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
|
||||||
|
from mmcv.utils import Registry, build_from_cfg
|
||||||
|
|
||||||
|
OPTIMIZER_BUILDERS = Registry(
|
||||||
|
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer_constructor(cfg):
|
||||||
|
constructor_type = cfg.get('type')
|
||||||
|
if constructor_type in OPTIMIZER_BUILDERS:
|
||||||
|
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
||||||
|
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
|
||||||
|
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
|
||||||
|
else:
|
||||||
|
raise KeyError(f'{constructor_type} is not registered '
|
||||||
|
'in the optimizer builder registry.')
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer(model, cfg):
|
||||||
|
optimizer_cfg = copy.deepcopy(cfg)
|
||||||
|
constructor_type = optimizer_cfg.pop('constructor',
|
||||||
|
'DefaultOptimizerConstructor')
|
||||||
|
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
||||||
|
optim_constructor = build_optimizer_constructor(
|
||||||
|
dict(
|
||||||
|
type=constructor_type,
|
||||||
|
optimizer_cfg=optimizer_cfg,
|
||||||
|
paramwise_cfg=paramwise_cfg))
|
||||||
|
optimizer = optim_constructor(model)
|
||||||
|
return optimizer
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
|
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
||||||
get_dist_info)
|
|
||||||
|
|
||||||
from mmseg.utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
|
from .builder import OPTIMIZER_BUILDERS
|
||||||
|
|
||||||
|
|
||||||
def get_num_layer_for_vit(var_name, num_max_layer):
|
def get_num_layer_for_vit(var_name, num_max_layer):
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
|
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
||||||
get_dist_info)
|
|
||||||
|
|
||||||
from ...utils import get_root_logger
|
from mmseg.utils import get_root_logger
|
||||||
|
from ..builder import OPTIMIZER_BUILDERS
|
||||||
|
|
||||||
|
|
||||||
def get_num_layer_layer_wise(var_name, num_max_layer=12):
|
def get_num_layer_layer_wise(var_name, num_max_layer=12):
|
||||||
|
59
tests/test_core/test_optimizer.py
Normal file
59
tests/test_core/test_optimizer.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.runner import DefaultOptimizerConstructor
|
||||||
|
|
||||||
|
from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer,
|
||||||
|
build_optimizer_constructor)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.param1 = nn.Parameter(torch.ones(1))
|
||||||
|
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
|
||||||
|
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
|
||||||
|
self.bn = nn.BatchNorm2d(2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
base_lr = 0.01
|
||||||
|
base_wd = 0.0001
|
||||||
|
momentum = 0.9
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_optimizer_constructor():
|
||||||
|
optimizer_cfg = dict(
|
||||||
|
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
|
||||||
|
optim_constructor_cfg = dict(
|
||||||
|
type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg)
|
||||||
|
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
||||||
|
# Test whether optimizer constructor can be built from parent.
|
||||||
|
assert type(optim_constructor) is DefaultOptimizerConstructor
|
||||||
|
|
||||||
|
@OPTIMIZER_BUILDERS.register_module()
|
||||||
|
class MyOptimizerConstructor(DefaultOptimizerConstructor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
optim_constructor_cfg = dict(
|
||||||
|
type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg)
|
||||||
|
optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
|
||||||
|
# Test optimizer constructor can be built from child registry.
|
||||||
|
assert type(optim_constructor) is MyOptimizerConstructor
|
||||||
|
|
||||||
|
# Test unregistered constructor cannot be built
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
build_optimizer_constructor(dict(type='A'))
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_optimizer():
|
||||||
|
model = ExampleModel()
|
||||||
|
optimizer_cfg = dict(
|
||||||
|
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
|
||||||
|
optimizer = build_optimizer(model, optimizer_cfg)
|
||||||
|
# test whether optimizer is successfully built from parent.
|
||||||
|
assert isinstance(optimizer, torch.optim.SGD)
|
Loading…
x
Reference in New Issue
Block a user