mmengine/tests/test_registry/test_build_functions.py

219 lines
6.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin,
Registry, build_from_cfg, build_model_from_cfg)
from mmengine.utils import is_installed
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
def test_build_from_cfg(cfg_type):
BACKBONES = Registry('backbone')
@BACKBONES.register_module()
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module()
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
# test `cfg` parameter
# `cfg` should be a dict, ConfigDict or Config object
with pytest.raises(
TypeError,
match=('cfg should be a dict, ConfigDict or Config, but got '
"<class 'str'>")):
cfg = 'ResNet'
model = build_from_cfg(cfg, BACKBONES)
# `cfg` is a dict, ConfigDict or Config object
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# `cfg` is a dict but it does not contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
cfg = cfg_type(cfg)
model = build_from_cfg(cfg, BACKBONES)
# cfg['type'] should be a str or class
with pytest.raises(
TypeError,
match="type must be a str or valid type, but got <class 'int'>"):
cfg = dict(type=1000)
cfg = cfg_type(cfg)
model = build_from_cfg(cfg, BACKBONES)
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = cfg_type(dict(type=ResNet, depth=50))
model = build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# non-registered class
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
cfg = cfg_type(dict(type='VGG'))
model = build_from_cfg(cfg, BACKBONES)
# `cfg` contains unexpected arguments
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
model = build_from_cfg(cfg, BACKBONES)
# test `default_args` parameter
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
# default_args must be a dict or None
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
# "type" defined using default_args
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = cfg_type(dict(depth=50))
model = build_from_cfg(
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# test `registry` parameter
# incorrect registry type
with pytest.raises(
TypeError,
match=('registry must be a mmengine.Registry object, but got '
"<class 'str'>")):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, 'BACKBONES')
VISUALIZER = Registry('visualizer')
@VISUALIZER.register_module()
class Visualizer(ManagerMixin):
def __init__(self, name):
super().__init__(name)
with pytest.raises(RuntimeError):
Visualizer.get_current_instance()
cfg = dict(type='Visualizer', name='visualizer')
build_from_cfg(cfg, VISUALIZER)
Visualizer.get_current_instance()
@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch')
def test_build_model_from_cfg():
import torch.nn as nn
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
class ResNet(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
@BACKBONES.register_module()
class ResNeXt(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
cfg = dict(type='ResNet', depth=50)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = [
dict(type='ResNet', depth=50),
dict(type='ResNeXt', depth=50, stages=3)
]
model = BACKBONES.build(cfg)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], ResNet)
assert model[0].depth == 50 and model[0].stages == 4
assert isinstance(model[1], ResNeXt)
assert model[1].depth == 50 and model[1].stages == 3
# test inherit `build_func` from parent
NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
assert NEW_MODELS.build_func is build_model_from_cfg
# test specify `build_func`
def pseudo_build(cfg):
return cfg
NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
assert NEW_MODELS.build_func is pseudo_build
@pytest.mark.skipif(not is_installed('torch'), reason='tests requires torch')
def test_build_scheduler_from_cfg():
import torch.nn as nn
from torch.optim import SGD
model = nn.Conv2d(1, 1, 1)
optimizer = SGD(model.parameters(), lr=0.1)
cfg = dict(
type='LinearParamScheduler',
optimizer=optimizer,
param_name='lr',
begin=0,
end=100)
scheduler = PARAM_SCHEDULERS.build(cfg)
assert scheduler.begin == 0
assert scheduler.end == 100
cfg = dict(
type='LinearParamScheduler',
convert_to_iter_based=True,
optimizer=optimizer,
param_name='lr',
begin=0,
end=100,
epoch_length=10)
scheduler = PARAM_SCHEDULERS.build(cfg)
assert scheduler.begin == 0
assert scheduler.end == 1000