mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix dump config without self.filename (#202)
* fix config * add docstring and unit test * update tutorial * update tutorial * fix markdown format * fix markdown format
This commit is contained in:
parent
4dcbd269aa
commit
452b3656a1
@ -194,6 +194,71 @@ a = {{_base_.model}}
|
||||
# 等价于 a = dict(type='ResNet', depth=50)
|
||||
```
|
||||
|
||||
## 配置文件的导出
|
||||
|
||||
在启动训练脚本时,用户可能通过传参的方式来修改配置文件的部分字段,为此我们提供了 `dump`
|
||||
接口来导出更改后的配置文件。与读取配置文件类似,用户可以通过 `cfg.dump('config.xxx')` 来选择导出文件的格式。`dump`
|
||||
同样可以导出有继承关系的配置文件,导出的文件可以被独立使用,不再依赖于 `_base_` 中定义的文件。
|
||||
|
||||
基于继承一节定义的 `resnet50.py`
|
||||
|
||||
```python
|
||||
_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
|
||||
model = dict(type='ResNet', depth=50)
|
||||
```
|
||||
|
||||
我们将其加载后导出:
|
||||
|
||||
```python
|
||||
cfg = Config.fromfile('resnet50.py')
|
||||
cfg.dump('resnet50_dump.py')
|
||||
```
|
||||
|
||||
`dumped_resnet50.py`
|
||||
|
||||
```python
|
||||
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
||||
gpu_ids = [0, 1]
|
||||
model = dict(type='ResNet', depth=50)
|
||||
```
|
||||
|
||||
类似的,我们可以导出 json、yaml 格式的配置文件
|
||||
|
||||
`dumped_resnet50.yaml`
|
||||
|
||||
```yaml
|
||||
gpu_ids:
|
||||
- 0
|
||||
- 1
|
||||
model:
|
||||
depth: 50
|
||||
type: ResNet
|
||||
optimizer:
|
||||
lr: 0.02
|
||||
momentum: 0.9
|
||||
type: SGD
|
||||
weight_decay: 0.0001
|
||||
```
|
||||
|
||||
`dumped_resnet50.json`
|
||||
|
||||
```json
|
||||
{"optimizer": {"type": "SGD", "lr": 0.02, "momentum": 0.9, "weight_decay": 0.0001}, "gpu_ids": [0, 1], "model": {"type": "ResNet", "depth": 50}}
|
||||
```
|
||||
|
||||
此外,`dump` 不仅能导出加载自文件的 `cfg`,还能导出加载自字典的 `cfg`
|
||||
|
||||
```python
|
||||
cfg = Config(dict(a=1, b=2))
|
||||
cfg.dump('demo.py')
|
||||
```
|
||||
|
||||
`demo.py`
|
||||
```python
|
||||
a=1
|
||||
b=2
|
||||
```
|
||||
|
||||
## 其他进阶用法
|
||||
|
||||
这里介绍一下配置类的进阶用法,这些小技巧可能使用户开发和使用算法库更简单方便。
|
||||
|
@ -12,6 +12,7 @@ import warnings
|
||||
from argparse import Action, ArgumentParser, Namespace
|
||||
from collections import abc
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
from addict import Dict
|
||||
@ -99,7 +100,8 @@ class Config:
|
||||
Args:
|
||||
cfg_dict (dict, optional): A config dictionary. Defaults to None.
|
||||
cfg_text (str, optional): Text of config. Defaults to None.
|
||||
filename (str, optional): Name of config file. Defaults to None.
|
||||
filename (str or Path, optional): Name of config file.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
@ -123,7 +125,8 @@ class Config:
|
||||
def __init__(self,
|
||||
cfg_dict: dict = None,
|
||||
cfg_text: Optional[str] = None,
|
||||
filename: str = None):
|
||||
filename: Optional[Union[str, Path]] = None):
|
||||
filename = str(filename) if isinstance(filename, Path) else filename
|
||||
if cfg_dict is None:
|
||||
cfg_dict = dict()
|
||||
elif not isinstance(cfg_dict, dict):
|
||||
@ -145,13 +148,13 @@ class Config:
|
||||
super().__setattr__('_text', text)
|
||||
|
||||
@staticmethod
|
||||
def fromfile(filename: str,
|
||||
def fromfile(filename: Union[str, Path],
|
||||
use_predefined_variables: bool = True,
|
||||
import_custom_modules: bool = True) -> 'Config':
|
||||
"""Build a Config instance from config file.
|
||||
|
||||
Args:
|
||||
filename (str): Name of config file.
|
||||
filename (str or Path): Name of config file.
|
||||
use_predefined_variables (bool, optional): Whether to use
|
||||
predefined variables. Defaults to True.
|
||||
import_custom_modules (bool, optional): Whether to support
|
||||
@ -160,6 +163,7 @@ class Config:
|
||||
Returns:
|
||||
Config: Config instance built from config file.
|
||||
"""
|
||||
filename = str(filename) if isinstance(filename, Path) else filename
|
||||
cfg_dict, cfg_text = Config._file2dict(filename,
|
||||
use_predefined_variables)
|
||||
if import_custom_modules and cfg_dict.get('custom_imports', None):
|
||||
@ -675,32 +679,31 @@ class Config:
|
||||
super().__setattr__('_filename', _filename)
|
||||
super().__setattr__('_text', _text)
|
||||
|
||||
def dump(self, file: Optional[str] = None):
|
||||
def dump(self, file: Optional[Union[str, Path]] = None):
|
||||
"""Dump config to file or return config text.
|
||||
|
||||
Args:
|
||||
file (str, optional): If not specified, then the object
|
||||
file (str or Path, optional): If not specified, then the object
|
||||
is dumped to a str, otherwise to a file specified by the filename.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
str or None: Config text.
|
||||
"""
|
||||
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
|
||||
if self.filename.endswith('.py'):
|
||||
if file is None:
|
||||
file = str(file) if isinstance(file, Path) else file
|
||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
|
||||
if file is None:
|
||||
if self.filename is None or self.filename.endswith('.py'):
|
||||
return self.pretty_text
|
||||
else:
|
||||
with open(file, 'w', encoding='utf-8') as f:
|
||||
f.write(self.pretty_text)
|
||||
return None
|
||||
else:
|
||||
if file is None:
|
||||
file_format = self.filename.split('.')[-1]
|
||||
return dump(cfg_dict, file_format=file_format)
|
||||
else:
|
||||
dump(cfg_dict, file)
|
||||
return None
|
||||
elif file.endswith('.py'):
|
||||
with open(file, 'w', encoding='utf-8') as f:
|
||||
f.write(self.pretty_text)
|
||||
else:
|
||||
file_format = file.split('.')[-1]
|
||||
return dump(cfg_dict, file=file, file_format=file_format)
|
||||
|
||||
def merge_from_dict(self,
|
||||
options: dict,
|
||||
|
@ -224,8 +224,26 @@ class TestConfig:
|
||||
pkl_cfg_filename = tmp_path / '_pickle.pkl'
|
||||
dump(cfg, pkl_cfg_filename)
|
||||
pkl_cfg = load(pkl_cfg_filename)
|
||||
|
||||
assert pkl_cfg._cfg_dict == cfg._cfg_dict
|
||||
# Test dump config from dict.
|
||||
cfg_dict = dict(a=1, b=2)
|
||||
cfg = Config(cfg_dict)
|
||||
assert cfg.pretty_text == cfg.dump()
|
||||
# Test dump python format config.
|
||||
dump_file = tmp_path / 'dump_from_dict.py'
|
||||
cfg.dump(dump_file)
|
||||
with open(dump_file, 'r') as f:
|
||||
assert f.read() == 'a = 1\nb = 2\n'
|
||||
# Test dump json format config.
|
||||
dump_file = tmp_path / 'dump_from_dict.json'
|
||||
cfg.dump(dump_file)
|
||||
with open(dump_file, 'r') as f:
|
||||
assert f.read() == '{"a": 1, "b": 2}'
|
||||
# Test dump yaml format config.
|
||||
dump_file = tmp_path / 'dump_from_dict.yaml'
|
||||
cfg.dump(dump_file)
|
||||
with open(dump_file, 'r') as f:
|
||||
assert f.read() == 'a: 1\nb: 2\n'
|
||||
|
||||
def test_pretty_text(self, tmp_path):
|
||||
cfg_file = osp.join(
|
||||
|
Loading…
x
Reference in New Issue
Block a user