mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix dump config without self.filename (#202)
* fix config * add docstring and unit test * update tutorial * update tutorial * fix markdown format * fix markdown format
This commit is contained in:
parent
4dcbd269aa
commit
452b3656a1
@ -194,6 +194,71 @@ a = {{_base_.model}}
|
|||||||
# 等价于 a = dict(type='ResNet', depth=50)
|
# 等价于 a = dict(type='ResNet', depth=50)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 配置文件的导出
|
||||||
|
|
||||||
|
在启动训练脚本时,用户可能通过传参的方式来修改配置文件的部分字段,为此我们提供了 `dump`
|
||||||
|
接口来导出更改后的配置文件。与读取配置文件类似,用户可以通过 `cfg.dump('config.xxx')` 来选择导出文件的格式。`dump`
|
||||||
|
同样可以导出有继承关系的配置文件,导出的文件可以被独立使用,不再依赖于 `_base_` 中定义的文件。
|
||||||
|
|
||||||
|
基于继承一节定义的 `resnet50.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
|
||||||
|
model = dict(type='ResNet', depth=50)
|
||||||
|
```
|
||||||
|
|
||||||
|
我们将其加载后导出:
|
||||||
|
|
||||||
|
```python
|
||||||
|
cfg = Config.fromfile('resnet50.py')
|
||||||
|
cfg.dump('resnet50_dump.py')
|
||||||
|
```
|
||||||
|
|
||||||
|
`dumped_resnet50.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
||||||
|
gpu_ids = [0, 1]
|
||||||
|
model = dict(type='ResNet', depth=50)
|
||||||
|
```
|
||||||
|
|
||||||
|
类似的,我们可以导出 json、yaml 格式的配置文件
|
||||||
|
|
||||||
|
`dumped_resnet50.yaml`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
gpu_ids:
|
||||||
|
- 0
|
||||||
|
- 1
|
||||||
|
model:
|
||||||
|
depth: 50
|
||||||
|
type: ResNet
|
||||||
|
optimizer:
|
||||||
|
lr: 0.02
|
||||||
|
momentum: 0.9
|
||||||
|
type: SGD
|
||||||
|
weight_decay: 0.0001
|
||||||
|
```
|
||||||
|
|
||||||
|
`dumped_resnet50.json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"optimizer": {"type": "SGD", "lr": 0.02, "momentum": 0.9, "weight_decay": 0.0001}, "gpu_ids": [0, 1], "model": {"type": "ResNet", "depth": 50}}
|
||||||
|
```
|
||||||
|
|
||||||
|
此外,`dump` 不仅能导出加载自文件的 `cfg`,还能导出加载自字典的 `cfg`
|
||||||
|
|
||||||
|
```python
|
||||||
|
cfg = Config(dict(a=1, b=2))
|
||||||
|
cfg.dump('demo.py')
|
||||||
|
```
|
||||||
|
|
||||||
|
`demo.py`
|
||||||
|
```python
|
||||||
|
a=1
|
||||||
|
b=2
|
||||||
|
```
|
||||||
|
|
||||||
## 其他进阶用法
|
## 其他进阶用法
|
||||||
|
|
||||||
这里介绍一下配置类的进阶用法,这些小技巧可能使用户开发和使用算法库更简单方便。
|
这里介绍一下配置类的进阶用法,这些小技巧可能使用户开发和使用算法库更简单方便。
|
||||||
|
@ -12,6 +12,7 @@ import warnings
|
|||||||
from argparse import Action, ArgumentParser, Namespace
|
from argparse import Action, ArgumentParser, Namespace
|
||||||
from collections import abc
|
from collections import abc
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Sequence, Tuple, Union
|
from typing import Any, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from addict import Dict
|
from addict import Dict
|
||||||
@ -99,7 +100,8 @@ class Config:
|
|||||||
Args:
|
Args:
|
||||||
cfg_dict (dict, optional): A config dictionary. Defaults to None.
|
cfg_dict (dict, optional): A config dictionary. Defaults to None.
|
||||||
cfg_text (str, optional): Text of config. Defaults to None.
|
cfg_text (str, optional): Text of config. Defaults to None.
|
||||||
filename (str, optional): Name of config file. Defaults to None.
|
filename (str or Path, optional): Name of config file.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||||
@ -123,7 +125,8 @@ class Config:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
cfg_dict: dict = None,
|
cfg_dict: dict = None,
|
||||||
cfg_text: Optional[str] = None,
|
cfg_text: Optional[str] = None,
|
||||||
filename: str = None):
|
filename: Optional[Union[str, Path]] = None):
|
||||||
|
filename = str(filename) if isinstance(filename, Path) else filename
|
||||||
if cfg_dict is None:
|
if cfg_dict is None:
|
||||||
cfg_dict = dict()
|
cfg_dict = dict()
|
||||||
elif not isinstance(cfg_dict, dict):
|
elif not isinstance(cfg_dict, dict):
|
||||||
@ -145,13 +148,13 @@ class Config:
|
|||||||
super().__setattr__('_text', text)
|
super().__setattr__('_text', text)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fromfile(filename: str,
|
def fromfile(filename: Union[str, Path],
|
||||||
use_predefined_variables: bool = True,
|
use_predefined_variables: bool = True,
|
||||||
import_custom_modules: bool = True) -> 'Config':
|
import_custom_modules: bool = True) -> 'Config':
|
||||||
"""Build a Config instance from config file.
|
"""Build a Config instance from config file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): Name of config file.
|
filename (str or Path): Name of config file.
|
||||||
use_predefined_variables (bool, optional): Whether to use
|
use_predefined_variables (bool, optional): Whether to use
|
||||||
predefined variables. Defaults to True.
|
predefined variables. Defaults to True.
|
||||||
import_custom_modules (bool, optional): Whether to support
|
import_custom_modules (bool, optional): Whether to support
|
||||||
@ -160,6 +163,7 @@ class Config:
|
|||||||
Returns:
|
Returns:
|
||||||
Config: Config instance built from config file.
|
Config: Config instance built from config file.
|
||||||
"""
|
"""
|
||||||
|
filename = str(filename) if isinstance(filename, Path) else filename
|
||||||
cfg_dict, cfg_text = Config._file2dict(filename,
|
cfg_dict, cfg_text = Config._file2dict(filename,
|
||||||
use_predefined_variables)
|
use_predefined_variables)
|
||||||
if import_custom_modules and cfg_dict.get('custom_imports', None):
|
if import_custom_modules and cfg_dict.get('custom_imports', None):
|
||||||
@ -675,32 +679,31 @@ class Config:
|
|||||||
super().__setattr__('_filename', _filename)
|
super().__setattr__('_filename', _filename)
|
||||||
super().__setattr__('_text', _text)
|
super().__setattr__('_text', _text)
|
||||||
|
|
||||||
def dump(self, file: Optional[str] = None):
|
def dump(self, file: Optional[Union[str, Path]] = None):
|
||||||
"""Dump config to file or return config text.
|
"""Dump config to file or return config text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (str, optional): If not specified, then the object
|
file (str or Path, optional): If not specified, then the object
|
||||||
is dumped to a str, otherwise to a file specified by the filename.
|
is dumped to a str, otherwise to a file specified by the filename.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str or None: Config text.
|
str or None: Config text.
|
||||||
"""
|
"""
|
||||||
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
|
file = str(file) if isinstance(file, Path) else file
|
||||||
if self.filename.endswith('.py'):
|
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
|
||||||
if file is None:
|
if file is None:
|
||||||
|
if self.filename is None or self.filename.endswith('.py'):
|
||||||
return self.pretty_text
|
return self.pretty_text
|
||||||
else:
|
else:
|
||||||
with open(file, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(self.pretty_text)
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if file is None:
|
|
||||||
file_format = self.filename.split('.')[-1]
|
file_format = self.filename.split('.')[-1]
|
||||||
return dump(cfg_dict, file_format=file_format)
|
return dump(cfg_dict, file_format=file_format)
|
||||||
else:
|
elif file.endswith('.py'):
|
||||||
dump(cfg_dict, file)
|
with open(file, 'w', encoding='utf-8') as f:
|
||||||
return None
|
f.write(self.pretty_text)
|
||||||
|
else:
|
||||||
|
file_format = file.split('.')[-1]
|
||||||
|
return dump(cfg_dict, file=file, file_format=file_format)
|
||||||
|
|
||||||
def merge_from_dict(self,
|
def merge_from_dict(self,
|
||||||
options: dict,
|
options: dict,
|
||||||
|
@ -224,8 +224,26 @@ class TestConfig:
|
|||||||
pkl_cfg_filename = tmp_path / '_pickle.pkl'
|
pkl_cfg_filename = tmp_path / '_pickle.pkl'
|
||||||
dump(cfg, pkl_cfg_filename)
|
dump(cfg, pkl_cfg_filename)
|
||||||
pkl_cfg = load(pkl_cfg_filename)
|
pkl_cfg = load(pkl_cfg_filename)
|
||||||
|
|
||||||
assert pkl_cfg._cfg_dict == cfg._cfg_dict
|
assert pkl_cfg._cfg_dict == cfg._cfg_dict
|
||||||
|
# Test dump config from dict.
|
||||||
|
cfg_dict = dict(a=1, b=2)
|
||||||
|
cfg = Config(cfg_dict)
|
||||||
|
assert cfg.pretty_text == cfg.dump()
|
||||||
|
# Test dump python format config.
|
||||||
|
dump_file = tmp_path / 'dump_from_dict.py'
|
||||||
|
cfg.dump(dump_file)
|
||||||
|
with open(dump_file, 'r') as f:
|
||||||
|
assert f.read() == 'a = 1\nb = 2\n'
|
||||||
|
# Test dump json format config.
|
||||||
|
dump_file = tmp_path / 'dump_from_dict.json'
|
||||||
|
cfg.dump(dump_file)
|
||||||
|
with open(dump_file, 'r') as f:
|
||||||
|
assert f.read() == '{"a": 1, "b": 2}'
|
||||||
|
# Test dump yaml format config.
|
||||||
|
dump_file = tmp_path / 'dump_from_dict.yaml'
|
||||||
|
cfg.dump(dump_file)
|
||||||
|
with open(dump_file, 'r') as f:
|
||||||
|
assert f.read() == 'a: 1\nb: 2\n'
|
||||||
|
|
||||||
def test_pretty_text(self, tmp_path):
|
def test_pretty_text(self, tmp_path):
|
||||||
cfg_file = osp.join(
|
cfg_file = osp.join(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user