[Feature] Compare the difference of two configs (#1260)

This commit is contained in:
vugia truong 2023-07-26 16:48:59 +09:00 committed by GitHub
parent 78205c3254
commit c8a1264568
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 1 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import ast import ast
import copy import copy
import difflib
import os import os
import os.path as osp import os.path as osp
import platform import platform
@ -17,6 +18,8 @@ from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union from typing import Any, Optional, Sequence, Tuple, Union
from addict import Dict from addict import Dict
from rich.console import Console
from rich.text import Text
from yapf.yapflib.yapf_api import FormatCode from yapf.yapflib.yapf_api import FormatCode
from mmengine.fileio import dump, load from mmengine.fileio import dump, load
@ -1432,7 +1435,8 @@ class Config:
use_mapping = _contain_invalid_identifier(input_dict) use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping: if use_mapping:
r += '{' r += '{'
for idx, (k, v) in enumerate(input_dict.items()): for idx, (k, v) in enumerate(
sorted(input_dict.items(), key=lambda x: str(x[0]))):
is_last = idx >= len(input_dict) - 1 is_last = idx >= len(input_dict) - 1
end = '' if outest_level or is_last else ',' end = '' if outest_level or is_last else ','
if isinstance(v, dict): if isinstance(v, dict):
@ -1605,6 +1609,36 @@ class Config:
Config._merge_a_into_b( Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
@staticmethod
def diff(cfg1: Union[str, 'Config'], cfg2: Union[str, 'Config']) -> str:
if isinstance(cfg1, str):
cfg1 = Config.fromfile(cfg1)
if isinstance(cfg2, str):
cfg2 = Config.fromfile(cfg2)
res = difflib.unified_diff(
cfg1.pretty_text.split('\n'), cfg2.pretty_text.split('\n'))
# Convert into rich format for better visualization
console = Console()
text = Text()
for line in res:
if line.startswith('+'):
color = 'bright_green'
elif line.startswith('-'):
color = 'bright_red'
else:
color = 'bright_white'
_text = Text(line + '\n')
_text.stylize(color)
text.append(_text)
with console.capture() as capture:
console.print(text)
return capture.get()
@staticmethod @staticmethod
def _is_lazy_import(filename: str) -> bool: def _is_lazy_import(filename: str) -> bool:
if not filename.endswith('.py'): if not filename.endswith('.py'):

View File

@ -0,0 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
a = 1
b = 2

View File

@ -0,0 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
b = 3
a = 1

View File

@ -183,6 +183,23 @@ class TestConfig:
with pytest.raises(KeyError): with pytest.raises(KeyError):
cfg.merge_from_dict(input_options, allow_list_keys=True) cfg.merge_from_dict(input_options, allow_list_keys=True)
def test_diff(self):
cfg1 = Config(dict(a=1, b=2))
cfg2 = Config(dict(a=1, b=3))
diff_str = \
'--- \n\n+++ \n\n@@ -1,3 +1,3 @@\n\n a = 1\n-b = 2\n+b = 3\n \n\n'
assert Config.diff(cfg1, cfg2) == diff_str
cfg1_file = osp.join(self.data_path, 'config/py_config/test_diff_1.py')
cfg1 = Config.fromfile(cfg1_file)
cfg2_file = osp.join(self.data_path, 'config/py_config/test_diff_2.py')
cfg2 = Config.fromfile(cfg2_file)
assert Config.diff(cfg1, cfg2) == diff_str
def test_auto_argparser(self): def test_auto_argparser(self):
# Temporarily make sys.argv only has one argument and keep backups # Temporarily make sys.argv only has one argument and keep backups
tmp = sys.argv[1:] tmp = sys.argv[1:]