mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support overwrite default scope with "_scope_". (#275)
* [Feature] Support overwrite default scope with "_scope_". * add ut * add ut
This commit is contained in:
parent
7a5d3c83ea
commit
2f16ec69fb
@ -1,5 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
import copy
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
|
||||
from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock
|
||||
|
||||
@ -71,3 +74,17 @@ class DefaultScope(ManagerMixin):
|
||||
instance = None
|
||||
_release_lock()
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
|
||||
"""overwrite the current default scope with `scope_name`"""
|
||||
if scope_name is None:
|
||||
yield
|
||||
else:
|
||||
tmp = copy.deepcopy(cls._instance_dict)
|
||||
cls.get_instance(f'overwrite-{time.time()}', scope_name=scope_name)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cls._instance_dict = tmp
|
||||
|
@ -468,7 +468,7 @@ class Registry:
|
||||
|
||||
return None
|
||||
|
||||
def build(self, *args, **kwargs) -> Any:
|
||||
def build(self, cfg, *args, **kwargs) -> Any:
|
||||
"""Build an instance.
|
||||
|
||||
Build an instance by calling :attr:`build_func`. If the global
|
||||
@ -487,27 +487,28 @@ class Registry:
|
||||
>>> cfg = dict(type='ResNet', depth=50)
|
||||
>>> model = MODELS.build(cfg)
|
||||
"""
|
||||
# get the global default scope
|
||||
default_scope = DefaultScope.get_current_instance()
|
||||
if default_scope is not None:
|
||||
scope_name = default_scope.scope_name
|
||||
root = self._get_root_registry()
|
||||
registry = root._search_child(scope_name)
|
||||
if registry is None:
|
||||
# if `default_scope` can not be found, fallback to use self
|
||||
warnings.warn(
|
||||
f'Failed to search registry with scope "{scope_name}" in '
|
||||
f'the "{root.name}" registry tree. '
|
||||
f'As a workaround, the current "{self.name}" registry in '
|
||||
f'"{self.scope}" is used to build instance. This may '
|
||||
f'cause unexpected failure when running the built '
|
||||
f'modules. Please check whether "{scope_name}" is a '
|
||||
f'correct scope, or whether the registry is initialized.')
|
||||
with DefaultScope.overwrite_default_scope(cfg.pop('_scope_', None)):
|
||||
# get the global default scope
|
||||
default_scope = DefaultScope.get_current_instance()
|
||||
if default_scope is not None:
|
||||
scope_name = default_scope.scope_name
|
||||
root = self._get_root_registry()
|
||||
registry = root._search_child(scope_name)
|
||||
if registry is None:
|
||||
# if `default_scope` can not be found, fallback to use self
|
||||
warnings.warn(
|
||||
f'Failed to search registry with scope "{scope_name}" '
|
||||
f'in the "{root.name}" registry tree. '
|
||||
f'As a workaround, the current "{self.name}" registry '
|
||||
f'in "{self.scope}" is used to build instance. This '
|
||||
f'may cause unexpected failure when running the built '
|
||||
f'modules. Please check whether "{scope_name}" is a '
|
||||
f'correct scope, or whether the registry is '
|
||||
f'initialized.')
|
||||
registry = self
|
||||
else:
|
||||
registry = self
|
||||
else:
|
||||
registry = self
|
||||
|
||||
return registry.build_func(*args, **kwargs, registry=registry)
|
||||
return registry.build_func(cfg, *args, **kwargs, registry=registry)
|
||||
|
||||
def _add_child(self, registry: 'Registry') -> None:
|
||||
"""Add a child for a registry.
|
||||
|
@ -21,3 +21,15 @@ class TestDefaultScope:
|
||||
DefaultScope.get_instance('instance_name', scope_name='mmengine')
|
||||
default_scope = DefaultScope.get_current_instance()
|
||||
assert default_scope.scope_name == 'mmengine'
|
||||
|
||||
def test_overwrite_default_scope(self):
|
||||
origin_scope = DefaultScope.get_instance(
|
||||
'test_overwrite_default_scope', scope_name='origin_scope')
|
||||
with DefaultScope.overwrite_default_scope(scope_name=None):
|
||||
assert DefaultScope.get_current_instance(
|
||||
).scope_name == 'origin_scope'
|
||||
with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'):
|
||||
assert DefaultScope.get_current_instance(
|
||||
).scope_name == 'test_overwrite'
|
||||
assert DefaultScope.get_current_instance(
|
||||
).scope_name == origin_scope.scope_name == 'origin_scope'
|
||||
|
@ -167,16 +167,17 @@ class TestRegistry:
|
||||
registries = []
|
||||
DOGS = Registry('dogs')
|
||||
registries.append(DOGS)
|
||||
HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
|
||||
HOUNDS = Registry('hounds', parent=DOGS, scope='hound')
|
||||
registries.append(HOUNDS)
|
||||
LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound')
|
||||
LITTLE_HOUNDS = Registry(
|
||||
'little hounds', parent=HOUNDS, scope='little_hound')
|
||||
registries.append(LITTLE_HOUNDS)
|
||||
MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
|
||||
MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound')
|
||||
registries.append(MID_HOUNDS)
|
||||
SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
|
||||
SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed')
|
||||
registries.append(SAMOYEDS)
|
||||
LITTLE_SAMOYEDS = Registry(
|
||||
'dogs', parent=SAMOYEDS, scope='little_samoyed')
|
||||
'little samoyeds', parent=SAMOYEDS, scope='little_samoyed')
|
||||
registries.append(LITTLE_SAMOYEDS)
|
||||
|
||||
return registries
|
||||
@ -323,7 +324,7 @@ class TestRegistry:
|
||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||
# (little_hound) (mid_hound) (little_samoyed)
|
||||
registries = self._build_registry()
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
|
||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
|
||||
|
||||
@DOGS.register_module()
|
||||
class GoldenRetriever:
|
||||
@ -367,6 +368,37 @@ class TestRegistry:
|
||||
dog = MID_HOUNDS.build(b_cfg)
|
||||
assert isinstance(dog, Beagle)
|
||||
|
||||
# test overwrite default scope with `_scope_`
|
||||
@SAMOYEDS.register_module()
|
||||
class MySamoyed:
|
||||
|
||||
def __init__(self, friend):
|
||||
self.friend = DOGS.build(friend)
|
||||
|
||||
@SAMOYEDS.register_module()
|
||||
class YourSamoyed:
|
||||
pass
|
||||
|
||||
s_cfg = cfg_type(
|
||||
dict(
|
||||
_scope_='samoyed',
|
||||
type='MySamoyed',
|
||||
friend=dict(type='hound.BloodHound')))
|
||||
dog = DOGS.build(s_cfg)
|
||||
assert isinstance(dog, MySamoyed)
|
||||
assert isinstance(dog.friend, BloodHound)
|
||||
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
||||
|
||||
s_cfg = cfg_type(
|
||||
dict(
|
||||
_scope_='samoyed',
|
||||
type='MySamoyed',
|
||||
friend=dict(type='YourSamoyed')))
|
||||
dog = DOGS.build(s_cfg)
|
||||
assert isinstance(dog, MySamoyed)
|
||||
assert isinstance(dog.friend, YourSamoyed)
|
||||
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
||||
|
||||
def test_repr(self):
|
||||
CATS = Registry('cat')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user