mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* [fix] register optimizer onstructor with mmseg * fix lint * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * Update tests/test_core/test_optimizer.py * fix lint Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
|
|
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
|
|
from mmcv.utils import Registry, build_from_cfg
|
|
|
|
OPTIMIZER_BUILDERS = Registry(
|
|
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
|
|
|
|
|
|
def build_optimizer_constructor(cfg):
|
|
constructor_type = cfg.get('type')
|
|
if constructor_type in OPTIMIZER_BUILDERS:
|
|
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
|
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
|
|
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
|
|
else:
|
|
raise KeyError(f'{constructor_type} is not registered '
|
|
'in the optimizer builder registry.')
|
|
|
|
|
|
def build_optimizer(model, cfg):
|
|
optimizer_cfg = copy.deepcopy(cfg)
|
|
constructor_type = optimizer_cfg.pop('constructor',
|
|
'DefaultOptimizerConstructor')
|
|
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
|
optim_constructor = build_optimizer_constructor(
|
|
dict(
|
|
type=constructor_type,
|
|
optimizer_cfg=optimizer_cfg,
|
|
paramwise_cfg=paramwise_cfg))
|
|
optimizer = optim_constructor(model)
|
|
return optimizer
|