662 lines
23 KiB
Python
662 lines
23 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from mmengine.config import Config, ConfigDict # type: ignore
|
|
from mmengine.registry import (DefaultScope, Registry, build_from_cfg,
|
|
build_model_from_cfg)
|
|
from mmengine.utils import ManagerMixin
|
|
|
|
|
|
class TestRegistry:
|
|
|
|
def test_init(self):
|
|
CATS = Registry('cat')
|
|
assert CATS.name == 'cat'
|
|
assert CATS.module_dict == {}
|
|
assert CATS.build_func is build_from_cfg
|
|
assert len(CATS) == 0
|
|
|
|
# test `build_func` parameter
|
|
def build_func(cfg, registry, default_args):
|
|
pass
|
|
|
|
CATS = Registry('cat', build_func=build_func)
|
|
assert CATS.build_func is build_func
|
|
|
|
# test `parent` parameter
|
|
# `parent` is either None or a `Registry` instance
|
|
with pytest.raises(AssertionError):
|
|
CATS = Registry('little_cat', parent='cat', scope='little_cat')
|
|
|
|
LITTLECATS = Registry('little_cat', parent=CATS, scope='little_cat')
|
|
assert LITTLECATS.parent is CATS
|
|
assert CATS._children.get('little_cat') is LITTLECATS
|
|
|
|
# test `scope` parameter
|
|
# `scope` is either None or a string
|
|
with pytest.raises(AssertionError):
|
|
CATS = Registry('cat', scope=1)
|
|
|
|
CATS = Registry('cat')
|
|
assert CATS.scope == 'test_registry'
|
|
|
|
CATS = Registry('cat', scope='cat')
|
|
assert CATS.scope == 'cat'
|
|
|
|
def test_split_scope_key(self):
|
|
DOGS = Registry('dogs')
|
|
|
|
scope, key = DOGS.split_scope_key('BloodHound')
|
|
assert scope is None and key == 'BloodHound'
|
|
scope, key = DOGS.split_scope_key('hound.BloodHound')
|
|
assert scope == 'hound' and key == 'BloodHound'
|
|
scope, key = DOGS.split_scope_key('hound.little_hound.Dachshund')
|
|
assert scope == 'hound' and key == 'little_hound.Dachshund'
|
|
|
|
def test_register_module(self):
|
|
CATS = Registry('cat')
|
|
|
|
@CATS.register_module()
|
|
def muchkin():
|
|
pass
|
|
|
|
assert CATS.get('muchkin') is muchkin
|
|
assert 'muchkin' in CATS
|
|
|
|
# can only decorate a class or a function
|
|
with pytest.raises(TypeError):
|
|
|
|
class Demo:
|
|
|
|
def some_method(self):
|
|
pass
|
|
|
|
method = Demo().some_method
|
|
CATS.register_module(name='some_method', module=method)
|
|
|
|
# test `name` parameter which must be either of None, a string or a
|
|
# sequence of string
|
|
# `name` is None
|
|
@CATS.register_module()
|
|
class BritishShorthair:
|
|
pass
|
|
|
|
assert len(CATS) == 2
|
|
assert CATS.get('BritishShorthair') is BritishShorthair
|
|
|
|
# `name` is a string
|
|
@CATS.register_module(name='Munchkin')
|
|
class Munchkin:
|
|
pass
|
|
|
|
assert len(CATS) == 3
|
|
assert CATS.get('Munchkin') is Munchkin
|
|
assert 'Munchkin' in CATS
|
|
|
|
# `name` is a sequence of string
|
|
@CATS.register_module(name=['Siamese', 'Siamese2'])
|
|
class SiameseCat:
|
|
pass
|
|
|
|
assert CATS.get('Siamese') is SiameseCat
|
|
assert CATS.get('Siamese2') is SiameseCat
|
|
assert len(CATS) == 5
|
|
|
|
# `name` is an invalid type
|
|
with pytest.raises(
|
|
TypeError,
|
|
match=('name must be None, an instance of str, or a sequence '
|
|
"of str, but got <class 'int'>")):
|
|
|
|
@CATS.register_module(name=7474741)
|
|
class SiameseCat:
|
|
pass
|
|
|
|
# test `force` parameter, which must be a boolean
|
|
# force is not a boolean
|
|
with pytest.raises(
|
|
TypeError,
|
|
match="force must be a boolean, but got <class 'int'>"):
|
|
|
|
@CATS.register_module(force=1)
|
|
class BritishShorthair:
|
|
pass
|
|
|
|
# force=False
|
|
with pytest.raises(
|
|
KeyError,
|
|
match='BritishShorthair is already registered in cat '
|
|
'at test_registry'):
|
|
|
|
@CATS.register_module()
|
|
class BritishShorthair:
|
|
pass
|
|
|
|
# force=True
|
|
@CATS.register_module(force=True)
|
|
class BritishShorthair:
|
|
pass
|
|
|
|
assert len(CATS) == 5
|
|
|
|
# test `module` parameter, which is either None or a class
|
|
# when the `register_module`` is called as a method rather than a
|
|
# decorator, which must be a class
|
|
with pytest.raises(
|
|
TypeError,
|
|
match='module must be a class or a function,'
|
|
" but got <class 'str'>"):
|
|
CATS.register_module(module='string')
|
|
|
|
class SphynxCat:
|
|
pass
|
|
|
|
CATS.register_module(module=SphynxCat)
|
|
assert CATS.get('SphynxCat') is SphynxCat
|
|
assert len(CATS) == 6
|
|
|
|
CATS.register_module(name='Sphynx1', module=SphynxCat)
|
|
assert CATS.get('Sphynx1') is SphynxCat
|
|
assert len(CATS) == 7
|
|
|
|
CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
|
|
assert CATS.get('Sphynx2') is SphynxCat
|
|
assert CATS.get('Sphynx3') is SphynxCat
|
|
assert len(CATS) == 9
|
|
|
|
def _build_registry(self):
|
|
"""A helper function to build a Hierarchical Registry."""
|
|
# Hierarchical Registry
|
|
# DOGS
|
|
# _______|_______
|
|
# | |
|
|
# HOUNDS (hound) SAMOYEDS (samoyed)
|
|
# _______|_______ |
|
|
# | | |
|
|
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
|
# (little_hound) (mid_hound) (little_samoyed)
|
|
registries = []
|
|
DOGS = Registry('dogs')
|
|
registries.append(DOGS)
|
|
HOUNDS = Registry('hounds', parent=DOGS, scope='hound')
|
|
registries.append(HOUNDS)
|
|
LITTLE_HOUNDS = Registry(
|
|
'little hounds', parent=HOUNDS, scope='little_hound')
|
|
registries.append(LITTLE_HOUNDS)
|
|
MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound')
|
|
registries.append(MID_HOUNDS)
|
|
SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed')
|
|
registries.append(SAMOYEDS)
|
|
LITTLE_SAMOYEDS = Registry(
|
|
'little samoyeds', parent=SAMOYEDS, scope='little_samoyed')
|
|
registries.append(LITTLE_SAMOYEDS)
|
|
|
|
return registries
|
|
|
|
def test__get_root_registry(self):
|
|
# Hierarchical Registry
|
|
# DOGS
|
|
# _______|_______
|
|
# | |
|
|
# HOUNDS (hound) SAMOYEDS (samoyed)
|
|
# _______|_______ |
|
|
# | | |
|
|
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
|
# (little_hound) (mid_hound) (little_samoyed)
|
|
registries = self._build_registry()
|
|
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
|
|
|
|
assert DOGS._get_root_registry() is DOGS
|
|
assert HOUNDS._get_root_registry() is DOGS
|
|
assert LITTLE_HOUNDS._get_root_registry() is DOGS
|
|
assert MID_HOUNDS._get_root_registry() is DOGS
|
|
|
|
def test_get(self):
|
|
# Hierarchical Registry
|
|
# DOGS
|
|
# _______|_______
|
|
# | |
|
|
# HOUNDS (hound) SAMOYEDS (samoyed)
|
|
# _______|_______ |
|
|
# | | |
|
|
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
|
# (little_hound) (mid_hound) (little_samoyed)
|
|
registries = self._build_registry()
|
|
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
|
|
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
|
|
|
|
@DOGS.register_module()
|
|
class GoldenRetriever:
|
|
pass
|
|
|
|
assert len(DOGS) == 1
|
|
assert DOGS.get('GoldenRetriever') is GoldenRetriever
|
|
|
|
@HOUNDS.register_module()
|
|
class BloodHound:
|
|
pass
|
|
|
|
assert len(HOUNDS) == 1
|
|
# get key from current registry
|
|
assert HOUNDS.get('BloodHound') is BloodHound
|
|
# get key from its children
|
|
assert DOGS.get('hound.BloodHound') is BloodHound
|
|
# get key from current registry
|
|
assert HOUNDS.get('hound.BloodHound') is BloodHound
|
|
|
|
# If the key is not found in the current registry, then look for its
|
|
# parent
|
|
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
|
|
|
|
@LITTLE_HOUNDS.register_module()
|
|
class Dachshund:
|
|
pass
|
|
|
|
assert len(LITTLE_HOUNDS) == 1
|
|
# get key from current registry
|
|
assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
|
|
# get key from its parent
|
|
assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
|
|
# get key from its children
|
|
assert HOUNDS.get('little_hound.Dachshund') is Dachshund
|
|
# get key from its descendants
|
|
assert DOGS.get('hound.little_hound.Dachshund') is Dachshund
|
|
|
|
# If the key is not found in the current registry, then look for its
|
|
# parent
|
|
assert LITTLE_HOUNDS.get('BloodHound') is BloodHound
|
|
assert LITTLE_HOUNDS.get('GoldenRetriever') is GoldenRetriever
|
|
|
|
@MID_HOUNDS.register_module()
|
|
class Beagle:
|
|
pass
|
|
|
|
# get key from its sibling registries
|
|
assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
|
|
|
|
@SAMOYEDS.register_module()
|
|
class PedigreeSamoyed:
|
|
pass
|
|
|
|
assert len(SAMOYEDS) == 1
|
|
# get key from its uncle
|
|
assert LITTLE_HOUNDS.get('samoyed.PedigreeSamoyed') is PedigreeSamoyed
|
|
|
|
@LITTLE_SAMOYEDS.register_module()
|
|
class LittlePedigreeSamoyed:
|
|
pass
|
|
|
|
# get key from its cousin
|
|
assert LITTLE_HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
|
|
) is LittlePedigreeSamoyed
|
|
|
|
# get key from its nephews
|
|
assert HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
|
|
) is LittlePedigreeSamoyed
|
|
|
|
# invalid keys
|
|
# GoldenRetrieverererer can not be found at LITTLE_HOUNDS modules
|
|
assert LITTLE_HOUNDS.get('GoldenRetrieverererer') is None
|
|
# samoyedddd is not a child of DOGS
|
|
assert DOGS.get('samoyedddd.PedigreeSamoyed') is None
|
|
# samoyed is a child of DOGS but LittlePedigreeSamoyed can not be found
|
|
# at SAMOYEDS modules
|
|
assert DOGS.get('samoyed.LittlePedigreeSamoyed') is None
|
|
assert LITTLE_HOUNDS.get('mid_hound.PedigreeSamoyedddddd') is None
|
|
|
|
def test__search_child(self):
|
|
# Hierarchical Registry
|
|
# DOGS
|
|
# _______|_______
|
|
# | |
|
|
# HOUNDS (hound) SAMOYEDS (samoyed)
|
|
# _______|_______ |
|
|
# | | |
|
|
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
|
# (little_hound) (mid_hound) (little_samoyed)
|
|
registries = self._build_registry()
|
|
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
|
|
|
|
assert DOGS._search_child('hound') is HOUNDS
|
|
assert DOGS._search_child('not a child') is None
|
|
assert DOGS._search_child('little_hound') is LITTLE_HOUNDS
|
|
assert LITTLE_HOUNDS._search_child('hound') is None
|
|
assert LITTLE_HOUNDS._search_child('mid_hound') is None
|
|
|
|
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
|
def test_build(self, cfg_type):
|
|
# Hierarchical Registry
|
|
# DOGS
|
|
# _______|_______
|
|
# | |
|
|
# HOUNDS (hound) SAMOYEDS (samoyed)
|
|
# _______|_______ |
|
|
# | | |
|
|
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
|
# (little_hound) (mid_hound) (little_samoyed)
|
|
registries = self._build_registry()
|
|
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
|
|
|
|
@DOGS.register_module()
|
|
class GoldenRetriever:
|
|
pass
|
|
|
|
gr_cfg = cfg_type(dict(type='GoldenRetriever'))
|
|
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
|
|
|
|
@HOUNDS.register_module()
|
|
class BloodHound:
|
|
pass
|
|
|
|
bh_cfg = cfg_type(dict(type='BloodHound'))
|
|
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
|
|
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
|
|
|
|
@LITTLE_HOUNDS.register_module()
|
|
class Dachshund:
|
|
pass
|
|
|
|
d_cfg = cfg_type(dict(type='Dachshund'))
|
|
assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund)
|
|
|
|
@MID_HOUNDS.register_module()
|
|
class Beagle:
|
|
pass
|
|
|
|
b_cfg = cfg_type(dict(type='Beagle'))
|
|
assert isinstance(MID_HOUNDS.build(b_cfg), Beagle)
|
|
|
|
# test `default_scope`
|
|
# switch the current registry to another registry
|
|
DefaultScope.get_instance(
|
|
f'test-{time.time()}', scope_name='mid_hound')
|
|
dog = LITTLE_HOUNDS.build(b_cfg)
|
|
assert isinstance(dog, Beagle)
|
|
|
|
# `default_scope` can not be found
|
|
DefaultScope.get_instance(
|
|
f'test2-{time.time()}', scope_name='scope-not-found')
|
|
dog = MID_HOUNDS.build(b_cfg)
|
|
assert isinstance(dog, Beagle)
|
|
|
|
# test overwrite default scope with `_scope_`
|
|
@SAMOYEDS.register_module()
|
|
class MySamoyed:
|
|
|
|
def __init__(self, friend):
|
|
self.friend = DOGS.build(friend)
|
|
|
|
@SAMOYEDS.register_module()
|
|
class YourSamoyed:
|
|
pass
|
|
|
|
s_cfg = cfg_type(
|
|
dict(
|
|
_scope_='samoyed',
|
|
type='MySamoyed',
|
|
friend=dict(type='hound.BloodHound')))
|
|
dog = DOGS.build(s_cfg)
|
|
assert isinstance(dog, MySamoyed)
|
|
assert isinstance(dog.friend, BloodHound)
|
|
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
|
|
|
s_cfg = cfg_type(
|
|
dict(
|
|
_scope_='samoyed',
|
|
type='MySamoyed',
|
|
friend=dict(type='YourSamoyed')))
|
|
dog = DOGS.build(s_cfg)
|
|
assert isinstance(dog, MySamoyed)
|
|
assert isinstance(dog.friend, YourSamoyed)
|
|
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
|
|
|
def test_get_registry_by_scope(self):
|
|
DOGS = Registry('dogs')
|
|
HOUNDS = Registry('hounds', scope='hound', parent=DOGS)
|
|
SAMOYEDS = Registry('samoyeds', scope='samoyed', parent=DOGS)
|
|
CHIHUAHUA = Registry('chihuahuas', scope='chihuahua', parent=DOGS)
|
|
|
|
# Hierarchical Registry
|
|
# DOGS
|
|
# ___________________|___________________
|
|
# | | |
|
|
# HOUNDS (hound) SAMOYEDS (samoyed) CHIHUAHUA (chihuahua)
|
|
|
|
DefaultScope.get_instance(
|
|
f'scope_{time.time()}', scope_name='chihuahua')
|
|
assert DefaultScope.get_current_instance().scope_name == 'chihuahua'
|
|
|
|
# Test switch scope and get target registry.
|
|
with CHIHUAHUA.switch_scope_and_registry(scope='hound') as \
|
|
registry:
|
|
assert DefaultScope.get_current_instance().scope_name == 'hound'
|
|
assert id(registry) == id(HOUNDS)
|
|
|
|
# Test nested-ly switch scope.
|
|
with CHIHUAHUA.switch_scope_and_registry(scope='samoyed') as \
|
|
samoyed_registry:
|
|
assert DefaultScope.get_current_instance().scope_name == 'samoyed'
|
|
assert id(samoyed_registry) == id(SAMOYEDS)
|
|
|
|
with CHIHUAHUA.switch_scope_and_registry(scope='hound') as \
|
|
hound_registry:
|
|
assert DefaultScope.get_current_instance().scope_name == \
|
|
'hound'
|
|
assert id(hound_registry) == id(HOUNDS)
|
|
|
|
# Test switch to original scope
|
|
assert DefaultScope.get_current_instance().scope_name == 'chihuahua'
|
|
|
|
# Test get an unknown registry.
|
|
with CHIHUAHUA.switch_scope_and_registry(scope='unknown') as \
|
|
registry:
|
|
assert id(registry) == id(CHIHUAHUA)
|
|
assert DefaultScope.get_current_instance().scope_name == 'unknown'
|
|
|
|
def test_repr(self):
|
|
CATS = Registry('cat')
|
|
|
|
@CATS.register_module()
|
|
class BritishShorthair:
|
|
pass
|
|
|
|
@CATS.register_module()
|
|
class Munchkin:
|
|
pass
|
|
|
|
repr_str = 'Registry(name=cat, items={'
|
|
repr_str += (
|
|
"'BritishShorthair': <class 'test_registry.TestRegistry.test_repr."
|
|
"<locals>.BritishShorthair'>, ")
|
|
repr_str += (
|
|
"'Munchkin': <class 'test_registry.TestRegistry.test_repr."
|
|
"<locals>.Munchkin'>")
|
|
repr_str += '})'
|
|
assert repr(CATS) == repr_str
|
|
|
|
|
|
@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
|
|
def test_build_from_cfg(cfg_type):
|
|
BACKBONES = Registry('backbone')
|
|
|
|
@BACKBONES.register_module()
|
|
class ResNet:
|
|
|
|
def __init__(self, depth, stages=4):
|
|
self.depth = depth
|
|
self.stages = stages
|
|
|
|
@BACKBONES.register_module()
|
|
class ResNeXt:
|
|
|
|
def __init__(self, depth, stages=4):
|
|
self.depth = depth
|
|
self.stages = stages
|
|
|
|
# test `cfg` parameter
|
|
# `cfg` should be a dict, ConfigDict or Config object
|
|
with pytest.raises(
|
|
TypeError,
|
|
match=('cfg should be a dict, ConfigDict or Config, but got '
|
|
"<class 'str'>")):
|
|
cfg = 'ResNet'
|
|
model = build_from_cfg(cfg, BACKBONES)
|
|
|
|
# `cfg` is a dict, ConfigDict or Config object
|
|
cfg = cfg_type(dict(type='ResNet', depth=50))
|
|
model = build_from_cfg(cfg, BACKBONES)
|
|
assert isinstance(model, ResNet)
|
|
assert model.depth == 50 and model.stages == 4
|
|
|
|
# `cfg` is a dict but it does not contain the key "type"
|
|
with pytest.raises(KeyError, match='must contain the key "type"'):
|
|
cfg = dict(depth=50, stages=4)
|
|
cfg = cfg_type(cfg)
|
|
model = build_from_cfg(cfg, BACKBONES)
|
|
|
|
# cfg['type'] should be a str or class
|
|
with pytest.raises(
|
|
TypeError,
|
|
match="type must be a str or valid type, but got <class 'int'>"):
|
|
cfg = dict(type=1000)
|
|
cfg = cfg_type(cfg)
|
|
model = build_from_cfg(cfg, BACKBONES)
|
|
|
|
cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
|
|
model = build_from_cfg(cfg, BACKBONES)
|
|
assert isinstance(model, ResNeXt)
|
|
assert model.depth == 50 and model.stages == 3
|
|
|
|
cfg = cfg_type(dict(type=ResNet, depth=50))
|
|
model = build_from_cfg(cfg, BACKBONES)
|
|
assert isinstance(model, ResNet)
|
|
assert model.depth == 50 and model.stages == 4
|
|
|
|
# non-registered class
|
|
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
|
|
cfg = cfg_type(dict(type='VGG'))
|
|
model = build_from_cfg(cfg, BACKBONES)
|
|
|
|
# `cfg` contains unexpected arguments
|
|
with pytest.raises(TypeError):
|
|
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
|
|
model = build_from_cfg(cfg, BACKBONES)
|
|
|
|
# test `default_args` parameter
|
|
cfg = cfg_type(dict(type='ResNet', depth=50))
|
|
model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
|
|
assert isinstance(model, ResNet)
|
|
assert model.depth == 50 and model.stages == 3
|
|
|
|
# default_args must be a dict or None
|
|
with pytest.raises(TypeError):
|
|
cfg = cfg_type(dict(type='ResNet', depth=50))
|
|
model = build_from_cfg(cfg, BACKBONES, default_args=1)
|
|
|
|
# cfg or default_args should contain the key "type"
|
|
with pytest.raises(KeyError, match='must contain the key "type"'):
|
|
cfg = cfg_type(dict(depth=50))
|
|
model = build_from_cfg(
|
|
cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
|
|
|
|
# "type" defined using default_args
|
|
cfg = cfg_type(dict(depth=50))
|
|
model = build_from_cfg(
|
|
cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
|
|
assert isinstance(model, ResNet)
|
|
assert model.depth == 50 and model.stages == 4
|
|
|
|
cfg = cfg_type(dict(depth=50))
|
|
model = build_from_cfg(
|
|
cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
|
|
assert isinstance(model, ResNet)
|
|
assert model.depth == 50 and model.stages == 4
|
|
|
|
# test `registry` parameter
|
|
# incorrect registry type
|
|
with pytest.raises(
|
|
TypeError,
|
|
match=('registry must be a mmengine.Registry object, but got '
|
|
"<class 'str'>")):
|
|
cfg = cfg_type(dict(type='ResNet', depth=50))
|
|
model = build_from_cfg(cfg, 'BACKBONES')
|
|
|
|
VISUALIZER = Registry('visualizer')
|
|
|
|
@VISUALIZER.register_module()
|
|
class Visualizer(ManagerMixin):
|
|
|
|
def __init__(self, name):
|
|
super().__init__(name)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
Visualizer.get_current_instance()
|
|
cfg = dict(type='Visualizer', name='visualizer')
|
|
build_from_cfg(cfg, VISUALIZER)
|
|
Visualizer.get_current_instance()
|
|
|
|
|
|
def test_build_model_from_cfg():
|
|
try:
|
|
import torch.nn as nn
|
|
except ImportError:
|
|
pytest.skip('require torch')
|
|
|
|
BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
|
|
|
|
@BACKBONES.register_module()
|
|
class ResNet(nn.Module):
|
|
|
|
def __init__(self, depth, stages=4):
|
|
super().__init__()
|
|
self.depth = depth
|
|
self.stages = stages
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
@BACKBONES.register_module()
|
|
class ResNeXt(nn.Module):
|
|
|
|
def __init__(self, depth, stages=4):
|
|
super().__init__()
|
|
self.depth = depth
|
|
self.stages = stages
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
cfg = dict(type='ResNet', depth=50)
|
|
model = BACKBONES.build(cfg)
|
|
assert isinstance(model, ResNet)
|
|
assert model.depth == 50 and model.stages == 4
|
|
|
|
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
|
model = BACKBONES.build(cfg)
|
|
assert isinstance(model, ResNeXt)
|
|
assert model.depth == 50 and model.stages == 3
|
|
|
|
cfg = [
|
|
dict(type='ResNet', depth=50),
|
|
dict(type='ResNeXt', depth=50, stages=3)
|
|
]
|
|
model = BACKBONES.build(cfg)
|
|
assert isinstance(model, nn.Sequential)
|
|
assert isinstance(model[0], ResNet)
|
|
assert model[0].depth == 50 and model[0].stages == 4
|
|
assert isinstance(model[1], ResNeXt)
|
|
assert model[1].depth == 50 and model[1].stages == 3
|
|
|
|
# test inherit `build_func` from parent
|
|
NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
|
|
assert NEW_MODELS.build_func is build_model_from_cfg
|
|
|
|
# test specify `build_func`
|
|
def pseudo_build(cfg):
|
|
return cfg
|
|
|
|
NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
|
|
assert NEW_MODELS.build_func is pseudo_build
|