diff --git a/mmengine/config/config.py b/mmengine/config/config.py index ddc24665..7d9e9423 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -261,6 +261,24 @@ class ConfigDict(Dict): for key, value in merged.items(): self[key] = value + def __getstate__(self): + state = {} + for key, value in super().items(): + state[key] = value + return state + + def __setstate__(self, state): + for key, value in state.items(): + self[key] = value + + def __eq__(self, other): + if isinstance(other, ConfigDict): + return other.to_dict() == self.to_dict() + elif isinstance(other, dict): + return {k: v for k, v in self.items()} == other + else: + return False + def _to_lazy_dict(self): """Convert the ConfigDict to a normal dictionary recursively, and keep the ``LazyObject`` or ``LazyAttr`` object not built.""" @@ -281,8 +299,8 @@ class ConfigDict(Dict): return _to_dict(self) def to_dict(self): - """Convert the ConfigDict to a normal dictionary recursively, and keep - the ``LazyObject`` or ``LazyAttr`` object not built.""" + """Convert the ConfigDict to a normal dictionary recursively, and + convert the ``LazyObject`` or ``LazyAttr`` to string.""" return _lazy2string(self, dict_type=dict) @@ -363,12 +381,14 @@ class Config: .. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html """ # noqa: E501 - def __init__(self, - cfg_dict: dict = None, - cfg_text: Optional[str] = None, - filename: Optional[Union[str, Path]] = None, - env_variables: Optional[dict] = None, - format_python_code: bool = True): + def __init__( + self, + cfg_dict: dict = None, + cfg_text: Optional[str] = None, + filename: Optional[Union[str, Path]] = None, + env_variables: Optional[dict] = None, + format_python_code: bool = True, + ): filename = str(filename) if isinstance(filename, Path) else filename if cfg_dict is None: cfg_dict = dict() @@ -384,6 +404,9 @@ class Config: super().__setattr__('_cfg_dict', cfg_dict) super().__setattr__('_filename', filename) super().__setattr__('_format_python_code', format_python_code) + if not hasattr(self, '_imported_names'): + super().__setattr__('_imported_names', set()) + if cfg_text: text = cfg_text elif filename: @@ -445,7 +468,8 @@ class Config: cfg_dict, cfg_text=cfg_text, filename=filename, - env_variables=env_variables) + env_variables=env_variables, + ) else: # Enable lazy import when parsing the config. # Using try-except to make sure ``ConfigDict.lazy`` will be reset @@ -457,15 +481,10 @@ class Config: except Exception as e: raise e finally: + # disable lazy import to get the real type. See more details + # about lazy in the docstring of ConfigDict ConfigDict.lazy = False - # delete builtin imported objects - for key, value in list(cfg_dict._to_lazy_dict().items()): - if isinstance(value, (types.FunctionType, types.ModuleType)): - cfg_dict.pop(key) - - # disable lazy import to get the real type. See more details about - # lazy in the docstring of ConfigDict cfg = Config( cfg_dict, filename=filename, @@ -996,7 +1015,7 @@ class Config: # accessed, but will not be dumped by default. with open(filename, encoding='utf-8') as f: - global_dict = {'LazyObject': LazyObject} + global_dict = {'LazyObject': LazyObject, '__file__': filename} base_dict = {} parsed_codes = ast.parse(f.read()) @@ -1470,9 +1489,13 @@ class Config: def __iter__(self): return iter(self._cfg_dict) - def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], dict]: - return (self._cfg_dict, self._filename, self._text, - self._env_variables) + def __getstate__( + self + ) -> Tuple[dict, Optional[str], Optional[str], dict, bool, set]: + state = (self._cfg_dict, self._filename, self._text, + self._env_variables, self._format_python_code, + self._imported_names) + return state def __deepcopy__(self, memo): cls = self.__class__ @@ -1495,12 +1518,13 @@ class Config: copy = __copy__ def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str], - dict]): - _cfg_dict, _filename, _text, _env_variables = state - super().__setattr__('_cfg_dict', _cfg_dict) - super().__setattr__('_filename', _filename) - super().__setattr__('_text', _text) - super().__setattr__('_text', _env_variables) + dict, bool, set]): + super().__setattr__('_cfg_dict', state[0]) + super().__setattr__('_filename', state[1]) + super().__setattr__('_text', state[2]) + super().__setattr__('_env_variables', state[3]) + super().__setattr__('_format_python_code', state[4]) + super().__setattr__('_imported_names', state[5]) def dump(self, file: Optional[Union[str, Path]] = None): """Dump config to file or return config text. @@ -1616,8 +1640,8 @@ class Config: return False def _to_lazy_dict(self, keep_imported: bool = False) -> dict: - """Convert config object to dictionary and filter the imported - object.""" + """Convert config object to dictionary with lazy object, and filter the + imported object.""" res = self._cfg_dict._to_lazy_dict() if hasattr(self, '_imported_names') and not keep_imported: res = { @@ -1637,7 +1661,14 @@ class Config: If you import third-party objects in the config file, all imported objects will be converted to a string like ``torch.optim.SGD`` """ - return self._cfg_dict.to_dict() + cfg_dict = self._cfg_dict.to_dict() + if hasattr(self, '_imported_names') and not keep_imported: + cfg_dict = { + key: value + for key, value in cfg_dict.items() + if key not in self._imported_names + } + return cfg_dict class DictAction(Action): diff --git a/mmengine/config/lazy.py b/mmengine/config/lazy.py index 9018b49a..e83cce7c 100644 --- a/mmengine/config/lazy.py +++ b/mmengine/config/lazy.py @@ -121,6 +121,16 @@ class LazyObject: __repr__ = __str__ + # `pickle.dump` will try to get the `__getstate__` and `__setstate__` + # methods of the dumped object. If these two methods are not defined, + # LazyObject will return a `__getstate__` LazyObject` or `__setstate__` + # LazyObject. + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + class LazyAttr: """The attribute of the LazyObject. @@ -219,3 +229,13 @@ class LazyAttr: return self.name __repr__ = __str__ + + # `pickle.dump` will try to get the `__getstate__` and `__setstate__` + # methods of the dumped object. If these two methods are not defined, + # LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__` + # LazyAttr. + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py index 6bd6fd94..7f56fd57 100644 --- a/mmengine/config/utils.py +++ b/mmengine/config/utils.py @@ -165,6 +165,8 @@ def _is_builtin_module(module_name: str) -> bool: return False if module_name.startswith('mmengine.config'): return True + if module_name in sys.builtin_module_names: + return True spec = find_spec(module_name.split('.')[0]) # Module not found if spec is None: @@ -314,6 +316,15 @@ class ImportTransformer(ast.NodeTransformer): # Built-in modules will not be parsed as LazyObject module = f'{node.level*"."}{node.module}' if _is_builtin_module(module): + # Make sure builtin module will be added into `self.imported_obj` + for alias in node.names: + if alias.asname is not None: + self.imported_obj.add(alias.asname) + elif alias.name == '*': + raise ConfigParsingError( + 'Cannot import * from non-base config') + else: + self.imported_obj.add(alias.name) return node if module in self.base_dict: @@ -409,6 +420,8 @@ class ImportTransformer(ast.NodeTransformer): alias = alias_list[0] if alias.asname is not None: self.imported_obj.add(alias.asname) + if _is_builtin_module(alias.name.split('.')[0]): + return node return ast.parse( # type: ignore f'{alias.asname} = LazyObject(' f'"{alias.name}",' diff --git a/tests/data/config/lazy_module_config/test_mix_builtin.py b/tests/data/config/lazy_module_config/test_mix_builtin.py new file mode 100644 index 00000000..e36da58a --- /dev/null +++ b/tests/data/config/lazy_module_config/test_mix_builtin.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from functools import partial +from itertools import chain +from os.path import basename +from os.path import exists as ex +from os.path import splitext + +import numpy as np + +path = osp.join('a', 'b') +name, suffix = splitext('a/b.py') +chained = list(chain([1, 2], [3, 4])) +existed = ex(__file__) +cfgname = partial(basename, __file__)() + diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 2e97b0af..660d6b0a 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -3,6 +3,7 @@ import argparse import copy import os import os.path as osp +import pickle import platform import sys import tempfile @@ -951,6 +952,19 @@ class TestConfig: assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss) DefaultScope._instance_dict.pop('test1') + def test_pickle(self): + # Text style config + cfg_path = osp.join(self.data_path, 'config/py_config/test_py_base.py') + cfg = Config.fromfile(cfg_path) + pickled = pickle.loads(pickle.dumps(cfg)) + assert pickled.__dict__ == cfg.__dict__ + + cfg_path = osp.join(self.data_path, + 'config/lazy_module_config/toy_model.py') + cfg = Config.fromfile(cfg_path) + pickled = pickle.loads(pickle.dumps(cfg)) + assert pickled.__dict__ == cfg.__dict__ + def test_lazy_import(self, tmp_path): lazy_import_cfg_path = osp.join( self.data_path, 'config/lazy_module_config/toy_model.py') @@ -1036,6 +1050,26 @@ error_attr = mmengine.error_attr osp.join(self.data_path, 'config/lazy_module_config/error_mix_using2.py')) + cfg = Config.fromfile( + osp.join(self.data_path, + 'config/lazy_module_config/test_mix_builtin.py')) + assert cfg.path == osp.join('a', 'b') + assert cfg.name == 'a/b' + assert cfg.suffix == '.py' + assert cfg.chained == [1, 2, 3, 4] + assert cfg.existed + assert cfg.cfgname == 'test_mix_builtin.py' + + cfg_dict = cfg.to_dict() + dumped_cfg_path = tmp_path / 'test_dump_lazy.py' + cfg.dump(dumped_cfg_path) + dumped_cfg = Config.fromfile(dumped_cfg_path) + + assert set(dumped_cfg.keys()) == { + 'path', 'name', 'suffix', 'chained', 'existed', 'cfgname' + } + assert dumped_cfg.to_dict() == cfg.to_dict() + class TestConfigDict(TestCase):