mmsegmentation/mmseg/core/builder.py

28 lines
935 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_CONSTRUCTORS:
return OPTIMIZER_CONSTRUCTORS.build(cfg)
else:
raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.')
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
optim_constructor = build_optimizer_constructor(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer