mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Add the unittest of registry (#5)
* Add the unittest of registry * improve the format * fix the test_repr * refactor the test_init * add the copyright for tests directory * add more unit tests * fix error * add unit tests for _search_child and _get_root_registry * build_from_cfg supports dict, ConfictDict and Config * improve docstring
This commit is contained in:
parent
27dd617532
commit
6ba7ac3a8e
@ -48,7 +48,7 @@ repos:
|
||||
rev: v0.2.0
|
||||
hooks:
|
||||
- id: check-copyright
|
||||
args: ["mmengine"]
|
||||
args: ["mmengine", "tests"]
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.812
|
||||
hooks:
|
||||
|
478
tests/test_registry/test_registry.py
Normal file
478
tests/test_registry/test_registry.py
Normal file
@ -0,0 +1,478 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
|
||||
from mmengine import Config, ConfigDict, Registry, build_from_cfg
|
||||
|
||||
|
||||
class TestRegistry:
|
||||
|
||||
def test_init(self):
|
||||
CATS = Registry('cat')
|
||||
assert CATS.name == 'cat'
|
||||
assert CATS.module_dict == {}
|
||||
assert CATS.build_func is build_from_cfg
|
||||
assert len(CATS) == 0
|
||||
|
||||
# test `build_func` parameter
|
||||
def build_func(cfg, registry, default_args):
|
||||
pass
|
||||
|
||||
CATS = Registry('cat', build_func=build_func)
|
||||
assert CATS.build_func is build_func
|
||||
|
||||
# test `parent` parameter
|
||||
# `parent` is either None or a `Registry` instance
|
||||
with pytest.raises(AssertionError):
|
||||
CATS = Registry('little_cat', parent='cat', scope='little_cat')
|
||||
|
||||
LITTLECATS = Registry('little_cat', parent=CATS, scope='little_cat')
|
||||
assert LITTLECATS.parent is CATS
|
||||
assert CATS._children.get('little_cat') is LITTLECATS
|
||||
|
||||
# test `scope` parameter
|
||||
# `scope` is either None or a string
|
||||
with pytest.raises(AssertionError):
|
||||
CATS = Registry('cat', scope=1)
|
||||
|
||||
CATS = Registry('cat')
|
||||
assert CATS.scope == 'test_registry'
|
||||
|
||||
CATS = Registry('cat', scope='cat')
|
||||
assert CATS.scope == 'cat'
|
||||
|
||||
def test_split_scope_key(self):
|
||||
DOGS = Registry('dogs')
|
||||
|
||||
scope, key = DOGS.split_scope_key('BloodHound')
|
||||
assert scope is None and key == 'BloodHound'
|
||||
scope, key = DOGS.split_scope_key('hound.BloodHound')
|
||||
assert scope == 'hound' and key == 'BloodHound'
|
||||
scope, key = DOGS.split_scope_key('hound.little_hound.Dachshund')
|
||||
assert scope == 'hound' and key == 'little_hound.Dachshund'
|
||||
|
||||
def test_register_module(self):
|
||||
CATS = Registry('cat')
|
||||
|
||||
# can only decorate a class
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
@CATS.register_module()
|
||||
def some_method():
|
||||
pass
|
||||
|
||||
# test `name` parameter which must be either of None, a string or a
|
||||
# sequence of string
|
||||
# `name` is None
|
||||
@CATS.register_module()
|
||||
class BritishShorthair:
|
||||
pass
|
||||
|
||||
assert len(CATS) == 1
|
||||
assert CATS.get('BritishShorthair') is BritishShorthair
|
||||
|
||||
# `name` is a string
|
||||
@CATS.register_module(name='Munchkin')
|
||||
class Munchkin:
|
||||
pass
|
||||
|
||||
assert len(CATS) == 2
|
||||
assert CATS.get('Munchkin') is Munchkin
|
||||
assert 'Munchkin' in CATS
|
||||
|
||||
# `name` is a sequence of string
|
||||
@CATS.register_module(name=['Siamese', 'Siamese2'])
|
||||
class SiameseCat:
|
||||
pass
|
||||
|
||||
assert CATS.get('Siamese') is SiameseCat
|
||||
assert CATS.get('Siamese2') is SiameseCat
|
||||
assert len(CATS) == 4
|
||||
|
||||
# `name` is an invalid type
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('name must be either of None, an instance of str or a '
|
||||
"sequence of str, but got <class 'int'>")):
|
||||
|
||||
@CATS.register_module(name=7474741)
|
||||
class SiameseCat:
|
||||
pass
|
||||
|
||||
# test `force` parameter, which must be a boolean
|
||||
# force is not a boolean
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="force must be a boolean, but got <class 'int'>"):
|
||||
|
||||
@CATS.register_module(force=1)
|
||||
class BritishShorthair:
|
||||
pass
|
||||
|
||||
# force=False
|
||||
with pytest.raises(
|
||||
KeyError,
|
||||
match='BritishShorthair is already registered in cat'):
|
||||
|
||||
@CATS.register_module()
|
||||
class BritishShorthair:
|
||||
pass
|
||||
|
||||
# force=True
|
||||
@CATS.register_module(force=True)
|
||||
class BritishShorthair:
|
||||
pass
|
||||
|
||||
assert len(CATS) == 4
|
||||
|
||||
# test `module` parameter, which is either None or a class
|
||||
# when the `register_module`` is called as a method rather than a
|
||||
# decorator, which must be a class
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="module must be a class, but got <class 'str'>"):
|
||||
CATS.register_module(module='string')
|
||||
|
||||
class SphynxCat:
|
||||
pass
|
||||
|
||||
CATS.register_module(module=SphynxCat)
|
||||
assert CATS.get('SphynxCat') is SphynxCat
|
||||
assert len(CATS) == 5
|
||||
|
||||
CATS.register_module(name='Sphynx1', module=SphynxCat)
|
||||
assert CATS.get('Sphynx1') is SphynxCat
|
||||
assert len(CATS) == 6
|
||||
|
||||
CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
|
||||
assert CATS.get('Sphynx2') is SphynxCat
|
||||
assert CATS.get('Sphynx3') is SphynxCat
|
||||
assert len(CATS) == 8
|
||||
|
||||
def _build_registry(self):
|
||||
"""A helper function to build a hierarchy registry."""
|
||||
# Hierarchy Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
# HOUNDS (hound) SAMOYEDS (samoyed)
|
||||
# _______|_______ |
|
||||
# | | |
|
||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||
# (little_hound) (mid_hound) (little_samoyed)
|
||||
registries = []
|
||||
DOGS = Registry('dogs')
|
||||
registries.append(DOGS)
|
||||
HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
|
||||
registries.append(HOUNDS)
|
||||
LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound')
|
||||
registries.append(LITTLE_HOUNDS)
|
||||
MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
|
||||
registries.append(MID_HOUNDS)
|
||||
SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
|
||||
registries.append(SAMOYEDS)
|
||||
LITTLE_SAMOYEDS = Registry(
|
||||
'dogs', parent=SAMOYEDS, scope='little_samoyed')
|
||||
registries.append(LITTLE_SAMOYEDS)
|
||||
|
||||
return registries
|
||||
|
||||
def test_get_root_registry(self):
|
||||
# Hierarchy Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
# HOUNDS (hound) SAMOYEDS (samoyed)
|
||||
# _______|_______ |
|
||||
# | | |
|
||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||
# (little_hound) (mid_hound) (little_samoyed)
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
|
||||
|
||||
assert DOGS._get_root_registry() is DOGS
|
||||
assert HOUNDS._get_root_registry() is DOGS
|
||||
assert LITTLE_HOUNDS._get_root_registry() is DOGS
|
||||
assert MID_HOUNDS._get_root_registry() is DOGS
|
||||
|
||||
def test_get(self):
|
||||
# Hierarchy Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
# HOUNDS (hound) SAMOYEDS (samoyed)
|
||||
# _______|_______ |
|
||||
# | | |
|
||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||
# (little_hound) (mid_hound) (little_samoyed)
|
||||
registries = self._build_registry()
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
|
||||
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
|
||||
|
||||
@DOGS.register_module()
|
||||
class GoldenRetriever:
|
||||
pass
|
||||
|
||||
assert len(DOGS) == 1
|
||||
assert DOGS.get('GoldenRetriever') is GoldenRetriever
|
||||
|
||||
@HOUNDS.register_module()
|
||||
class BloodHound:
|
||||
pass
|
||||
|
||||
assert len(HOUNDS) == 1
|
||||
# get key from current registry
|
||||
assert HOUNDS.get('BloodHound') is BloodHound
|
||||
# get key from its children
|
||||
assert DOGS.get('hound.BloodHound') is BloodHound
|
||||
# get key from current registry
|
||||
assert HOUNDS.get('hound.BloodHound') is BloodHound
|
||||
|
||||
# If the key is not found in the current registry, then look for its
|
||||
# parent
|
||||
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
|
||||
|
||||
@LITTLE_HOUNDS.register_module()
|
||||
class Dachshund:
|
||||
pass
|
||||
|
||||
assert len(LITTLE_HOUNDS) == 1
|
||||
# get key from current registry
|
||||
assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
|
||||
# get key from its parent
|
||||
assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
|
||||
# get key from its children
|
||||
assert HOUNDS.get('little_hound.Dachshund') is Dachshund
|
||||
# get key from its descendants
|
||||
assert DOGS.get('hound.little_hound.Dachshund') is Dachshund
|
||||
|
||||
# If the key is not found in the current registry, then look for its
|
||||
# parent
|
||||
assert LITTLE_HOUNDS.get('BloodHound') is BloodHound
|
||||
assert LITTLE_HOUNDS.get('GoldenRetriever') is GoldenRetriever
|
||||
|
||||
@MID_HOUNDS.register_module()
|
||||
class Beagle:
|
||||
pass
|
||||
|
||||
# get key from its sibling registries
|
||||
assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
|
||||
|
||||
@SAMOYEDS.register_module()
|
||||
class PedigreeSamoyed:
|
||||
pass
|
||||
|
||||
assert len(SAMOYEDS) == 1
|
||||
# get key from its uncle
|
||||
assert LITTLE_HOUNDS.get('samoyed.PedigreeSamoyed') is PedigreeSamoyed
|
||||
|
||||
@LITTLE_SAMOYEDS.register_module()
|
||||
class LittlePedigreeSamoyed:
|
||||
pass
|
||||
|
||||
# get key from its cousin
|
||||
assert LITTLE_HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
|
||||
) is LittlePedigreeSamoyed
|
||||
|
||||
# get key from its nephews
|
||||
assert HOUNDS.get('samoyed.little_samoyed.LittlePedigreeSamoyed'
|
||||
) is LittlePedigreeSamoyed
|
||||
|
||||
def test_search_child(self):
|
||||
# Hierarchy Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
# HOUNDS (hound) SAMOYEDS (samoyed)
|
||||
# _______|_______ |
|
||||
# | | |
|
||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||
# (little_hound) (mid_hound) (little_samoyed)
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = self._build_registry()
|
||||
|
||||
assert DOGS._search_child('hound') is HOUNDS
|
||||
assert DOGS._search_child('not a child') is None
|
||||
assert DOGS._search_child('little_hound') is LITTLE_HOUNDS
|
||||
assert LITTLE_HOUNDS._search_child('hound') is None
|
||||
assert LITTLE_HOUNDS._search_child('mid_hound') is None
|
||||
|
||||
def test_build(self):
|
||||
# Hierarchy Registry
|
||||
# DOGS
|
||||
# _______|_______
|
||||
# | |
|
||||
# HOUNDS (hound) SAMOYEDS (samoyed)
|
||||
# _______|_______ |
|
||||
# | | |
|
||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||
# (little_hound) (mid_hound) (little_samoyed)
|
||||
registries = self._build_registry()
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
|
||||
|
||||
@DOGS.register_module()
|
||||
class GoldenRetriever:
|
||||
pass
|
||||
|
||||
gr_cfg = dict(type='GoldenRetriever')
|
||||
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
|
||||
|
||||
@HOUNDS.register_module()
|
||||
class BloodHound:
|
||||
pass
|
||||
|
||||
bh_cfg = dict(type='BloodHound')
|
||||
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
|
||||
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
|
||||
|
||||
@LITTLE_HOUNDS.register_module()
|
||||
class Dachshund:
|
||||
pass
|
||||
|
||||
d_cfg = dict(type='Dachshund')
|
||||
assert isinstance(LITTLE_HOUNDS.build(d_cfg), Dachshund)
|
||||
|
||||
@MID_HOUNDS.register_module()
|
||||
class Beagle:
|
||||
pass
|
||||
|
||||
b_cfg = dict(type='Beagle')
|
||||
assert isinstance(MID_HOUNDS.build(b_cfg), Beagle)
|
||||
|
||||
# test `default_scope`
|
||||
# `default_scope` is an invalid scope
|
||||
with pytest.raises(KeyError):
|
||||
LITTLE_HOUNDS.build(b_cfg, default_scope='invalid_mid_hound')
|
||||
|
||||
# switch the current registry to another registry
|
||||
dog = LITTLE_HOUNDS.build(b_cfg, default_scope='mid_hound')
|
||||
assert isinstance(dog, Beagle)
|
||||
|
||||
def test_repr(self):
|
||||
CATS = Registry('cat')
|
||||
|
||||
@CATS.register_module()
|
||||
class BritishShorthair:
|
||||
pass
|
||||
|
||||
@CATS.register_module()
|
||||
class Munchkin:
|
||||
pass
|
||||
|
||||
repr_str = 'Registry(name=cat, items={'
|
||||
repr_str += (
|
||||
"'BritishShorthair': <class 'test_registry.TestRegistry.test_repr."
|
||||
"<locals>.BritishShorthair'>, ")
|
||||
repr_str += (
|
||||
"'Munchkin': <class 'test_registry.TestRegistry.test_repr."
|
||||
"<locals>.Munchkin'>")
|
||||
repr_str += '})'
|
||||
assert repr(CATS) == repr_str
|
||||
|
||||
|
||||
def test_build_from_cfg():
|
||||
BACKBONES = Registry('backbone')
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet:
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNeXt:
|
||||
|
||||
def __init__(self, depth, stages=4):
|
||||
self.depth = depth
|
||||
self.stages = stages
|
||||
|
||||
# test `cfg` parameter
|
||||
# `cfg` should be a dict, ConfigDict or Config object
|
||||
with pytest.raises(
|
||||
TypeError, match="cfg must be a dict, but got <class 'str'>"):
|
||||
cfg = 'ResNet'
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` is a dict
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# `cfg` is a ConfigDict object
|
||||
cfg = ConfigDict(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# `cfg` is a Config object
|
||||
cfg = Config(dict(type='ResNet', depth=50))
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# `cfg` is a dict but it does not contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = dict(depth=50, stages=4)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# cfg['type'] should be a str or class
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="type must be a str or valid type, but got <class 'int'>"):
|
||||
cfg = dict(type=1000)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNeXt)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
cfg = dict(type=ResNet, depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# non-registered class
|
||||
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
|
||||
cfg = dict(type='VGG')
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# `cfg` contains unexpected arguments
|
||||
with pytest.raises(TypeError):
|
||||
cfg = dict(type='ResNet', non_existing_arg=50)
|
||||
model = build_from_cfg(cfg, BACKBONES)
|
||||
|
||||
# test `default_args` parameter
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 3
|
||||
|
||||
# default_args must be a dict or None
|
||||
with pytest.raises(TypeError):
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=1)
|
||||
|
||||
# cfg or default_args should contain the key "type"
|
||||
with pytest.raises(KeyError, match='must contain the key "type"'):
|
||||
cfg = dict(depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(stages=4))
|
||||
|
||||
# "type" defined using default_args
|
||||
cfg = dict(depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(type='ResNet'))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
cfg = dict(depth=50)
|
||||
model = build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
|
||||
assert isinstance(model, ResNet)
|
||||
assert model.depth == 50 and model.stages == 4
|
||||
|
||||
# test `registry` parameter
|
||||
# incorrect registry type
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=('registry must be a mmengine.Registry object, but got '
|
||||
"<class 'str'>")):
|
||||
cfg = dict(type='ResNet', depth=50)
|
||||
model = build_from_cfg(cfg, 'BACKBONES')
|
Loading…
x
Reference in New Issue
Block a user