[Fix] Fix new config (#1227)

pull/1231/head
Mashiro 2023-07-01 22:35:11 +08:00 committed by GitHub
parent eea8a7135c
commit b638d3b1fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 27 deletions

View File

@ -147,19 +147,24 @@ class LazyAttr:
if isinstance(self.source, LazyObject): if isinstance(self.source, LazyObject):
if isinstance(self.source._module, str): if isinstance(self.source._module, str):
# In this case, the source code of LazyObject could be one of if self.source._imported is None:
# the following: # source code:
# 1. import xxx.yyy as zzz # from xxx.yyy import zzz
# 2. from xxx.yyy import zzz # equivalent code:
# zzz = LazyObject('xxx.yyy', 'zzz')
# The equivalent code of LazyObject is: # The source code of get attribute:
# 1. zzz = LazyObject('xxx.yyy') # eee = zzz.eee
# 2. zzz = LazyObject('xxx.yyy', 'zzz') # Then, `eee._module` should be "xxx.yyy.zzz"
self._module = self.source._module
# The source code of LazyAttr will be: else:
# eee = zzz.eee # source code:
# Then, eee._module = xxx.yyy # import xxx.yyy as zzz
self._module = self.source._module # equivalent code:
# zzz = LazyObject('xxx.yyy')
# The source code of get attribute:
# eee = zzz.eee
# Then, `eee._module` should be "xxx.yyy"
self._module = f'{self.source._module}.{self.source}'
else: else:
# The source code of LazyObject should be # The source code of LazyObject should be
# 1. import xxx.yyy # 1. import xxx.yyy

View File

@ -442,19 +442,6 @@ class Registry:
'The key argument of `Registry.get` must be a str, ' 'The key argument of `Registry.get` must be a str, '
f'got {type(key)}') f'got {type(key)}')
# Actually, it's strange to implement this `try ... except` to get the
# object by its name in `Registry.get`. However, If we want to build
# the model using a configuration like
# `dict(type='mmengine.model.BaseModel')`, which can
# be dumped by lazy import config, we need this code snippet
# for `Registry.get` to work.
try:
obj_cls = get_object_from_string(key)
except Exception:
raise RuntimeError(f'Failed to get {key}')
if obj_cls is not None:
return obj_cls
scope, real_key = self.split_scope_key(key) scope, real_key = self.split_scope_key(key)
obj_cls = None obj_cls = None
registry_name = self.name registry_name = self.name
@ -508,6 +495,18 @@ class Registry:
else: else:
obj_cls = root.get(key) obj_cls = root.get(key)
if obj_cls is None:
# Actually, it's strange to implement this `try ... except` to
# get the object by its name in `Registry.get`. However, If we
# want to build the model using a configuration like
# `dict(type='mmengine.model.BaseModel')`, which can
# be dumped by lazy import config, we need this code snippet
# for `Registry.get` to work.
try:
obj_cls = get_object_from_string(key)
except Exception:
raise RuntimeError(f'Failed to get {key}')
if obj_cls is not None: if obj_cls is not None:
# For some rare cases (e.g. obj_cls is a partial function), obj_cls # For some rare cases (e.g. obj_cls is a partial function), obj_cls
# doesn't have `__name__`. Use default value to prevent error # doesn't have `__name__`. Use default value to prevent error
@ -517,6 +516,7 @@ class Registry:
f' registry in "{scope_name}"', f' registry in "{scope_name}"',
logger='current', logger='current',
level=logging.DEBUG) level=logging.DEBUG)
return obj_cls return obj_cls
def _search_child(self, scope: str) -> Optional['Registry']: def _search_child(self, scope: str) -> Optional['Registry']:

View File

@ -11,3 +11,5 @@ from mmengine.fileio import LocalBackend as local
from mmengine.fileio import PetrelBackend from mmengine.fileio import PetrelBackend
from ._base_.default_runtime import default_scope as scope from ._base_.default_runtime import default_scope as scope
from ._base_.scheduler import val_cfg from ._base_.scheduler import val_cfg
from rich.progress import Progress
start = Progress.start

View File

@ -10,6 +10,7 @@ from unittest import TestCase
import numpy import numpy
import numpy.compat import numpy.compat
import numpy.linalg as linalg import numpy.linalg as linalg
from rich.progress import Progress
import mmengine import mmengine
from mmengine.config import Config from mmengine.config import Config
@ -62,12 +63,18 @@ class TestImportTransformer(TestCase):
self.assertIs(imported_numpy.linalg, linalg) self.assertIs(imported_numpy.linalg, linalg)
self.assertIs(imported_numpy.compat, numpy.compat) self.assertIs(imported_numpy.compat, numpy.compat)
# 1.4 Build module from LazyAttr # 1.4.1 Build module from LazyAttr
imported_linalg = lazy_numpy.linalg.build() imported_linalg = lazy_numpy.linalg.build()
imported_compat = lazy_numpy.compat.build() imported_compat = lazy_numpy.compat.build()
self.assertIs(imported_compat, numpy.compat) self.assertIs(imported_compat, numpy.compat)
self.assertIs(imported_linalg, linalg) self.assertIs(imported_linalg, linalg)
# 1.4.2 build class method from LazyAttr
start = global_dict['start']
self.assertEqual(start.module, 'rich.progress.Progress')
self.assertEqual(str(start), 'start')
self.assertIs(start.build(), Progress.start)
# 1.5 import ... as, and build module from LazyObject # 1.5 import ... as, and build module from LazyObject
lazy_linalg = global_dict['linalg'] lazy_linalg = global_dict['linalg']
self.assertIsInstance(lazy_linalg, LazyObject) self.assertIsInstance(lazy_linalg, LazyObject)