[Feature] Support registering partial functions and more (#595)
* support registering partial functions * Update mmengine/registry/build_functions.py Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Update mmengine/registry/registry.py Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> * Revert unit test and refine * add current logger and set log level --------- Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>pull/912/head
parent
f76218a489
commit
b2ad2210b5
|
@ -104,7 +104,8 @@ def build_from_cfg(
|
|||
'can be found at '
|
||||
'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
|
||||
)
|
||||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
||||
# this will include classes, functions, partial functions and more
|
||||
elif callable(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
|
@ -120,12 +121,20 @@ def build_from_cfg(
|
|||
else:
|
||||
obj = obj_cls(**args) # type: ignore
|
||||
|
||||
print_log(
|
||||
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
||||
'registry, its implementation can be found in '
|
||||
f'{obj_cls.__module__}', # type: ignore
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls)
|
||||
or inspect.ismethod(obj_cls)):
|
||||
print_log(
|
||||
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
|
||||
'registry, and its implementation can be found in '
|
||||
f'{obj_cls.__module__}', # type: ignore
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
else:
|
||||
print_log(
|
||||
'An instance is built from registry, and its constructor '
|
||||
f'is {obj_cls}',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
return obj
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -487,8 +487,11 @@ class Registry:
|
|||
obj_cls = root.get(key)
|
||||
|
||||
if obj_cls is not None:
|
||||
# For some rare cases (e.g. obj_cls is a partial function), obj_cls
|
||||
# doesn't have `__name__`. Use default value to prevent error
|
||||
cls_name = getattr(obj_cls, '__name__', str(obj_cls))
|
||||
print_log(
|
||||
f'Get class `{obj_cls.__name__}` from "{registry_name}"'
|
||||
f'Get class `{cls_name}` from "{registry_name}"'
|
||||
f' registry in "{scope_name}"',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
|
@ -565,16 +568,16 @@ class Registry:
|
|||
"""Register a module.
|
||||
|
||||
Args:
|
||||
module (type): Module class or function to be registered.
|
||||
module (type): Module to be registered. Typically a class or a
|
||||
function, but generally all ``Callable`` are acceptable.
|
||||
module_name (str or list of str, optional): The module name to be
|
||||
registered. If not specified, the class name will be used.
|
||||
Defaults to None.
|
||||
force (bool): Whether to override an existing class with the same
|
||||
name. Defaults to False.
|
||||
"""
|
||||
if not inspect.isclass(module) and not inspect.isfunction(module):
|
||||
raise TypeError('module must be a class or a function, '
|
||||
f'but got {type(module)}')
|
||||
if not callable(module):
|
||||
raise TypeError(f'module must be Callable, but got {type(module)}')
|
||||
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
@ -59,23 +60,12 @@ class TestRegistry:
|
|||
CATS = Registry('cat')
|
||||
|
||||
@CATS.register_module()
|
||||
def muchkin():
|
||||
def muchkin(size):
|
||||
pass
|
||||
|
||||
assert CATS.get('muchkin') is muchkin
|
||||
assert 'muchkin' in CATS
|
||||
|
||||
# can only decorate a class or a function
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class Demo:
|
||||
|
||||
def some_method(self):
|
||||
pass
|
||||
|
||||
method = Demo().some_method
|
||||
CATS.register_module(name='some_method', module=method)
|
||||
|
||||
# test `name` parameter which must be either of None, a string or a
|
||||
# sequence of string
|
||||
# `name` is None
|
||||
|
@ -146,7 +136,7 @@ class TestRegistry:
|
|||
# decorator, which must be a class
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match='module must be a class or a function,'
|
||||
match='module must be Callable,'
|
||||
" but got <class 'str'>"):
|
||||
CATS.register_module(module='string')
|
||||
|
||||
|
@ -166,6 +156,17 @@ class TestRegistry:
|
|||
assert CATS.get('Sphynx3') is SphynxCat
|
||||
assert len(CATS) == 9
|
||||
|
||||
# partial functions can be registered
|
||||
muchkin0 = functools.partial(muchkin, size=0)
|
||||
CATS.register_module('muchkin0', False, muchkin0)
|
||||
# lambda functions can be registered
|
||||
CATS.register_module(name='unknown cat', module=lambda: 'unknown')
|
||||
|
||||
assert CATS.get('muchkin0') is muchkin0
|
||||
assert 'unknown cat' in CATS
|
||||
assert 'muchkin0' in CATS
|
||||
assert len(CATS) == 11
|
||||
|
||||
def _build_registry(self):
|
||||
"""A helper function to build a Hierarchical Registry."""
|
||||
# Hierarchical Registry
|
||||
|
@ -227,12 +228,21 @@ class TestRegistry:
|
|||
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
|
||||
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
|
||||
|
||||
@DOGS.register_module()
|
||||
def bark(word, times):
|
||||
return [word] * times
|
||||
|
||||
dog_bark = functools.partial(bark, 'woof')
|
||||
DOGS.register_module('dog_bark', False, dog_bark)
|
||||
|
||||
@DOGS.register_module()
|
||||
class GoldenRetriever:
|
||||
pass
|
||||
|
||||
assert len(DOGS) == 1
|
||||
assert len(DOGS) == 3
|
||||
assert DOGS.get('GoldenRetriever') is GoldenRetriever
|
||||
assert DOGS.get('bark') is bark
|
||||
assert DOGS.get('dog_bark') is dog_bark
|
||||
|
||||
@HOUNDS.register_module()
|
||||
class BloodHound:
|
||||
|
@ -249,6 +259,8 @@ class TestRegistry:
|
|||
# If the key is not found in the current registry, then look for its
|
||||
# parent
|
||||
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
|
||||
assert HOUNDS.get('bark') is bark
|
||||
assert HOUNDS.get('dog_bark') is dog_bark
|
||||
|
||||
@LITTLE_HOUNDS.register_module()
|
||||
class Dachshund:
|
||||
|
@ -340,11 +352,14 @@ class TestRegistry:
|
|||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
|
||||
|
||||
@DOGS.register_module()
|
||||
def bark(times=1):
|
||||
return ' '.join(['woof'] * times)
|
||||
def bark(word, times):
|
||||
return ' '.join([word] * times)
|
||||
|
||||
bark_cfg = cfg_type(dict(type='bark', times=3))
|
||||
assert DOGS.build(bark_cfg) == 'woof woof woof'
|
||||
dog_bark = functools.partial(bark, word='woof')
|
||||
DOGS.register_module('dog_bark', False, dog_bark)
|
||||
|
||||
bark_cfg = cfg_type(dict(type='bark', word='meow', times=3))
|
||||
dog_bark_cfg = cfg_type(dict(type='dog_bark', times=3))
|
||||
|
||||
@DOGS.register_module()
|
||||
class GoldenRetriever:
|
||||
|
@ -352,6 +367,8 @@ class TestRegistry:
|
|||
|
||||
gr_cfg = cfg_type(dict(type='GoldenRetriever'))
|
||||
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
|
||||
assert DOGS.build(bark_cfg) == 'meow meow meow'
|
||||
assert DOGS.build(dog_bark_cfg) == 'woof woof woof'
|
||||
|
||||
@HOUNDS.register_module()
|
||||
class BloodHound:
|
||||
|
@ -360,6 +377,8 @@ class TestRegistry:
|
|||
bh_cfg = cfg_type(dict(type='BloodHound'))
|
||||
assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
|
||||
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
|
||||
assert HOUNDS.build(bark_cfg) == 'meow meow meow'
|
||||
assert HOUNDS.build(dog_bark_cfg) == 'woof woof woof'
|
||||
|
||||
@LITTLE_HOUNDS.register_module()
|
||||
class Dachshund:
|
||||
|
@ -419,6 +438,18 @@ class TestRegistry:
|
|||
assert isinstance(dog.friend, YourSamoyed)
|
||||
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
||||
|
||||
# build an instance by lambda or partial function.
|
||||
lambda_dog = lambda name: name # noqa: E731
|
||||
DOGS.register_module(name='lambda_dog', module=lambda_dog)
|
||||
lambda_cfg = cfg_type(dict(type='lambda_dog', name='unknown'))
|
||||
assert DOGS.build(lambda_cfg) == 'unknown'
|
||||
|
||||
DOGS.register_module(
|
||||
name='patial dog',
|
||||
module=functools.partial(lambda_dog, name='patial'))
|
||||
unknown_cfg = cfg_type(dict(type='patial dog'))
|
||||
assert DOGS.build(unknown_cfg) == 'patial'
|
||||
|
||||
def test_switch_scope_and_registry(self):
|
||||
DOGS = Registry('dogs')
|
||||
HOUNDS = Registry('hounds', scope='hound', parent=DOGS)
|
||||
|
|
Loading…
Reference in New Issue