mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support registering partial functions and more (#595)
* support registering partial functions * Update mmengine/registry/build_functions.py Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Update mmengine/registry/registry.py Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Revert unit test and refine * add current logger and set log level --------- Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
This commit is contained in:
parent
f76218a489
commit
b2ad2210b5
@ -104,7 +104,8 @@ def build_from_cfg(
|
|||||||
'can be found at '
|
'can be found at '
|
||||||
'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
|
'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
|
||||||
)
|
)
|
||||||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
# this will include classes, functions, partial functions and more
|
||||||
|
elif callable(obj_type):
|
||||||
obj_cls = obj_type
|
obj_cls = obj_type
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -120,12 +121,20 @@ def build_from_cfg(
|
|||||||
else:
|
else:
|
||||||
obj = obj_cls(**args) # type: ignore
|
obj = obj_cls(**args) # type: ignore
|
||||||
|
|
||||||
print_log(
|
if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls)
|
||||||
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
or inspect.ismethod(obj_cls)):
|
||||||
'registry, its implementation can be found in '
|
print_log(
|
||||||
f'{obj_cls.__module__}', # type: ignore
|
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
||||||
logger='current',
|
'registry, and its implementation can be found in '
|
||||||
level=logging.DEBUG)
|
f'{obj_cls.__module__}', # type: ignore
|
||||||
|
logger='current',
|
||||||
|
level=logging.DEBUG)
|
||||||
|
else:
|
||||||
|
print_log(
|
||||||
|
'An instance is built from registry, and its constructor '
|
||||||
|
f'is {obj_cls}',
|
||||||
|
logger='current',
|
||||||
|
level=logging.DEBUG)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -487,8 +487,11 @@ class Registry:
|
|||||||
obj_cls = root.get(key)
|
obj_cls = root.get(key)
|
||||||
|
|
||||||
if obj_cls is not None:
|
if obj_cls is not None:
|
||||||
|
# For some rare cases (e.g. obj_cls is a partial function), obj_cls
|
||||||
|
# doesn't have `__name__`. Use default value to prevent error
|
||||||
|
cls_name = getattr(obj_cls, '__name__', str(obj_cls))
|
||||||
print_log(
|
print_log(
|
||||||
f'Get class `{obj_cls.__name__}` from "{registry_name}"'
|
f'Get class `{cls_name}` from "{registry_name}"'
|
||||||
f' registry in "{scope_name}"',
|
f' registry in "{scope_name}"',
|
||||||
logger='current',
|
logger='current',
|
||||||
level=logging.DEBUG)
|
level=logging.DEBUG)
|
||||||
@ -565,16 +568,16 @@ class Registry:
|
|||||||
"""Register a module.
|
"""Register a module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (type): Module class or function to be registered.
|
module (type): Module to be registered. Typically a class or a
|
||||||
|
function, but generally all ``Callable`` are acceptable.
|
||||||
module_name (str or list of str, optional): The module name to be
|
module_name (str or list of str, optional): The module name to be
|
||||||
registered. If not specified, the class name will be used.
|
registered. If not specified, the class name will be used.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
force (bool): Whether to override an existing class with the same
|
force (bool): Whether to override an existing class with the same
|
||||||
name. Defaults to False.
|
name. Defaults to False.
|
||||||
"""
|
"""
|
||||||
if not inspect.isclass(module) and not inspect.isfunction(module):
|
if not callable(module):
|
||||||
raise TypeError('module must be a class or a function, '
|
raise TypeError(f'module must be Callable, but got {type(module)}')
|
||||||
f'but got {type(module)}')
|
|
||||||
|
|
||||||
if module_name is None:
|
if module_name is None:
|
||||||
module_name = module.__name__
|
module_name = module.__name__
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import functools
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -59,23 +60,12 @@ class TestRegistry:
|
|||||||
CATS = Registry('cat')
|
CATS = Registry('cat')
|
||||||
|
|
||||||
@CATS.register_module()
|
@CATS.register_module()
|
||||||
def muchkin():
|
def muchkin(size):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert CATS.get('muchkin') is muchkin
|
assert CATS.get('muchkin') is muchkin
|
||||||
assert 'muchkin' in CATS
|
assert 'muchkin' in CATS
|
||||||
|
|
||||||
# can only decorate a class or a function
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
|
|
||||||
class Demo:
|
|
||||||
|
|
||||||
def some_method(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
method = Demo().some_method
|
|
||||||
CATS.register_module(name='some_method', module=method)
|
|
||||||
|
|
||||||
# test `name` parameter which must be either of None, a string or a
|
# test `name` parameter which must be either of None, a string or a
|
||||||
# sequence of string
|
# sequence of string
|
||||||
# `name` is None
|
# `name` is None
|
||||||
@ -146,7 +136,7 @@ class TestRegistry:
|
|||||||
# decorator, which must be a class
|
# decorator, which must be a class
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
TypeError,
|
TypeError,
|
||||||
match='module must be a class or a function,'
|
match='module must be Callable,'
|
||||||
" but got <class 'str'>"):
|
" but got <class 'str'>"):
|
||||||
CATS.register_module(module='string')
|
CATS.register_module(module='string')
|
||||||
|
|
||||||
@ -166,6 +156,17 @@ class TestRegistry:
|
|||||||
assert CATS.get('Sphynx3') is SphynxCat
|
assert CATS.get('Sphynx3') is SphynxCat
|
||||||
assert len(CATS) == 9
|
assert len(CATS) == 9
|
||||||
|
|
||||||
|
# partial functions can be registered
|
||||||
|
muchkin0 = functools.partial(muchkin, size=0)
|
||||||
|
CATS.register_module('muchkin0', False, muchkin0)
|
||||||
|
# lambda functions can be registered
|
||||||
|
CATS.register_module(name='unknown cat', module=lambda: 'unknown')
|
||||||
|
|
||||||
|
assert CATS.get('muchkin0') is muchkin0
|
||||||
|
assert 'unknown cat' in CATS
|
||||||
|
assert 'muchkin0' in CATS
|
||||||
|
assert len(CATS) == 11
|
||||||
|
|
||||||
def _build_registry(self):
|
def _build_registry(self):
|
||||||
"""A helper function to build a Hierarchical Registry."""
|
"""A helper function to build a Hierarchical Registry."""
|
||||||
# Hierarchical Registry
|
# Hierarchical Registry
|
||||||
@ -227,12 +228,21 @@ class TestRegistry:
|
|||||||
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
|
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
|
||||||
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
|
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
|
||||||
|
|
||||||
|
@DOGS.register_module()
|
||||||
|
def bark(word, times):
|
||||||
|
return [word] * times
|
||||||
|
|
||||||
|
dog_bark = functools.partial(bark, 'woof')
|
||||||
|
DOGS.register_module('dog_bark', False, dog_bark)
|
||||||
|
|
||||||
@DOGS.register_module()
|
@DOGS.register_module()
|
||||||
class GoldenRetriever:
|
class GoldenRetriever:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert len(DOGS) == 1
|
assert len(DOGS) == 3
|
||||||
assert DOGS.get('GoldenRetriever') is GoldenRetriever
|
assert DOGS.get('GoldenRetriever') is GoldenRetriever
|
||||||
|
assert DOGS.get('bark') is bark
|
||||||
|
assert DOGS.get('dog_bark') is dog_bark
|
||||||
|
|
||||||
@HOUNDS.register_module()
|
@HOUNDS.register_module()
|
||||||
class BloodHound:
|
class BloodHound:
|
||||||
@ -249,6 +259,8 @@ class TestRegistry:
|
|||||||
# If the key is not found in the current registry, then look for its
|
# If the key is not found in the current registry, then look for its
|
||||||
# parent
|
# parent
|
||||||
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
|
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
|
||||||
|
assert HOUNDS.get('bark') is bark
|
||||||
|
assert HOUNDS.get('dog_bark') is dog_bark
|
||||||
|
|
||||||
@LITTLE_HOUNDS.register_module()
|
@LITTLE_HOUNDS.register_module()
|
||||||
class Dachshund:
|
class Dachshund:
|
||||||
@ -340,11 +352,14 @@ class TestRegistry:
|
|||||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
|
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
|
||||||
|
|
||||||
@DOGS.register_module()
|
@DOGS.register_module()
|
||||||
def bark(times=1):
|
def bark(word, times):
|
||||||
return ' '.join(['woof'] * times)
|
return ' '.join([word] * times)
|
||||||
|
|
||||||
bark_cfg = cfg_type(dict(type='bark', times=3))
|
dog_bark = functools.partial(bark, word='woof')
|
||||||
assert DOGS.build(bark_cfg) == 'woof woof woof'
|
DOGS.register_module('dog_bark', False, dog_bark)
|
||||||
|
|
||||||
|
bark_cfg = cfg_type(dict(type='bark', word='meow', times=3))
|
||||||
|
dog_bark_cfg = cfg_type(dict(type='dog_bark', times=3))
|
||||||
|
|
||||||
@DOGS.register_module()
|
@DOGS.register_module()
|
||||||
class GoldenRetriever:
|
class GoldenRetriever:
|
||||||
@ -352,6 +367,8 @@ class TestRegistry:
|
|||||||
|
|
||||||
gr_cfg = cfg_type(dict(type='GoldenRetriever'))
|
gr_cfg = cfg_type(dict(type='GoldenRetriever'))
|
||||||
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
|
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
|
||||||
|
assert DOGS.build(bark_cfg) == 'meow meow meow'
|
||||||
|
assert DOGS.build(dog_bark_cfg) == 'woof woof woof'
|
||||||
|
|
||||||
@HOUNDS.register_module()
|
@HOUNDS.register_module()
|
||||||
class BloodHound:
|
class BloodHound:
|
||||||
@ -360,6 +377,8 @@ class TestRegistry:
|
|||||||
bh_cfg = cfg_type(dict(type='BloodHound'))
|
bh_cfg = cfg_type(dict(type='BloodHound'))
|
||||||
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
|
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
|
||||||
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
|
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
|
||||||
|
assert HOUNDS.build(bark_cfg) == 'meow meow meow'
|
||||||
|
assert HOUNDS.build(dog_bark_cfg) == 'woof woof woof'
|
||||||
|
|
||||||
@LITTLE_HOUNDS.register_module()
|
@LITTLE_HOUNDS.register_module()
|
||||||
class Dachshund:
|
class Dachshund:
|
||||||
@ -419,6 +438,18 @@ class TestRegistry:
|
|||||||
assert isinstance(dog.friend, YourSamoyed)
|
assert isinstance(dog.friend, YourSamoyed)
|
||||||
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
||||||
|
|
||||||
|
# build an instance by lambda or partial function.
|
||||||
|
lambda_dog = lambda name: name # noqa: E731
|
||||||
|
DOGS.register_module(name='lambda_dog', module=lambda_dog)
|
||||||
|
lambda_cfg = cfg_type(dict(type='lambda_dog', name='unknown'))
|
||||||
|
assert DOGS.build(lambda_cfg) == 'unknown'
|
||||||
|
|
||||||
|
DOGS.register_module(
|
||||||
|
name='patial dog',
|
||||||
|
module=functools.partial(lambda_dog, name='patial'))
|
||||||
|
unknown_cfg = cfg_type(dict(type='patial dog'))
|
||||||
|
assert DOGS.build(unknown_cfg) == 'patial'
|
||||||
|
|
||||||
def test_switch_scope_and_registry(self):
|
def test_switch_scope_and_registry(self):
|
||||||
DOGS = Registry('dogs')
|
DOGS = Registry('dogs')
|
||||||
HOUNDS = Registry('hounds', scope='hound', parent=DOGS)
|
HOUNDS = Registry('hounds', scope='hound', parent=DOGS)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user