# Copyright (c) OpenMMLab. All rights reserved. import pytest from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin, Registry, build_from_cfg, build_model_from_cfg) from mmengine.utils import is_installed @pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) def test_build_from_cfg(cfg_type): BACKBONES = Registry('backbone') @BACKBONES.register_module() class ResNet: def __init__(self, depth, stages=4): self.depth = depth self.stages = stages @BACKBONES.register_module() class ResNeXt: def __init__(self, depth, stages=4): self.depth = depth self.stages = stages # test `cfg` parameter # `cfg` should be a dict, ConfigDict or Config object with pytest.raises( TypeError, match=('cfg should be a dict, ConfigDict or Config, but got ' "")): cfg = 'ResNet' model = build_from_cfg(cfg, BACKBONES) # `cfg` is a dict, ConfigDict or Config object cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # `cfg` is a dict but it does not contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = dict(depth=50, stages=4) cfg = cfg_type(cfg) model = build_from_cfg(cfg, BACKBONES) # cfg['type'] should be a str or class with pytest.raises( TypeError, match="type must be a str or valid type, but got "): cfg = dict(type=1000) cfg = cfg_type(cfg) model = build_from_cfg(cfg, BACKBONES) cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 cfg = cfg_type(dict(type=ResNet, depth=50)) model = build_from_cfg(cfg, BACKBONES) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # non-registered class with pytest.raises(KeyError, match='VGG is not in the backbone registry'): cfg = cfg_type(dict(type='VGG')) model = build_from_cfg(cfg, BACKBONES) # `cfg` contains unexpected arguments with pytest.raises(TypeError): cfg = cfg_type(dict(type='ResNet', non_existing_arg=50)) model = build_from_cfg(cfg, BACKBONES) # test `default_args` parameter cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 3 # default_args must be a dict or None with pytest.raises(TypeError): cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, BACKBONES, default_args=1) # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = cfg_type(dict(depth=50)) model = build_from_cfg( cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) # "type" defined using default_args cfg = cfg_type(dict(depth=50)) model = build_from_cfg( cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = cfg_type(dict(depth=50)) model = build_from_cfg( cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 # test `registry` parameter # incorrect registry type with pytest.raises( TypeError, match=('registry must be a mmengine.Registry object, but got ' "")): cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, 'BACKBONES') VISUALIZER = Registry('visualizer') @VISUALIZER.register_module() class Visualizer(ManagerMixin): def __init__(self, name): super().__init__(name) with pytest.raises(RuntimeError): Visualizer.get_current_instance() cfg = dict(type='Visualizer', name='visualizer') build_from_cfg(cfg, VISUALIZER) Visualizer.get_current_instance() @pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch') def test_build_model_from_cfg(): import torch.nn as nn BACKBONES = Registry('backbone', build_func=build_model_from_cfg) @BACKBONES.register_module() class ResNet(nn.Module): def __init__(self, depth, stages=4): super().__init__() self.depth = depth self.stages = stages def forward(self, x): return x @BACKBONES.register_module() class ResNeXt(nn.Module): def __init__(self, depth, stages=4): super().__init__() self.depth = depth self.stages = stages def forward(self, x): return x cfg = dict(type='ResNet', depth=50) model = BACKBONES.build(cfg) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = dict(type='ResNeXt', depth=50, stages=3) model = BACKBONES.build(cfg) assert isinstance(model, ResNeXt) assert model.depth == 50 and model.stages == 3 cfg = [ dict(type='ResNet', depth=50), dict(type='ResNeXt', depth=50, stages=3) ] model = BACKBONES.build(cfg) assert isinstance(model, nn.Sequential) assert isinstance(model[0], ResNet) assert model[0].depth == 50 and model[0].stages == 4 assert isinstance(model[1], ResNeXt) assert model[1].depth == 50 and model[1].stages == 3 # test inherit `build_func` from parent NEW_MODELS = Registry('models', parent=BACKBONES, scope='new') assert NEW_MODELS.build_func is build_model_from_cfg # test specify `build_func` def pseudo_build(cfg): return cfg NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build) assert NEW_MODELS.build_func is pseudo_build @pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch') def test_build_sheduler_from_cfg(): import torch.nn as nn from torch.optim import SGD model = nn.Conv2d(1, 1, 1) optimizer = SGD(model.parameters(), lr=0.1) cfg = dict( type='LinearParamScheduler', optimizer=optimizer, param_name='lr', begin=0, end=100) sheduler = PARAM_SCHEDULERS.build(cfg) assert sheduler.begin == 0 assert sheduler.end == 100 cfg = dict( type='LinearParamScheduler', convert_to_iter_based=True, optimizer=optimizer, param_name='lr', begin=0, end=100, epoch_length=10) sheduler = PARAM_SCHEDULERS.build(cfg) assert sheduler.begin == 0 assert sheduler.end == 1000