[Feature] Support registering partial functions and more (#595)

* support registering partial functions

* Update mmengine/registry/build_functions.py

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Update mmengine/registry/registry.py

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Revert unit test and refine

* add current logger and set log level

---------

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
This commit is contained in:
Qian Zhao 2023-04-10 19:42:04 +08:00 committed by GitHub
parent f76218a489
commit b2ad2210b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 30 deletions

View File

@ -104,7 +104,8 @@ def build_from_cfg(
'can be found at ' 'can be found at '
'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
) )
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): # this will include classes, functions, partial functions and more
elif callable(obj_type):
obj_cls = obj_type obj_cls = obj_type
else: else:
raise TypeError( raise TypeError(
@ -120,12 +121,20 @@ def build_from_cfg(
else: else:
obj = obj_cls(**args) # type: ignore obj = obj_cls(**args) # type: ignore
print_log( if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls)
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 or inspect.ismethod(obj_cls)):
'registry, its implementation can be found in ' print_log(
f'{obj_cls.__module__}', # type: ignore f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
logger='current', 'registry, and its implementation can be found in '
level=logging.DEBUG) f'{obj_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
else:
print_log(
'An instance is built from registry, and its constructor '
f'is {obj_cls}',
logger='current',
level=logging.DEBUG)
return obj return obj
except Exception as e: except Exception as e:

View File

@ -487,8 +487,11 @@ class Registry:
obj_cls = root.get(key) obj_cls = root.get(key)
if obj_cls is not None: if obj_cls is not None:
# For some rare cases (e.g. obj_cls is a partial function), obj_cls
# doesn't have `__name__`. Use default value to prevent error
cls_name = getattr(obj_cls, '__name__', str(obj_cls))
print_log( print_log(
f'Get class `{obj_cls.__name__}` from "{registry_name}"' f'Get class `{cls_name}` from "{registry_name}"'
f' registry in "{scope_name}"', f' registry in "{scope_name}"',
logger='current', logger='current',
level=logging.DEBUG) level=logging.DEBUG)
@ -565,16 +568,16 @@ class Registry:
"""Register a module. """Register a module.
Args: Args:
module (type): Module class or function to be registered. module (type): Module to be registered. Typically a class or a
function, but generally all ``Callable`` are acceptable.
module_name (str or list of str, optional): The module name to be module_name (str or list of str, optional): The module name to be
registered. If not specified, the class name will be used. registered. If not specified, the class name will be used.
Defaults to None. Defaults to None.
force (bool): Whether to override an existing class with the same force (bool): Whether to override an existing class with the same
name. Defaults to False. name. Defaults to False.
""" """
if not inspect.isclass(module) and not inspect.isfunction(module): if not callable(module):
raise TypeError('module must be a class or a function, ' raise TypeError(f'module must be Callable, but got {type(module)}')
f'but got {type(module)}')
if module_name is None: if module_name is None:
module_name = module.__name__ module_name = module.__name__

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import functools
import time import time
import pytest import pytest
@ -59,23 +60,12 @@ class TestRegistry:
CATS = Registry('cat') CATS = Registry('cat')
@CATS.register_module() @CATS.register_module()
def muchkin(): def muchkin(size):
pass pass
assert CATS.get('muchkin') is muchkin assert CATS.get('muchkin') is muchkin
assert 'muchkin' in CATS assert 'muchkin' in CATS
# can only decorate a class or a function
with pytest.raises(TypeError):
class Demo:
def some_method(self):
pass
method = Demo().some_method
CATS.register_module(name='some_method', module=method)
# test `name` parameter which must be either of None, a string or a # test `name` parameter which must be either of None, a string or a
# sequence of string # sequence of string
# `name` is None # `name` is None
@ -146,7 +136,7 @@ class TestRegistry:
# decorator, which must be a class # decorator, which must be a class
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match='module must be a class or a function,' match='module must be Callable,'
" but got <class 'str'>"): " but got <class 'str'>"):
CATS.register_module(module='string') CATS.register_module(module='string')
@ -166,6 +156,17 @@ class TestRegistry:
assert CATS.get('Sphynx3') is SphynxCat assert CATS.get('Sphynx3') is SphynxCat
assert len(CATS) == 9 assert len(CATS) == 9
# partial functions can be registered
muchkin0 = functools.partial(muchkin, size=0)
CATS.register_module('muchkin0', False, muchkin0)
# lambda functions can be registered
CATS.register_module(name='unknown cat', module=lambda: 'unknown')
assert CATS.get('muchkin0') is muchkin0
assert 'unknown cat' in CATS
assert 'muchkin0' in CATS
assert len(CATS) == 11
def _build_registry(self): def _build_registry(self):
"""A helper function to build a Hierarchical Registry.""" """A helper function to build a Hierarchical Registry."""
# Hierarchical Registry # Hierarchical Registry
@ -227,12 +228,21 @@ class TestRegistry:
DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3]
MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:] MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:]
@DOGS.register_module()
def bark(word, times):
return [word] * times
dog_bark = functools.partial(bark, 'woof')
DOGS.register_module('dog_bark', False, dog_bark)
@DOGS.register_module() @DOGS.register_module()
class GoldenRetriever: class GoldenRetriever:
pass pass
assert len(DOGS) == 1 assert len(DOGS) == 3
assert DOGS.get('GoldenRetriever') is GoldenRetriever assert DOGS.get('GoldenRetriever') is GoldenRetriever
assert DOGS.get('bark') is bark
assert DOGS.get('dog_bark') is dog_bark
@HOUNDS.register_module() @HOUNDS.register_module()
class BloodHound: class BloodHound:
@ -249,6 +259,8 @@ class TestRegistry:
# If the key is not found in the current registry, then look for its # If the key is not found in the current registry, then look for its
# parent # parent
assert HOUNDS.get('GoldenRetriever') is GoldenRetriever assert HOUNDS.get('GoldenRetriever') is GoldenRetriever
assert HOUNDS.get('bark') is bark
assert HOUNDS.get('dog_bark') is dog_bark
@LITTLE_HOUNDS.register_module() @LITTLE_HOUNDS.register_module()
class Dachshund: class Dachshund:
@ -340,11 +352,14 @@ class TestRegistry:
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5] DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
@DOGS.register_module() @DOGS.register_module()
def bark(times=1): def bark(word, times):
return ' '.join(['woof'] * times) return ' '.join([word] * times)
bark_cfg = cfg_type(dict(type='bark', times=3)) dog_bark = functools.partial(bark, word='woof')
assert DOGS.build(bark_cfg) == 'woof woof woof' DOGS.register_module('dog_bark', False, dog_bark)
bark_cfg = cfg_type(dict(type='bark', word='meow', times=3))
dog_bark_cfg = cfg_type(dict(type='dog_bark', times=3))
@DOGS.register_module() @DOGS.register_module()
class GoldenRetriever: class GoldenRetriever:
@ -352,6 +367,8 @@ class TestRegistry:
gr_cfg = cfg_type(dict(type='GoldenRetriever')) gr_cfg = cfg_type(dict(type='GoldenRetriever'))
assert isinstance(DOGS.build(gr_cfg), GoldenRetriever) assert isinstance(DOGS.build(gr_cfg), GoldenRetriever)
assert DOGS.build(bark_cfg) == 'meow meow meow'
assert DOGS.build(dog_bark_cfg) == 'woof woof woof'
@HOUNDS.register_module() @HOUNDS.register_module()
class BloodHound: class BloodHound:
@ -360,6 +377,8 @@ class TestRegistry:
bh_cfg = cfg_type(dict(type='BloodHound')) bh_cfg = cfg_type(dict(type='BloodHound'))
assert isinstance(HOUNDS.build(bh_cfg), BloodHound) assert isinstance(HOUNDS.build(bh_cfg), BloodHound)
assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever) assert isinstance(HOUNDS.build(gr_cfg), GoldenRetriever)
assert HOUNDS.build(bark_cfg) == 'meow meow meow'
assert HOUNDS.build(dog_bark_cfg) == 'woof woof woof'
@LITTLE_HOUNDS.register_module() @LITTLE_HOUNDS.register_module()
class Dachshund: class Dachshund:
@ -419,6 +438,18 @@ class TestRegistry:
assert isinstance(dog.friend, YourSamoyed) assert isinstance(dog.friend, YourSamoyed)
assert DefaultScope.get_current_instance().scope_name != 'samoyed' assert DefaultScope.get_current_instance().scope_name != 'samoyed'
# build an instance by lambda or partial function.
lambda_dog = lambda name: name # noqa: E731
DOGS.register_module(name='lambda_dog', module=lambda_dog)
lambda_cfg = cfg_type(dict(type='lambda_dog', name='unknown'))
assert DOGS.build(lambda_cfg) == 'unknown'
DOGS.register_module(
name='patial dog',
module=functools.partial(lambda_dog, name='patial'))
unknown_cfg = cfg_type(dict(type='patial dog'))
assert DOGS.build(unknown_cfg) == 'patial'
def test_switch_scope_and_registry(self): def test_switch_scope_and_registry(self):
DOGS = Registry('dogs') DOGS = Registry('dogs')
HOUNDS = Registry('hounds', scope='hound', parent=DOGS) HOUNDS = Registry('hounds', scope='hound', parent=DOGS)