mirror of https://github.com/open-mmlab/mmcv.git
parent
45111e193d
commit
a5b5193767
|
@ -85,10 +85,14 @@ class Config(object):
|
|||
check_file_exist(filename)
|
||||
if filename.endswith('.py'):
|
||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||
temp_config_file = tempfile.NamedTemporaryFile(
|
||||
dir=temp_config_dir, suffix='.py')
|
||||
temp_config_name = osp.basename(temp_config_file.name)
|
||||
shutil.copyfile(filename,
|
||||
osp.join(temp_config_dir, '_tempconfig.py'))
|
||||
osp.join(temp_config_dir, temp_config_name))
|
||||
temp_module_name = osp.splitext(temp_config_name)[0]
|
||||
sys.path.insert(0, temp_config_dir)
|
||||
mod = import_module('_tempconfig')
|
||||
mod = import_module(temp_module_name)
|
||||
sys.path.pop(0)
|
||||
cfg_dict = {
|
||||
name: value
|
||||
|
@ -96,7 +100,9 @@ class Config(object):
|
|||
if not name.startswith('__')
|
||||
}
|
||||
# delete imported module
|
||||
del sys.modules['_tempconfig']
|
||||
del sys.modules[temp_module_name]
|
||||
# close temp file
|
||||
temp_config_file.close()
|
||||
elif filename.endswith(('.yml', '.yaml', '.json')):
|
||||
import mmcv
|
||||
cfg_dict = mmcv.load(filename)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from mmcv import Config # isort:skip
|
||||
cfg = Config.fromfile('./tests/data/config/a.py')
|
||||
item5 = cfg.item1[0] + cfg.item2.a
|
|
@ -134,6 +134,17 @@ def test_merge_intermediate_variable():
|
|||
assert cfg.item6 == dict(cfg=dict(b=2))
|
||||
|
||||
|
||||
def test_fromfile_in_config():
|
||||
cfg_file = osp.join(osp.dirname(__file__), 'data/config/code.py')
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
# cfg.field
|
||||
assert cfg.cfg.item1 == [1, 2]
|
||||
assert cfg.cfg.item2 == dict(a=0)
|
||||
assert cfg.cfg.item3 is True
|
||||
assert cfg.cfg.item4 == 'test'
|
||||
assert cfg.item5 == 1
|
||||
|
||||
|
||||
def test_dict():
|
||||
cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test')
|
||||
|
||||
|
|
Loading…
Reference in New Issue