[Feature] Support overwrite default scope with "_scope_". (#275)

* [Feature] Support overwrite default scope with "_scope_".

* add ut

* add ut
This commit is contained in:
RangiLyu 2022-06-09 20:16:31 +08:00 committed by GitHub
parent 7a5d3c83ea
commit 2f16ec69fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 28 deletions

View File

@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional import copy
import time
from contextlib import contextmanager
from typing import Generator, Optional
from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock
@ -71,3 +74,17 @@ class DefaultScope(ManagerMixin):
instance = None instance = None
_release_lock() _release_lock()
return instance return instance
@classmethod
@contextmanager
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
"""overwrite the current default scope with `scope_name`"""
if scope_name is None:
yield
else:
tmp = copy.deepcopy(cls._instance_dict)
cls.get_instance(f'overwrite-{time.time()}', scope_name=scope_name)
try:
yield
finally:
cls._instance_dict = tmp

View File

@ -468,7 +468,7 @@ class Registry:
return None return None
def build(self, *args, **kwargs) -> Any: def build(self, cfg, *args, **kwargs) -> Any:
"""Build an instance. """Build an instance.
Build an instance by calling :attr:`build_func`. If the global Build an instance by calling :attr:`build_func`. If the global
@ -487,27 +487,28 @@ class Registry:
>>> cfg = dict(type='ResNet', depth=50) >>> cfg = dict(type='ResNet', depth=50)
>>> model = MODELS.build(cfg) >>> model = MODELS.build(cfg)
""" """
# get the global default scope with DefaultScope.overwrite_default_scope(cfg.pop('_scope_', None)):
default_scope = DefaultScope.get_current_instance() # get the global default scope
if default_scope is not None: default_scope = DefaultScope.get_current_instance()
scope_name = default_scope.scope_name if default_scope is not None:
root = self._get_root_registry() scope_name = default_scope.scope_name
registry = root._search_child(scope_name) root = self._get_root_registry()
if registry is None: registry = root._search_child(scope_name)
# if `default_scope` can not be found, fallback to use self if registry is None:
warnings.warn( # if `default_scope` can not be found, fallback to use self
f'Failed to search registry with scope "{scope_name}" in ' warnings.warn(
f'the "{root.name}" registry tree. ' f'Failed to search registry with scope "{scope_name}" '
f'As a workaround, the current "{self.name}" registry in ' f'in the "{root.name}" registry tree. '
f'"{self.scope}" is used to build instance. This may ' f'As a workaround, the current "{self.name}" registry '
f'cause unexpected failure when running the built ' f'in "{self.scope}" is used to build instance. This '
f'modules. Please check whether "{scope_name}" is a ' f'may cause unexpected failure when running the built '
f'correct scope, or whether the registry is initialized.') f'modules. Please check whether "{scope_name}" is a '
f'correct scope, or whether the registry is '
f'initialized.')
registry = self
else:
registry = self registry = self
else: return registry.build_func(cfg, *args, **kwargs, registry=registry)
registry = self
return registry.build_func(*args, **kwargs, registry=registry)
def _add_child(self, registry: 'Registry') -> None: def _add_child(self, registry: 'Registry') -> None:
"""Add a child for a registry. """Add a child for a registry.

View File

@ -21,3 +21,15 @@ class TestDefaultScope:
DefaultScope.get_instance('instance_name', scope_name='mmengine') DefaultScope.get_instance('instance_name', scope_name='mmengine')
default_scope = DefaultScope.get_current_instance() default_scope = DefaultScope.get_current_instance()
assert default_scope.scope_name == 'mmengine' assert default_scope.scope_name == 'mmengine'
def test_overwrite_default_scope(self):
origin_scope = DefaultScope.get_instance(
'test_overwrite_default_scope', scope_name='origin_scope')
with DefaultScope.overwrite_default_scope(scope_name=None):
assert DefaultScope.get_current_instance(
).scope_name == 'origin_scope'
with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'):
assert DefaultScope.get_current_instance(
).scope_name == 'test_overwrite'
assert DefaultScope.get_current_instance(
).scope_name == origin_scope.scope_name == 'origin_scope'

View File

@ -167,16 +167,17 @@ class TestRegistry:
registries = [] registries = []
DOGS = Registry('dogs') DOGS = Registry('dogs')
registries.append(DOGS) registries.append(DOGS)
HOUNDS = Registry('dogs', parent=DOGS, scope='hound') HOUNDS = Registry('hounds', parent=DOGS, scope='hound')
registries.append(HOUNDS) registries.append(HOUNDS)
LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound') LITTLE_HOUNDS = Registry(
'little hounds', parent=HOUNDS, scope='little_hound')
registries.append(LITTLE_HOUNDS) registries.append(LITTLE_HOUNDS)
MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound') MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound')
registries.append(MID_HOUNDS) registries.append(MID_HOUNDS)
SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed') SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed')
registries.append(SAMOYEDS) registries.append(SAMOYEDS)
LITTLE_SAMOYEDS = Registry( LITTLE_SAMOYEDS = Registry(
'dogs', parent=SAMOYEDS, scope='little_samoyed') 'little samoyeds', parent=SAMOYEDS, scope='little_samoyed')
registries.append(LITTLE_SAMOYEDS) registries.append(LITTLE_SAMOYEDS)
return registries return registries
@ -323,7 +324,7 @@ class TestRegistry:
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed) # (little_hound) (mid_hound) (little_samoyed)
registries = self._build_registry() registries = self._build_registry()
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4] DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
@DOGS.register_module() @DOGS.register_module()
class GoldenRetriever: class GoldenRetriever:
@ -367,6 +368,37 @@ class TestRegistry:
dog = MID_HOUNDS.build(b_cfg) dog = MID_HOUNDS.build(b_cfg)
assert isinstance(dog, Beagle) assert isinstance(dog, Beagle)
# test overwrite default scope with `_scope_`
@SAMOYEDS.register_module()
class MySamoyed:
def __init__(self, friend):
self.friend = DOGS.build(friend)
@SAMOYEDS.register_module()
class YourSamoyed:
pass
s_cfg = cfg_type(
dict(
_scope_='samoyed',
type='MySamoyed',
friend=dict(type='hound.BloodHound')))
dog = DOGS.build(s_cfg)
assert isinstance(dog, MySamoyed)
assert isinstance(dog.friend, BloodHound)
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
s_cfg = cfg_type(
dict(
_scope_='samoyed',
type='MySamoyed',
friend=dict(type='YourSamoyed')))
dog = DOGS.build(s_cfg)
assert isinstance(dog, MySamoyed)
assert isinstance(dog.friend, YourSamoyed)
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
def test_repr(self): def test_repr(self):
CATS = Registry('cat') CATS = Registry('cat')