[Fix] Fix Config cannot parse base config when there is . in tmp path (#856)

* [Fix] Fix config cannot parse tmp path like

* Add comments

* Add comments

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Mashiro 2022-12-30 14:56:14 +08:00 committed by GitHub
parent 6af88783fb
commit ad1b43faf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 3 deletions

View File

@ -550,8 +550,8 @@ class Config:
Returns:
list: A list of base config.
"""
file_format = filename.partition('.')[-1]
if file_format == 'py':
file_format = osp.splitext(filename)[1]
if file_format == '.py':
Config._validate_py_syntax(filename)
with open(filename, encoding='utf-8') as f:
codes = ast.parse(f.read()).body
@ -568,7 +568,7 @@ class Config:
base_files = eval(compile(base_code, '', mode='eval'))
else:
base_files = []
elif file_format in ('yml', 'yaml', 'json'):
elif file_format in ('.yml', '.yaml', '.json'):
import mmengine
cfg_dict = mmengine.load(filename)
base_files = cfg_dict.get(BASE_KEY, [])

View File

@ -5,8 +5,10 @@ import os
import os.path as osp
import platform
import sys
import tempfile
from importlib import import_module
from pathlib import Path
from unittest.mock import patch
import pytest
@ -715,6 +717,21 @@ class TestConfig:
cfg = Config._file2dict(cfg_file)[0]
assert cfg == dict(item1=dict(a=1))
# Simulate the case that the temporary directory includes `.`, etc.
# /tmp/test.axsgr12/. This patch is to check the issue
# https://github.com/open-mmlab/mmengine/issues/788 has been solved.
class PatchedTempDirectory(tempfile.TemporaryDirectory):
def __init__(self, *args, prefix='test.', **kwargs):
super().__init__(*args, prefix=prefix, **kwargs)
with patch('mmengine.config.config.tempfile.TemporaryDirectory',
PatchedTempDirectory):
cfg_file = osp.join(self.data_path,
'config/py_config/test_py_modify_key.py')
cfg = Config._file2dict(cfg_file)[0]
assert cfg == dict(item1=dict(a=1))
def _merge_recursive_bases(self):
cfg_file = osp.join(self.data_path,
'config/py_config/test_merge_recursive_bases.py')