mirror of https://github.com/open-mmlab/mmcv.git
Use mapping rather than dict for special keys (#304)
* Support path as a key in dict of config * reformat test case * update pre-commit version and fix format * fix bug * clean code * reformat * fix missing partspull/311/head
parent
b63d880af5
commit
8ceb404ea6
|
@ -1,7 +1,7 @@
|
|||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://gitlab.com/pycqa/flake8.git
|
||||
rev: 3.7.9
|
||||
rev: 3.8.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/asottile/seed-isort-config
|
||||
|
@ -13,7 +13,7 @@ repos:
|
|||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.29.0
|
||||
rev: v0.30.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
|
|
|
@ -219,46 +219,72 @@ class Config(object):
|
|||
s = first + '\n' + s
|
||||
return s
|
||||
|
||||
def _format_basic_types(k, v):
|
||||
def _format_basic_types(k, v, use_mapping=False):
|
||||
if isinstance(v, str):
|
||||
v_str = f"'{v}'"
|
||||
else:
|
||||
v_str = str(v)
|
||||
attr_str = f'{str(k)}={v_str}'
|
||||
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: {v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}={v_str}'
|
||||
attr_str = _indent(attr_str, indent)
|
||||
|
||||
return attr_str
|
||||
|
||||
def _format_list(k, v):
|
||||
def _format_list(k, v, use_mapping=False):
|
||||
# check if all items in the list are dict
|
||||
if all(isinstance(_, dict) for _ in v):
|
||||
v_str = '[\n'
|
||||
v_str += '\n'.join(
|
||||
f'dict({_indent(_format_dict(v_), indent)}),'
|
||||
for v_ in v).rstrip(',')
|
||||
attr_str = f'{str(k)}={v_str}'
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: {v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}={v_str}'
|
||||
attr_str = _indent(attr_str, indent) + ']'
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v)
|
||||
attr_str = _format_basic_types(k, v, use_mapping)
|
||||
return attr_str
|
||||
|
||||
def _format_dict(d, outest_level=False):
|
||||
def _contain_invalid_identifier(dict_str):
|
||||
contain_invalid_identifier = False
|
||||
for key_name in dict_str:
|
||||
contain_invalid_identifier |= \
|
||||
(not str(key_name).isidentifier())
|
||||
return contain_invalid_identifier
|
||||
|
||||
def _format_dict(input_dict, outest_level=False):
|
||||
r = ''
|
||||
s = []
|
||||
for idx, (k, v) in enumerate(d.items()):
|
||||
is_last = idx >= len(d) - 1
|
||||
|
||||
use_mapping = _contain_invalid_identifier(input_dict)
|
||||
if use_mapping:
|
||||
r += '{'
|
||||
for idx, (k, v) in enumerate(input_dict.items()):
|
||||
is_last = idx >= len(input_dict) - 1
|
||||
end = '' if outest_level or is_last else ','
|
||||
if isinstance(v, dict):
|
||||
v_str = '\n' + _format_dict(v)
|
||||
attr_str = f'{str(k)}=dict({v_str}'
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: dict({v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}=dict({v_str}'
|
||||
attr_str = _indent(attr_str, indent) + ')' + end
|
||||
elif isinstance(v, list):
|
||||
attr_str = _format_list(k, v) + end
|
||||
attr_str = _format_list(k, v, use_mapping) + end
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v) + end
|
||||
attr_str = _format_basic_types(k, v, use_mapping) + end
|
||||
|
||||
s.append(attr_str)
|
||||
r += '\n'.join(s)
|
||||
if use_mapping:
|
||||
r += '}'
|
||||
return r
|
||||
|
||||
cfg_dict = self._cfg_dict.to_dict()
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
test_item1 = [1, 2]
|
||||
bool_item2 = True
|
||||
str_item3 = 'test'
|
||||
dict_item4 = dict(
|
||||
a={
|
||||
'c/d': 'path/d',
|
||||
'f': 's3//f',
|
||||
6: '2333',
|
||||
'2333': 'number'
|
||||
},
|
||||
b={'8': 543},
|
||||
c={9: 678},
|
||||
d={'a': 0},
|
||||
f=dict(a='69'))
|
||||
dict_item5 = {'x/x': {'a.0': 233}}
|
||||
dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]}
|
|
@ -254,3 +254,15 @@ def test_dict_action():
|
|||
cfg.merge_from_dict(args.options)
|
||||
assert cfg.item2 == dict(a=1, b=0.1, c='x')
|
||||
assert cfg.item3 is False
|
||||
|
||||
|
||||
def test_dump_mapping():
|
||||
cfg_file = osp.join(osp.dirname(__file__), 'data/config/n.py')
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||
text_cfg_filename = osp.join(temp_config_dir, '_text_config.py')
|
||||
cfg.dump(text_cfg_filename)
|
||||
text_cfg = Config.fromfile(text_cfg_filename)
|
||||
|
||||
assert text_cfg._cfg_dict == cfg._cfg_dict
|
||||
|
|
Loading…
Reference in New Issue