190 lines
6.9 KiB
Python
190 lines
6.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import ast
|
|
import copy
|
|
import os
|
|
import os.path as osp
|
|
from importlib import import_module
|
|
from importlib.util import find_spec
|
|
from unittest import TestCase
|
|
|
|
import numpy
|
|
import numpy.compat
|
|
import numpy.linalg as linalg
|
|
from rich.progress import Progress
|
|
|
|
import mmengine
|
|
from mmengine.config import Config
|
|
from mmengine.config.lazy import LazyAttr, LazyObject
|
|
from mmengine.config.utils import ImportTransformer, _gather_abs_import_lazyobj
|
|
from mmengine.fileio import LocalBackend, PetrelBackend
|
|
|
|
|
|
class TestImportTransformer(TestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
cls.data_dir = osp.join( # type: ignore
|
|
osp.dirname(__file__), '..', 'data', 'config',
|
|
'lazy_module_config')
|
|
super().setUpClass()
|
|
|
|
def test_lazy_module(self):
|
|
cfg_path = osp.join(self.data_dir, 'test_ast_transform.py')
|
|
with open(cfg_path) as f:
|
|
codestr = f.read()
|
|
codeobj = ast.parse(codestr)
|
|
global_dict = {
|
|
'LazyObject': LazyObject,
|
|
}
|
|
base_dict = {
|
|
'._base_.default_runtime': {
|
|
'default_scope': 'test_config'
|
|
},
|
|
'._base_.scheduler': {
|
|
'val_cfg': {}
|
|
},
|
|
}
|
|
codeobj = ImportTransformer(global_dict, base_dict).visit(codeobj)
|
|
codeobj, _ = _gather_abs_import_lazyobj(codeobj)
|
|
codeobj = ast.fix_missing_locations(codeobj)
|
|
|
|
exec(compile(codeobj, cfg_path, mode='exec'), global_dict, global_dict)
|
|
# 1. absolute import
|
|
# 1.1 import module as LazyObject
|
|
lazy_numpy = global_dict['numpy']
|
|
self.assertIsInstance(lazy_numpy, LazyObject)
|
|
|
|
# 1.2 getattr as LazyAttr
|
|
self.assertIsInstance(lazy_numpy.linalg, LazyAttr)
|
|
self.assertIsInstance(lazy_numpy.compat, LazyAttr)
|
|
|
|
# 1.3 Build module from LazyObject. amp and functional can be accessed
|
|
imported_numpy = lazy_numpy.build()
|
|
self.assertIs(imported_numpy.linalg, linalg)
|
|
self.assertIs(imported_numpy.compat, numpy.compat)
|
|
|
|
# 1.4.1 Build module from LazyAttr
|
|
imported_linalg = lazy_numpy.linalg.build()
|
|
imported_compat = lazy_numpy.compat.build()
|
|
self.assertIs(imported_compat, numpy.compat)
|
|
self.assertIs(imported_linalg, linalg)
|
|
|
|
# 1.4.2 build class method from LazyAttr
|
|
start = global_dict['start']
|
|
self.assertEqual(start.module, 'rich.progress.Progress')
|
|
self.assertEqual(str(start), 'start')
|
|
self.assertIs(start.build(), Progress.start)
|
|
|
|
# 1.5 import ... as, and build module from LazyObject
|
|
lazy_linalg = global_dict['linalg']
|
|
self.assertIsInstance(lazy_linalg, LazyObject)
|
|
self.assertIs(lazy_linalg.build(), linalg)
|
|
self.assertIsInstance(lazy_linalg.norm, LazyAttr)
|
|
self.assertIs(lazy_linalg.norm.build(), linalg.norm)
|
|
|
|
# 1.6 import built in module
|
|
imported_os = global_dict['os']
|
|
self.assertIs(imported_os, os)
|
|
|
|
# 2. Relative import
|
|
# 2.1 from ... import ...
|
|
lazy_local_backend = global_dict['local']
|
|
self.assertIsInstance(lazy_local_backend, LazyObject)
|
|
self.assertIs(lazy_local_backend.build(), LocalBackend)
|
|
|
|
# 2.2 from ... import ... as ...
|
|
lazy_petrel_backend = global_dict['PetrelBackend']
|
|
self.assertIsInstance(lazy_petrel_backend, LazyObject)
|
|
self.assertIs(lazy_petrel_backend.build(), PetrelBackend)
|
|
|
|
# 2.3 from ... import builtin module or obj from `mmengine.Config`
|
|
self.assertIs(global_dict['find_module'], find_spec)
|
|
self.assertIs(global_dict['Config'], Config)
|
|
|
|
# 3 test import base config
|
|
# 3.1 simple from ... import and from ... import ... as
|
|
self.assertEqual(global_dict['scope'], 'test_config')
|
|
self.assertDictEqual(global_dict['val_cfg'], {})
|
|
|
|
# 4. Error catching
|
|
cfg_path = osp.join(self.data_dir,
|
|
'test_ast_transform_error_catching1.py')
|
|
with open(cfg_path) as f:
|
|
codestr = f.read()
|
|
codeobj = ast.parse(codestr)
|
|
global_dict = {'LazyObject': LazyObject}
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r'Illegal syntax in config! `from xxx import \*`'):
|
|
codeobj = ImportTransformer(global_dict).visit(codeobj)
|
|
|
|
|
|
class TestLazyObject(TestCase):
|
|
|
|
def test_init(self):
|
|
LazyObject('mmengine')
|
|
LazyObject('mmengine.fileio')
|
|
LazyObject('mmengine.fileio', 'LocalBackend')
|
|
|
|
# module must be str
|
|
with self.assertRaises(TypeError):
|
|
LazyObject(1)
|
|
|
|
# imported must be a sequence of string or None
|
|
with self.assertRaises(TypeError):
|
|
LazyObject('mmengine', ['error_type'])
|
|
|
|
def test_build(self):
|
|
lazy_mmengine = LazyObject('mmengine')
|
|
self.assertIs(lazy_mmengine.build(), mmengine)
|
|
|
|
lazy_mmengine_fileio = LazyObject('mmengine.fileio')
|
|
self.assertIs(lazy_mmengine_fileio.build(),
|
|
import_module('mmengine.fileio'))
|
|
|
|
lazy_local_backend = LazyObject('mmengine.fileio', 'LocalBackend')
|
|
self.assertIs(lazy_local_backend.build(), LocalBackend)
|
|
|
|
# TODO: The commented test is required, we need to test the built
|
|
# LazyObject can access the `mmengine.dataset`. We need to clean the
|
|
# environment to make sure the `dataset` is not imported before, and
|
|
# it is triggered by lazy_mmengine.build(). However, if we simply
|
|
# pop the `mmengine.dataset` will lead to other tests failed, of which
|
|
# reason is still unknown. We need to figure out the reason and fix it
|
|
# in the latter
|
|
|
|
# sys.modules.pop('mmengine.config')
|
|
# sys.modules.pop('mmengine.fileio')
|
|
# sys.modules.pop('mmengine')
|
|
# lazy_mmengine = LazyObject(['mmengine', 'mmengine.dataset'])
|
|
# self.assertIs(lazy_mmengine.build().dataset,
|
|
# import_module('mmengine.config'))
|
|
copied = copy.deepcopy(lazy_local_backend)
|
|
self.assertDictEqual(copied.__dict__, lazy_local_backend.__dict__)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
lazy_mmengine()
|
|
|
|
with self.assertRaises(ImportError):
|
|
LazyObject('unknown').build()
|
|
|
|
|
|
class TestLazyAttr(TestCase):
|
|
# Since LazyAttr should only be built from LazyObect, we only test
|
|
# the build method here.
|
|
def test_build(self):
|
|
lazy_mmengine = LazyObject('mmengine')
|
|
local_backend = lazy_mmengine.fileio.LocalBackend
|
|
self.assertIs(local_backend.build(), LocalBackend)
|
|
|
|
copied = copy.deepcopy(local_backend)
|
|
self.assertDictEqual(copied.__dict__, local_backend.__dict__)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
local_backend()
|
|
|
|
with self.assertRaisesRegex(
|
|
ImportError,
|
|
'Failed to import mmengine.fileio.LocalBackend.unknown'):
|
|
local_backend.unknown.build()
|