mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support overwrite default scope with "_scope_". (#275)
* [Feature] Support overwrite default scope with "_scope_". * add ut * add ut
This commit is contained in:
parent
7a5d3c83ea
commit
2f16ec69fb
@ -1,5 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Optional
|
import copy
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Generator, Optional
|
||||||
|
|
||||||
from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock
|
from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock
|
||||||
|
|
||||||
@ -71,3 +74,17 @@ class DefaultScope(ManagerMixin):
|
|||||||
instance = None
|
instance = None
|
||||||
_release_lock()
|
_release_lock()
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
|
||||||
|
"""overwrite the current default scope with `scope_name`"""
|
||||||
|
if scope_name is None:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
tmp = copy.deepcopy(cls._instance_dict)
|
||||||
|
cls.get_instance(f'overwrite-{time.time()}', scope_name=scope_name)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
cls._instance_dict = tmp
|
||||||
|
@ -468,7 +468,7 @@ class Registry:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def build(self, *args, **kwargs) -> Any:
|
def build(self, cfg, *args, **kwargs) -> Any:
|
||||||
"""Build an instance.
|
"""Build an instance.
|
||||||
|
|
||||||
Build an instance by calling :attr:`build_func`. If the global
|
Build an instance by calling :attr:`build_func`. If the global
|
||||||
@ -487,6 +487,7 @@ class Registry:
|
|||||||
>>> cfg = dict(type='ResNet', depth=50)
|
>>> cfg = dict(type='ResNet', depth=50)
|
||||||
>>> model = MODELS.build(cfg)
|
>>> model = MODELS.build(cfg)
|
||||||
"""
|
"""
|
||||||
|
with DefaultScope.overwrite_default_scope(cfg.pop('_scope_', None)):
|
||||||
# get the global default scope
|
# get the global default scope
|
||||||
default_scope = DefaultScope.get_current_instance()
|
default_scope = DefaultScope.get_current_instance()
|
||||||
if default_scope is not None:
|
if default_scope is not None:
|
||||||
@ -496,18 +497,18 @@ class Registry:
|
|||||||
if registry is None:
|
if registry is None:
|
||||||
# if `default_scope` can not be found, fallback to use self
|
# if `default_scope` can not be found, fallback to use self
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f'Failed to search registry with scope "{scope_name}" in '
|
f'Failed to search registry with scope "{scope_name}" '
|
||||||
f'the "{root.name}" registry tree. '
|
f'in the "{root.name}" registry tree. '
|
||||||
f'As a workaround, the current "{self.name}" registry in '
|
f'As a workaround, the current "{self.name}" registry '
|
||||||
f'"{self.scope}" is used to build instance. This may '
|
f'in "{self.scope}" is used to build instance. This '
|
||||||
f'cause unexpected failure when running the built '
|
f'may cause unexpected failure when running the built '
|
||||||
f'modules. Please check whether "{scope_name}" is a '
|
f'modules. Please check whether "{scope_name}" is a '
|
||||||
f'correct scope, or whether the registry is initialized.')
|
f'correct scope, or whether the registry is '
|
||||||
|
f'initialized.')
|
||||||
registry = self
|
registry = self
|
||||||
else:
|
else:
|
||||||
registry = self
|
registry = self
|
||||||
|
return registry.build_func(cfg, *args, **kwargs, registry=registry)
|
||||||
return registry.build_func(*args, **kwargs, registry=registry)
|
|
||||||
|
|
||||||
def _add_child(self, registry: 'Registry') -> None:
|
def _add_child(self, registry: 'Registry') -> None:
|
||||||
"""Add a child for a registry.
|
"""Add a child for a registry.
|
||||||
|
@ -21,3 +21,15 @@ class TestDefaultScope:
|
|||||||
DefaultScope.get_instance('instance_name', scope_name='mmengine')
|
DefaultScope.get_instance('instance_name', scope_name='mmengine')
|
||||||
default_scope = DefaultScope.get_current_instance()
|
default_scope = DefaultScope.get_current_instance()
|
||||||
assert default_scope.scope_name == 'mmengine'
|
assert default_scope.scope_name == 'mmengine'
|
||||||
|
|
||||||
|
def test_overwrite_default_scope(self):
|
||||||
|
origin_scope = DefaultScope.get_instance(
|
||||||
|
'test_overwrite_default_scope', scope_name='origin_scope')
|
||||||
|
with DefaultScope.overwrite_default_scope(scope_name=None):
|
||||||
|
assert DefaultScope.get_current_instance(
|
||||||
|
).scope_name == 'origin_scope'
|
||||||
|
with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'):
|
||||||
|
assert DefaultScope.get_current_instance(
|
||||||
|
).scope_name == 'test_overwrite'
|
||||||
|
assert DefaultScope.get_current_instance(
|
||||||
|
).scope_name == origin_scope.scope_name == 'origin_scope'
|
||||||
|
@ -167,16 +167,17 @@ class TestRegistry:
|
|||||||
registries = []
|
registries = []
|
||||||
DOGS = Registry('dogs')
|
DOGS = Registry('dogs')
|
||||||
registries.append(DOGS)
|
registries.append(DOGS)
|
||||||
HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
|
HOUNDS = Registry('hounds', parent=DOGS, scope='hound')
|
||||||
registries.append(HOUNDS)
|
registries.append(HOUNDS)
|
||||||
LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound')
|
LITTLE_HOUNDS = Registry(
|
||||||
|
'little hounds', parent=HOUNDS, scope='little_hound')
|
||||||
registries.append(LITTLE_HOUNDS)
|
registries.append(LITTLE_HOUNDS)
|
||||||
MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
|
MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound')
|
||||||
registries.append(MID_HOUNDS)
|
registries.append(MID_HOUNDS)
|
||||||
SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
|
SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed')
|
||||||
registries.append(SAMOYEDS)
|
registries.append(SAMOYEDS)
|
||||||
LITTLE_SAMOYEDS = Registry(
|
LITTLE_SAMOYEDS = Registry(
|
||||||
'dogs', parent=SAMOYEDS, scope='little_samoyed')
|
'little samoyeds', parent=SAMOYEDS, scope='little_samoyed')
|
||||||
registries.append(LITTLE_SAMOYEDS)
|
registries.append(LITTLE_SAMOYEDS)
|
||||||
|
|
||||||
return registries
|
return registries
|
||||||
@ -323,7 +324,7 @@ class TestRegistry:
|
|||||||
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
|
||||||
# (little_hound) (mid_hound) (little_samoyed)
|
# (little_hound) (mid_hound) (little_samoyed)
|
||||||
registries = self._build_registry()
|
registries = self._build_registry()
|
||||||
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
|
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
|
||||||
|
|
||||||
@DOGS.register_module()
|
@DOGS.register_module()
|
||||||
class GoldenRetriever:
|
class GoldenRetriever:
|
||||||
@ -367,6 +368,37 @@ class TestRegistry:
|
|||||||
dog = MID_HOUNDS.build(b_cfg)
|
dog = MID_HOUNDS.build(b_cfg)
|
||||||
assert isinstance(dog, Beagle)
|
assert isinstance(dog, Beagle)
|
||||||
|
|
||||||
|
# test overwrite default scope with `_scope_`
|
||||||
|
@SAMOYEDS.register_module()
|
||||||
|
class MySamoyed:
|
||||||
|
|
||||||
|
def __init__(self, friend):
|
||||||
|
self.friend = DOGS.build(friend)
|
||||||
|
|
||||||
|
@SAMOYEDS.register_module()
|
||||||
|
class YourSamoyed:
|
||||||
|
pass
|
||||||
|
|
||||||
|
s_cfg = cfg_type(
|
||||||
|
dict(
|
||||||
|
_scope_='samoyed',
|
||||||
|
type='MySamoyed',
|
||||||
|
friend=dict(type='hound.BloodHound')))
|
||||||
|
dog = DOGS.build(s_cfg)
|
||||||
|
assert isinstance(dog, MySamoyed)
|
||||||
|
assert isinstance(dog.friend, BloodHound)
|
||||||
|
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
||||||
|
|
||||||
|
s_cfg = cfg_type(
|
||||||
|
dict(
|
||||||
|
_scope_='samoyed',
|
||||||
|
type='MySamoyed',
|
||||||
|
friend=dict(type='YourSamoyed')))
|
||||||
|
dog = DOGS.build(s_cfg)
|
||||||
|
assert isinstance(dog, MySamoyed)
|
||||||
|
assert isinstance(dog.friend, YourSamoyed)
|
||||||
|
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
|
||||||
|
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
CATS = Registry('cat')
|
CATS = Registry('cat')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user