mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Support variables in base files for configs (#1083)
* Support variables in base files for configs Signed-off-by: lizz <lizz@sensetime.com> * Test json and yaml as well Signed-off-by: lizz <lizz@sensetime.com> * Add test for recusive base Signed-off-by: lizz <lizz@sensetime.com> * Test misleading values Signed-off-by: lizz <lizz@sensetime.com> * Improve comments Signed-off-by: lizz <lizz@sensetime.com> * Add doc Signed-off-by: lizz <lizz@sensetime.com> * Improve doc Signed-off-by: lizz <lizz@sensetime.com> * More tests Signed-off-by: lizz <lizz@sensetime.com> * Harder test case Signed-off-by: lizz <lizz@sensetime.com> * use BASE_KEY instead of base Signed-off-by: lizz <lizz@sensetime.com>
This commit is contained in:
parent
eb08835fa2
commit
d9effbd1d0
@ -154,6 +154,32 @@ _base_ = ['./config_a.py', './config_e.py']
|
||||
... d='string')
|
||||
```
|
||||
|
||||
#### Reference variables from base
|
||||
|
||||
You can reference variables defined in base using the following grammar.
|
||||
|
||||
`base.py`
|
||||
|
||||
```python
|
||||
item1 = 'a'
|
||||
item2 = dict(item3 = 'b')
|
||||
```
|
||||
|
||||
`config_g.py`
|
||||
|
||||
```python
|
||||
_base_ = ['./base.py']
|
||||
item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }})
|
||||
```
|
||||
|
||||
```python
|
||||
>>> cfg = Config.fromfile('./config_g.py')
|
||||
>>> print(cfg.pretty_text)
|
||||
item1 = 'a'
|
||||
item2 = dict(item3='b')
|
||||
item = dict(a='a', b='b')
|
||||
```
|
||||
|
||||
### ProgressBar
|
||||
|
||||
If you want to apply a method to a list of items and track the progress, `track_progress`
|
||||
|
@ -1,11 +1,13 @@
|
||||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import ast
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
import warnings
|
||||
from argparse import Action, ArgumentParser
|
||||
from collections import abc
|
||||
@ -121,6 +123,57 @@ class Config:
|
||||
with open(temp_config_name, 'w') as tmp_config_file:
|
||||
tmp_config_file.write(config_file)
|
||||
|
||||
@staticmethod
|
||||
def _pre_substitute_base_vars(filename, temp_config_name):
|
||||
"""Substitute base variable placehoders to string, so that parsing
|
||||
would work."""
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
config_file = f.read()
|
||||
base_var_dict = {}
|
||||
regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
|
||||
base_vars = set(re.findall(regexp, config_file))
|
||||
for base_var in base_vars:
|
||||
randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
|
||||
base_var_dict[randstr] = base_var
|
||||
regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
|
||||
config_file = re.sub(regexp, f'"{randstr}"', config_file)
|
||||
with open(temp_config_name, 'w') as tmp_config_file:
|
||||
tmp_config_file.write(config_file)
|
||||
return base_var_dict
|
||||
|
||||
@staticmethod
|
||||
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
|
||||
"""Substitute variable strings to their actual values."""
|
||||
cfg = copy.deepcopy(cfg)
|
||||
|
||||
if isinstance(cfg, dict):
|
||||
for k, v in cfg.items():
|
||||
if isinstance(v, str) and v in base_var_dict:
|
||||
new_v = base_cfg
|
||||
for new_k in base_var_dict[v].split('.'):
|
||||
new_v = new_v[new_k]
|
||||
cfg[k] = new_v
|
||||
elif isinstance(v, (list, tuple, dict)):
|
||||
cfg[k] = Config._substitute_base_vars(
|
||||
v, base_var_dict, base_cfg)
|
||||
elif isinstance(cfg, tuple):
|
||||
cfg = tuple(
|
||||
Config._substitute_base_vars(c, base_var_dict, base_cfg)
|
||||
for c in cfg)
|
||||
elif isinstance(cfg, list):
|
||||
cfg = [
|
||||
Config._substitute_base_vars(c, base_var_dict, base_cfg)
|
||||
for c in cfg
|
||||
]
|
||||
elif isinstance(cfg, str) and cfg in base_var_dict:
|
||||
new_v = base_cfg
|
||||
for new_k in base_var_dict[cfg].split('.'):
|
||||
new_v = new_v[new_k]
|
||||
cfg = new_v
|
||||
|
||||
return cfg
|
||||
|
||||
@staticmethod
|
||||
def _file2dict(filename, use_predefined_variables=True):
|
||||
filename = osp.abspath(osp.expanduser(filename))
|
||||
@ -141,6 +194,9 @@ class Config:
|
||||
temp_config_file.name)
|
||||
else:
|
||||
shutil.copyfile(filename, temp_config_file.name)
|
||||
# Substitute base variables from placeholders to strings
|
||||
base_var_dict = Config._pre_substitute_base_vars(
|
||||
temp_config_file.name, temp_config_file.name)
|
||||
|
||||
if filename.endswith('.py'):
|
||||
temp_module_name = osp.splitext(temp_config_name)[0]
|
||||
@ -185,6 +241,10 @@ class Config:
|
||||
raise KeyError('Duplicate key is not allowed among bases')
|
||||
base_cfg_dict.update(c)
|
||||
|
||||
# Subtitute base variables from strings to their actual values
|
||||
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
|
||||
base_cfg_dict)
|
||||
|
||||
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
|
||||
cfg_dict = base_cfg_dict
|
||||
|
||||
|
13
tests/data/config/t.json
Normal file
13
tests/data/config/t.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"_base_": [
|
||||
"./l1.py",
|
||||
"./l2.yaml",
|
||||
"./l3.json",
|
||||
"./l4.py"
|
||||
],
|
||||
"item3": false,
|
||||
"item4": "test",
|
||||
"item8": "{{fileBasename}}",
|
||||
"item9": {{ _base_.item2 }},
|
||||
"item10": {{ _base_.item7.b.c }}
|
||||
}
|
6
tests/data/config/t.py
Normal file
6
tests/data/config/t.py
Normal file
@ -0,0 +1,6 @@
|
||||
_base_ = ['./l1.py', './l2.yaml', './l3.json', './l4.py']
|
||||
item3 = False
|
||||
item4 = 'test'
|
||||
item8 = '{{fileBasename}}'
|
||||
item9 = {{ _base_.item2 }}
|
||||
item10 = {{ _base_.item7.b.c }}
|
6
tests/data/config/t.yaml
Normal file
6
tests/data/config/t.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
_base_ : ['./l1.py', './l2.yaml', './l3.json', './l4.py']
|
||||
item3 : False
|
||||
item4 : 'test'
|
||||
item8 : '{{fileBasename}}'
|
||||
item9 : {{ _base_.item2 }}
|
||||
item10 : {{ _base_.item7.b.c }}
|
26
tests/data/config/u.json
Normal file
26
tests/data/config/u.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"_base_": [
|
||||
"./t.py"
|
||||
],
|
||||
"base": "_base_.item8",
|
||||
"item11": {{ _base_.item8 }},
|
||||
"item12": {{ _base_.item9 }},
|
||||
"item13": {{ _base_.item10 }},
|
||||
"item14": {{ _base_.item1 }},
|
||||
"item15": {
|
||||
"a": {
|
||||
"b": {{ _base_.item2 }}
|
||||
},
|
||||
"b": [
|
||||
{{ _base_.item3 }}
|
||||
],
|
||||
"c": [{{ _base_.item4 }}],
|
||||
"d": [[
|
||||
{
|
||||
"e": {{ _base_.item5.a }}
|
||||
}
|
||||
],
|
||||
{{ _base_.item6 }}],
|
||||
"e": {{ _base_.item1 }}
|
||||
}
|
||||
}
|
13
tests/data/config/u.py
Normal file
13
tests/data/config/u.py
Normal file
@ -0,0 +1,13 @@
|
||||
_base_ = ['./t.py']
|
||||
base = '_base_.item8'
|
||||
item11 = {{ _base_.item8 }}
|
||||
item12 = {{ _base_.item9 }}
|
||||
item13 = {{ _base_.item10 }}
|
||||
item14 = {{ _base_.item1 }}
|
||||
item15 = dict(
|
||||
a = dict( b = {{ _base_.item2 }} ),
|
||||
b = [{{ _base_.item3 }}],
|
||||
c = [{{ _base_.item4 }}],
|
||||
d = [[dict(e = {{ _base_.item5.a }})],{{ _base_.item6 }}],
|
||||
e = {{ _base_.item1 }}
|
||||
)
|
15
tests/data/config/u.yaml
Normal file
15
tests/data/config/u.yaml
Normal file
@ -0,0 +1,15 @@
|
||||
_base_: ["./t.py"]
|
||||
base: "_base_.item8"
|
||||
item11: {{ _base_.item8 }}
|
||||
item12: {{ _base_.item9 }}
|
||||
item13: {{ _base_.item10 }}
|
||||
item14: {{ _base_.item1 }}
|
||||
item15:
|
||||
a:
|
||||
b: {{ _base_.item2 }}
|
||||
b: [{{ _base_.item3 }}]
|
||||
c: [{{ _base_.item4 }}]
|
||||
d:
|
||||
- [e: {{ _base_.item5.a }}]
|
||||
- {{ _base_.item6 }}
|
||||
e: {{ _base_.item1 }}
|
11
tests/data/config/v.py
Normal file
11
tests/data/config/v.py
Normal file
@ -0,0 +1,11 @@
|
||||
_base_ = ['./u.py']
|
||||
item21 = {{ _base_.item11 }}
|
||||
item22 = item21
|
||||
item23 = {{ _base_.item10 }}
|
||||
item24 = item23
|
||||
item25 = dict(
|
||||
a = dict( b = item24 ),
|
||||
b = [item24],
|
||||
c = [[dict(e = item22)],{{ _base_.item6 }}],
|
||||
e = item21
|
||||
)
|
@ -224,6 +224,81 @@ def test_merge_from_multiple_bases():
|
||||
Config.fromfile(osp.join(data_path, 'config/m.py'))
|
||||
|
||||
|
||||
def test_base_variables():
|
||||
for file in ['t.py', 't.json', 't.yaml']:
|
||||
cfg_file = osp.join(data_path, f'config/{file}')
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
assert isinstance(cfg, Config)
|
||||
assert cfg.filename == cfg_file
|
||||
# cfg.field
|
||||
assert cfg.item1 == [1, 2]
|
||||
assert cfg.item2.a == 0
|
||||
assert cfg.item3 is False
|
||||
assert cfg.item4 == 'test'
|
||||
assert cfg.item5 == dict(a=0, b=1)
|
||||
assert cfg.item6 == [dict(a=0), dict(b=1)]
|
||||
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
|
||||
assert cfg.item8 == file
|
||||
assert cfg.item9 == dict(a=0)
|
||||
assert cfg.item10 == [3.1, 4.2, 5.3]
|
||||
|
||||
# test nested base
|
||||
for file in ['u.py', 'u.json', 'u.yaml']:
|
||||
cfg_file = osp.join(data_path, f'config/{file}')
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
assert isinstance(cfg, Config)
|
||||
assert cfg.filename == cfg_file
|
||||
# cfg.field
|
||||
assert cfg.base == '_base_.item8'
|
||||
assert cfg.item1 == [1, 2]
|
||||
assert cfg.item2.a == 0
|
||||
assert cfg.item3 is False
|
||||
assert cfg.item4 == 'test'
|
||||
assert cfg.item5 == dict(a=0, b=1)
|
||||
assert cfg.item6 == [dict(a=0), dict(b=1)]
|
||||
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
|
||||
assert cfg.item8 == 't.py'
|
||||
assert cfg.item9 == dict(a=0)
|
||||
assert cfg.item10 == [3.1, 4.2, 5.3]
|
||||
assert cfg.item11 == 't.py'
|
||||
assert cfg.item12 == dict(a=0)
|
||||
assert cfg.item13 == [3.1, 4.2, 5.3]
|
||||
assert cfg.item14 == [1, 2]
|
||||
assert cfg.item15 == dict(
|
||||
a=dict(b=dict(a=0)),
|
||||
b=[False],
|
||||
c=['test'],
|
||||
d=[[{
|
||||
'e': 0
|
||||
}], [{
|
||||
'a': 0
|
||||
}, {
|
||||
'b': 1
|
||||
}]],
|
||||
e=[1, 2])
|
||||
|
||||
# test reference assignment for py
|
||||
cfg_file = osp.join(data_path, 'config/v.py')
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
assert isinstance(cfg, Config)
|
||||
assert cfg.filename == cfg_file
|
||||
assert cfg.item21 == 't.py'
|
||||
assert cfg.item22 == 't.py'
|
||||
assert cfg.item23 == [3.1, 4.2, 5.3]
|
||||
assert cfg.item24 == [3.1, 4.2, 5.3]
|
||||
assert cfg.item25 == dict(
|
||||
a=dict(b=[3.1, 4.2, 5.3]),
|
||||
b=[[3.1, 4.2, 5.3]],
|
||||
c=[[{
|
||||
'e': 't.py'
|
||||
}], [{
|
||||
'a': 0
|
||||
}, {
|
||||
'b': 1
|
||||
}]],
|
||||
e='t.py')
|
||||
|
||||
|
||||
def test_merge_recursive_bases():
|
||||
cfg_file = osp.join(data_path, 'config/f.py')
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
|
Loading…
x
Reference in New Issue
Block a user