mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Compare the difference of two configs (#1260)
This commit is contained in:
parent
78205c3254
commit
c8a1264568
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import ast
|
import ast
|
||||||
import copy
|
import copy
|
||||||
|
import difflib
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import platform
|
import platform
|
||||||
@ -17,6 +18,8 @@ from pathlib import Path
|
|||||||
from typing import Any, Optional, Sequence, Tuple, Union
|
from typing import Any, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from addict import Dict
|
from addict import Dict
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.text import Text
|
||||||
from yapf.yapflib.yapf_api import FormatCode
|
from yapf.yapflib.yapf_api import FormatCode
|
||||||
|
|
||||||
from mmengine.fileio import dump, load
|
from mmengine.fileio import dump, load
|
||||||
@ -1432,7 +1435,8 @@ class Config:
|
|||||||
use_mapping = _contain_invalid_identifier(input_dict)
|
use_mapping = _contain_invalid_identifier(input_dict)
|
||||||
if use_mapping:
|
if use_mapping:
|
||||||
r += '{'
|
r += '{'
|
||||||
for idx, (k, v) in enumerate(input_dict.items()):
|
for idx, (k, v) in enumerate(
|
||||||
|
sorted(input_dict.items(), key=lambda x: str(x[0]))):
|
||||||
is_last = idx >= len(input_dict) - 1
|
is_last = idx >= len(input_dict) - 1
|
||||||
end = '' if outest_level or is_last else ','
|
end = '' if outest_level or is_last else ','
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
@ -1605,6 +1609,36 @@ class Config:
|
|||||||
Config._merge_a_into_b(
|
Config._merge_a_into_b(
|
||||||
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
|
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def diff(cfg1: Union[str, 'Config'], cfg2: Union[str, 'Config']) -> str:
|
||||||
|
if isinstance(cfg1, str):
|
||||||
|
cfg1 = Config.fromfile(cfg1)
|
||||||
|
|
||||||
|
if isinstance(cfg2, str):
|
||||||
|
cfg2 = Config.fromfile(cfg2)
|
||||||
|
|
||||||
|
res = difflib.unified_diff(
|
||||||
|
cfg1.pretty_text.split('\n'), cfg2.pretty_text.split('\n'))
|
||||||
|
|
||||||
|
# Convert into rich format for better visualization
|
||||||
|
console = Console()
|
||||||
|
text = Text()
|
||||||
|
for line in res:
|
||||||
|
if line.startswith('+'):
|
||||||
|
color = 'bright_green'
|
||||||
|
elif line.startswith('-'):
|
||||||
|
color = 'bright_red'
|
||||||
|
else:
|
||||||
|
color = 'bright_white'
|
||||||
|
_text = Text(line + '\n')
|
||||||
|
_text.stylize(color)
|
||||||
|
text.append(_text)
|
||||||
|
|
||||||
|
with console.capture() as capture:
|
||||||
|
console.print(text)
|
||||||
|
|
||||||
|
return capture.get()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_lazy_import(filename: str) -> bool:
|
def _is_lazy_import(filename: str) -> bool:
|
||||||
if not filename.endswith('.py'):
|
if not filename.endswith('.py'):
|
||||||
|
3
tests/data/config/py_config/test_diff_1.py
Normal file
3
tests/data/config/py_config/test_diff_1.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
a = 1
|
||||||
|
b = 2
|
3
tests/data/config/py_config/test_diff_2.py
Normal file
3
tests/data/config/py_config/test_diff_2.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
b = 3
|
||||||
|
a = 1
|
@ -183,6 +183,23 @@ class TestConfig:
|
|||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
cfg.merge_from_dict(input_options, allow_list_keys=True)
|
cfg.merge_from_dict(input_options, allow_list_keys=True)
|
||||||
|
|
||||||
|
def test_diff(self):
|
||||||
|
cfg1 = Config(dict(a=1, b=2))
|
||||||
|
cfg2 = Config(dict(a=1, b=3))
|
||||||
|
|
||||||
|
diff_str = \
|
||||||
|
'--- \n\n+++ \n\n@@ -1,3 +1,3 @@\n\n a = 1\n-b = 2\n+b = 3\n \n\n'
|
||||||
|
|
||||||
|
assert Config.diff(cfg1, cfg2) == diff_str
|
||||||
|
|
||||||
|
cfg1_file = osp.join(self.data_path, 'config/py_config/test_diff_1.py')
|
||||||
|
cfg1 = Config.fromfile(cfg1_file)
|
||||||
|
|
||||||
|
cfg2_file = osp.join(self.data_path, 'config/py_config/test_diff_2.py')
|
||||||
|
cfg2 = Config.fromfile(cfg2_file)
|
||||||
|
|
||||||
|
assert Config.diff(cfg1, cfg2) == diff_str
|
||||||
|
|
||||||
def test_auto_argparser(self):
|
def test_auto_argparser(self):
|
||||||
# Temporarily make sys.argv only has one argument and keep backups
|
# Temporarily make sys.argv only has one argument and keep backups
|
||||||
tmp = sys.argv[1:]
|
tmp = sys.argv[1:]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user