49 lines
1.9 KiB
Python
49 lines
1.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from collections import OrderedDict
|
|
|
|
import pytest
|
|
|
|
from mmengine.registry import DefaultScope
|
|
|
|
|
|
class TestDefaultScope:
|
|
|
|
def test_scope(self):
|
|
default_scope = DefaultScope.get_instance('name1', scope_name='mmdet')
|
|
assert default_scope.scope_name == 'mmdet'
|
|
# `DefaultScope.get_instance` must have `scope_name` argument.
|
|
with pytest.raises(TypeError):
|
|
DefaultScope.get_instance('name2')
|
|
|
|
def test_get_current_instance(self):
|
|
DefaultScope._instance_dict = OrderedDict()
|
|
assert DefaultScope.get_current_instance() is None
|
|
DefaultScope.get_instance('instance_name', scope_name='mmengine')
|
|
default_scope = DefaultScope.get_current_instance()
|
|
assert default_scope.scope_name == 'mmengine'
|
|
|
|
def test_overwrite_default_scope(self):
|
|
origin_scope = DefaultScope.get_instance(
|
|
'test_overwrite_default_scope', scope_name='origin_scope')
|
|
with DefaultScope.overwrite_default_scope(scope_name=None):
|
|
assert DefaultScope.get_current_instance(
|
|
).scope_name == 'origin_scope'
|
|
with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'):
|
|
assert DefaultScope.get_current_instance(
|
|
).scope_name == 'test_overwrite'
|
|
assert DefaultScope.get_current_instance(
|
|
).scope_name == origin_scope.scope_name == 'origin_scope'
|
|
|
|
# Test overwrite default scope immediately.
|
|
# Test sequentially overwrite.
|
|
with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'):
|
|
pass
|
|
with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'):
|
|
pass
|
|
|
|
# Test nested overwrite.
|
|
with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'):
|
|
with DefaultScope.overwrite_default_scope(
|
|
scope_name='test_overwrite'):
|
|
pass
|