diff --git a/docs/utils.md b/docs/utils.md index 991a862f9..84d6b06e2 100644 --- a/docs/utils.md +++ b/docs/utils.md @@ -154,6 +154,32 @@ _base_ = ['./config_a.py', './config_e.py'] ... d='string') ``` +#### Reference variables from base + +You can reference variables defined in base using the following grammar. + +`base.py` + +```python +item1 = 'a' +item2 = dict(item3 = 'b') +``` + +`config_g.py` + +```python +_base_ = ['./base.py'] +item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }}) +``` + +```python +>>> cfg = Config.fromfile('./config_g.py') +>>> print(cfg.pretty_text) +item1 = 'a' +item2 = dict(item3='b') +item = dict(a='a', b='b') +``` + ### ProgressBar If you want to apply a method to a list of items and track the progress, `track_progress` diff --git a/mmcv/utils/config.py b/mmcv/utils/config.py index fbdfe656b..d10696a6a 100644 --- a/mmcv/utils/config.py +++ b/mmcv/utils/config.py @@ -1,11 +1,13 @@ # Copyright (c) Open-MMLab. All rights reserved. import ast +import copy import os import os.path as osp import platform import shutil import sys import tempfile +import uuid import warnings from argparse import Action, ArgumentParser from collections import abc @@ -121,6 +123,57 @@ class Config: with open(temp_config_name, 'w') as tmp_config_file: tmp_config_file.write(config_file) + @staticmethod + def _pre_substitute_base_vars(filename, temp_config_name): + """Substitute base variable placehoders to string, so that parsing + would work.""" + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + config_file = f.read() + base_var_dict = {} + regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' + base_var_dict[randstr] = base_var + regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, 'w') as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _substitute_base_vars(cfg, base_var_dict, base_cfg): + """Substitute variable strings to their actual values.""" + cfg = copy.deepcopy(cfg) + + if isinstance(cfg, dict): + for k, v in cfg.items(): + if isinstance(v, str) and v in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[v].split('.'): + new_v = new_v[new_k] + cfg[k] = new_v + elif isinstance(v, (list, tuple, dict)): + cfg[k] = Config._substitute_base_vars( + v, base_var_dict, base_cfg) + elif isinstance(cfg, tuple): + cfg = tuple( + Config._substitute_base_vars(c, base_var_dict, base_cfg) + for c in cfg) + elif isinstance(cfg, list): + cfg = [ + Config._substitute_base_vars(c, base_var_dict, base_cfg) + for c in cfg + ] + elif isinstance(cfg, str) and cfg in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[cfg].split('.'): + new_v = new_v[new_k] + cfg = new_v + + return cfg + @staticmethod def _file2dict(filename, use_predefined_variables=True): filename = osp.abspath(osp.expanduser(filename)) @@ -141,6 +194,9 @@ class Config: temp_config_file.name) else: shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars( + temp_config_file.name, temp_config_file.name) if filename.endswith('.py'): temp_module_name = osp.splitext(temp_config_name)[0] @@ -185,6 +241,10 @@ class Config: raise KeyError('Duplicate key is not allowed among bases') base_cfg_dict.update(c) + # Subtitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, + base_cfg_dict) + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) cfg_dict = base_cfg_dict diff --git a/tests/data/config/t.json b/tests/data/config/t.json new file mode 100644 index 000000000..8f7b9b4a1 --- /dev/null +++ b/tests/data/config/t.json @@ -0,0 +1,13 @@ +{ + "_base_": [ + "./l1.py", + "./l2.yaml", + "./l3.json", + "./l4.py" + ], + "item3": false, + "item4": "test", + "item8": "{{fileBasename}}", + "item9": {{ _base_.item2 }}, + "item10": {{ _base_.item7.b.c }} +} diff --git a/tests/data/config/t.py b/tests/data/config/t.py new file mode 100644 index 000000000..9f085ae67 --- /dev/null +++ b/tests/data/config/t.py @@ -0,0 +1,6 @@ +_base_ = ['./l1.py', './l2.yaml', './l3.json', './l4.py'] +item3 = False +item4 = 'test' +item8 = '{{fileBasename}}' +item9 = {{ _base_.item2 }} +item10 = {{ _base_.item7.b.c }} diff --git a/tests/data/config/t.yaml b/tests/data/config/t.yaml new file mode 100644 index 000000000..ab42859ec --- /dev/null +++ b/tests/data/config/t.yaml @@ -0,0 +1,6 @@ +_base_ : ['./l1.py', './l2.yaml', './l3.json', './l4.py'] +item3 : False +item4 : 'test' +item8 : '{{fileBasename}}' +item9 : {{ _base_.item2 }} +item10 : {{ _base_.item7.b.c }} diff --git a/tests/data/config/u.json b/tests/data/config/u.json new file mode 100644 index 000000000..f6a01e3c0 --- /dev/null +++ b/tests/data/config/u.json @@ -0,0 +1,26 @@ +{ + "_base_": [ + "./t.py" + ], + "base": "_base_.item8", + "item11": {{ _base_.item8 }}, + "item12": {{ _base_.item9 }}, + "item13": {{ _base_.item10 }}, + "item14": {{ _base_.item1 }}, + "item15": { + "a": { + "b": {{ _base_.item2 }} + }, + "b": [ + {{ _base_.item3 }} + ], + "c": [{{ _base_.item4 }}], + "d": [[ + { + "e": {{ _base_.item5.a }} + } + ], + {{ _base_.item6 }}], + "e": {{ _base_.item1 }} + } +} diff --git a/tests/data/config/u.py b/tests/data/config/u.py new file mode 100644 index 000000000..bdd96a7e4 --- /dev/null +++ b/tests/data/config/u.py @@ -0,0 +1,13 @@ +_base_ = ['./t.py'] +base = '_base_.item8' +item11 = {{ _base_.item8 }} +item12 = {{ _base_.item9 }} +item13 = {{ _base_.item10 }} +item14 = {{ _base_.item1 }} +item15 = dict( + a = dict( b = {{ _base_.item2 }} ), + b = [{{ _base_.item3 }}], + c = [{{ _base_.item4 }}], + d = [[dict(e = {{ _base_.item5.a }})],{{ _base_.item6 }}], + e = {{ _base_.item1 }} +) diff --git a/tests/data/config/u.yaml b/tests/data/config/u.yaml new file mode 100644 index 000000000..d201cb926 --- /dev/null +++ b/tests/data/config/u.yaml @@ -0,0 +1,15 @@ +_base_: ["./t.py"] +base: "_base_.item8" +item11: {{ _base_.item8 }} +item12: {{ _base_.item9 }} +item13: {{ _base_.item10 }} +item14: {{ _base_.item1 }} +item15: + a: + b: {{ _base_.item2 }} + b: [{{ _base_.item3 }}] + c: [{{ _base_.item4 }}] + d: + - [e: {{ _base_.item5.a }}] + - {{ _base_.item6 }} + e: {{ _base_.item1 }} diff --git a/tests/data/config/v.py b/tests/data/config/v.py new file mode 100644 index 000000000..3d2a1a436 --- /dev/null +++ b/tests/data/config/v.py @@ -0,0 +1,11 @@ +_base_ = ['./u.py'] +item21 = {{ _base_.item11 }} +item22 = item21 +item23 = {{ _base_.item10 }} +item24 = item23 +item25 = dict( + a = dict( b = item24 ), + b = [item24], + c = [[dict(e = item22)],{{ _base_.item6 }}], + e = item21 +) diff --git a/tests/test_utils/test_config.py b/tests/test_utils/test_config.py index 5abafe80b..44a67ba50 100644 --- a/tests/test_utils/test_config.py +++ b/tests/test_utils/test_config.py @@ -224,6 +224,81 @@ def test_merge_from_multiple_bases(): Config.fromfile(osp.join(data_path, 'config/m.py')) +def test_base_variables(): + for file in ['t.py', 't.json', 't.yaml']: + cfg_file = osp.join(data_path, f'config/{file}') + cfg = Config.fromfile(cfg_file) + assert isinstance(cfg, Config) + assert cfg.filename == cfg_file + # cfg.field + assert cfg.item1 == [1, 2] + assert cfg.item2.a == 0 + assert cfg.item3 is False + assert cfg.item4 == 'test' + assert cfg.item5 == dict(a=0, b=1) + assert cfg.item6 == [dict(a=0), dict(b=1)] + assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg.item8 == file + assert cfg.item9 == dict(a=0) + assert cfg.item10 == [3.1, 4.2, 5.3] + + # test nested base + for file in ['u.py', 'u.json', 'u.yaml']: + cfg_file = osp.join(data_path, f'config/{file}') + cfg = Config.fromfile(cfg_file) + assert isinstance(cfg, Config) + assert cfg.filename == cfg_file + # cfg.field + assert cfg.base == '_base_.item8' + assert cfg.item1 == [1, 2] + assert cfg.item2.a == 0 + assert cfg.item3 is False + assert cfg.item4 == 'test' + assert cfg.item5 == dict(a=0, b=1) + assert cfg.item6 == [dict(a=0), dict(b=1)] + assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg.item8 == 't.py' + assert cfg.item9 == dict(a=0) + assert cfg.item10 == [3.1, 4.2, 5.3] + assert cfg.item11 == 't.py' + assert cfg.item12 == dict(a=0) + assert cfg.item13 == [3.1, 4.2, 5.3] + assert cfg.item14 == [1, 2] + assert cfg.item15 == dict( + a=dict(b=dict(a=0)), + b=[False], + c=['test'], + d=[[{ + 'e': 0 + }], [{ + 'a': 0 + }, { + 'b': 1 + }]], + e=[1, 2]) + + # test reference assignment for py + cfg_file = osp.join(data_path, 'config/v.py') + cfg = Config.fromfile(cfg_file) + assert isinstance(cfg, Config) + assert cfg.filename == cfg_file + assert cfg.item21 == 't.py' + assert cfg.item22 == 't.py' + assert cfg.item23 == [3.1, 4.2, 5.3] + assert cfg.item24 == [3.1, 4.2, 5.3] + assert cfg.item25 == dict( + a=dict(b=[3.1, 4.2, 5.3]), + b=[[3.1, 4.2, 5.3]], + c=[[{ + 'e': 't.py' + }], [{ + 'a': 0 + }, { + 'b': 1 + }]], + e='t.py') + + def test_merge_recursive_bases(): cfg_file = osp.join(data_path, 'config/f.py') cfg = Config.fromfile(cfg_file)