From 452b3656a1d5abe64f815d3f15a705d40c3ffe3a Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 27 Apr 2022 19:43:12 +0800 Subject: [PATCH] [Fix] Fix dump config without self.filename (#202) * fix config * add docstring and unit test * update tutorial * update tutorial * fix markdown format * fix markdown format --- docs/zh_cn/tutorials/config.md | 65 ++++++++++++++++++++++++++++++++ mmengine/config/config.py | 37 +++++++++--------- tests/test_config/test_config.py | 20 +++++++++- 3 files changed, 104 insertions(+), 18 deletions(-) diff --git a/docs/zh_cn/tutorials/config.md b/docs/zh_cn/tutorials/config.md index 8e530490..72ed898e 100644 --- a/docs/zh_cn/tutorials/config.md +++ b/docs/zh_cn/tutorials/config.md @@ -194,6 +194,71 @@ a = {{_base_.model}} # 等价于 a = dict(type='ResNet', depth=50) ``` +## 配置文件的导出 + +在启动训练脚本时,用户可能通过传参的方式来修改配置文件的部分字段,为此我们提供了 `dump` +接口来导出更改后的配置文件。与读取配置文件类似,用户可以通过 `cfg.dump('config.xxx')` 来选择导出文件的格式。`dump` +同样可以导出有继承关系的配置文件,导出的文件可以被独立使用,不再依赖于 `_base_` 中定义的文件。 + +基于继承一节定义的 `resnet50.py` + +```python +_base_ = ['optimizer_cfg.py', 'runtime_cfg.py'] +model = dict(type='ResNet', depth=50) +``` + +我们将其加载后导出: + +```python +cfg = Config.fromfile('resnet50.py') +cfg.dump('resnet50_dump.py') +``` + +`dumped_resnet50.py` + +```python +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +gpu_ids = [0, 1] +model = dict(type='ResNet', depth=50) +``` + +类似的,我们可以导出 json、yaml 格式的配置文件 + +`dumped_resnet50.yaml` + +```yaml +gpu_ids: +- 0 +- 1 +model: + depth: 50 + type: ResNet +optimizer: + lr: 0.02 + momentum: 0.9 + type: SGD + weight_decay: 0.0001 +``` + +`dumped_resnet50.json` + +```json +{"optimizer": {"type": "SGD", "lr": 0.02, "momentum": 0.9, "weight_decay": 0.0001}, "gpu_ids": [0, 1], "model": {"type": "ResNet", "depth": 50}} +``` + +此外,`dump` 不仅能导出加载自文件的 `cfg`,还能导出加载自字典的 `cfg` + +```python +cfg = Config(dict(a=1, b=2)) +cfg.dump('demo.py') +``` + +`demo.py` +```python +a=1 +b=2 +``` + ## 其他进阶用法 这里介绍一下配置类的进阶用法,这些小技巧可能使用户开发和使用算法库更简单方便。 diff --git a/mmengine/config/config.py b/mmengine/config/config.py index d18164dc..d90d7423 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -12,6 +12,7 @@ import warnings from argparse import Action, ArgumentParser, Namespace from collections import abc from importlib import import_module +from pathlib import Path from typing import Any, Optional, Sequence, Tuple, Union from addict import Dict @@ -99,7 +100,8 @@ class Config: Args: cfg_dict (dict, optional): A config dictionary. Defaults to None. cfg_text (str, optional): Text of config. Defaults to None. - filename (str, optional): Name of config file. Defaults to None. + filename (str or Path, optional): Name of config file. + Defaults to None. Examples: >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) @@ -123,7 +125,8 @@ class Config: def __init__(self, cfg_dict: dict = None, cfg_text: Optional[str] = None, - filename: str = None): + filename: Optional[Union[str, Path]] = None): + filename = str(filename) if isinstance(filename, Path) else filename if cfg_dict is None: cfg_dict = dict() elif not isinstance(cfg_dict, dict): @@ -145,13 +148,13 @@ class Config: super().__setattr__('_text', text) @staticmethod - def fromfile(filename: str, + def fromfile(filename: Union[str, Path], use_predefined_variables: bool = True, import_custom_modules: bool = True) -> 'Config': """Build a Config instance from config file. Args: - filename (str): Name of config file. + filename (str or Path): Name of config file. use_predefined_variables (bool, optional): Whether to use predefined variables. Defaults to True. import_custom_modules (bool, optional): Whether to support @@ -160,6 +163,7 @@ class Config: Returns: Config: Config instance built from config file. """ + filename = str(filename) if isinstance(filename, Path) else filename cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables) if import_custom_modules and cfg_dict.get('custom_imports', None): @@ -675,32 +679,31 @@ class Config: super().__setattr__('_filename', _filename) super().__setattr__('_text', _text) - def dump(self, file: Optional[str] = None): + def dump(self, file: Optional[Union[str, Path]] = None): """Dump config to file or return config text. Args: - file (str, optional): If not specified, then the object + file (str or Path, optional): If not specified, then the object is dumped to a str, otherwise to a file specified by the filename. Defaults to None. Returns: str or None: Config text. """ - cfg_dict = super().__getattribute__('_cfg_dict').to_dict() - if self.filename.endswith('.py'): - if file is None: + file = str(file) if isinstance(file, Path) else file + cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() + if file is None: + if self.filename is None or self.filename.endswith('.py'): return self.pretty_text else: - with open(file, 'w', encoding='utf-8') as f: - f.write(self.pretty_text) - return None - else: - if file is None: file_format = self.filename.split('.')[-1] return dump(cfg_dict, file_format=file_format) - else: - dump(cfg_dict, file) - return None + elif file.endswith('.py'): + with open(file, 'w', encoding='utf-8') as f: + f.write(self.pretty_text) + else: + file_format = file.split('.')[-1] + return dump(cfg_dict, file=file, file_format=file_format) def merge_from_dict(self, options: dict, diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index bd4078b2..c4cccb2b 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -224,8 +224,26 @@ class TestConfig: pkl_cfg_filename = tmp_path / '_pickle.pkl' dump(cfg, pkl_cfg_filename) pkl_cfg = load(pkl_cfg_filename) - assert pkl_cfg._cfg_dict == cfg._cfg_dict + # Test dump config from dict. + cfg_dict = dict(a=1, b=2) + cfg = Config(cfg_dict) + assert cfg.pretty_text == cfg.dump() + # Test dump python format config. + dump_file = tmp_path / 'dump_from_dict.py' + cfg.dump(dump_file) + with open(dump_file, 'r') as f: + assert f.read() == 'a = 1\nb = 2\n' + # Test dump json format config. + dump_file = tmp_path / 'dump_from_dict.json' + cfg.dump(dump_file) + with open(dump_file, 'r') as f: + assert f.read() == '{"a": 1, "b": 2}' + # Test dump yaml format config. + dump_file = tmp_path / 'dump_from_dict.yaml' + cfg.dump(dump_file) + with open(dump_file, 'r') as f: + assert f.read() == 'a: 1\nb: 2\n' def test_pretty_text(self, tmp_path): cfg_file = osp.join(