[Fix] Fix pickle the Python style config (#1241)

This commit is contained in:
Mashiro 2023-07-11 20:37:48 +08:00 committed by GitHub
parent b2295a258c
commit 955b5712c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 143 additions and 29 deletions

View File

@ -261,6 +261,24 @@ class ConfigDict(Dict):
for key, value in merged.items():
self[key] = value
def __getstate__(self):
state = {}
for key, value in super().items():
state[key] = value
return state
def __setstate__(self, state):
for key, value in state.items():
self[key] = value
def __eq__(self, other):
if isinstance(other, ConfigDict):
return other.to_dict() == self.to_dict()
elif isinstance(other, dict):
return {k: v for k, v in self.items()} == other
else:
return False
def _to_lazy_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built."""
@ -281,8 +299,8 @@ class ConfigDict(Dict):
return _to_dict(self)
def to_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built."""
"""Convert the ConfigDict to a normal dictionary recursively, and
convert the ``LazyObject`` or ``LazyAttr`` to string."""
return _lazy2string(self, dict_type=dict)
@ -363,12 +381,14 @@ class Config:
.. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html
""" # noqa: E501
def __init__(self,
cfg_dict: dict = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
format_python_code: bool = True):
def __init__(
self,
cfg_dict: dict = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
format_python_code: bool = True,
):
filename = str(filename) if isinstance(filename, Path) else filename
if cfg_dict is None:
cfg_dict = dict()
@ -384,6 +404,9 @@ class Config:
super().__setattr__('_cfg_dict', cfg_dict)
super().__setattr__('_filename', filename)
super().__setattr__('_format_python_code', format_python_code)
if not hasattr(self, '_imported_names'):
super().__setattr__('_imported_names', set())
if cfg_text:
text = cfg_text
elif filename:
@ -445,7 +468,8 @@ class Config:
cfg_dict,
cfg_text=cfg_text,
filename=filename,
env_variables=env_variables)
env_variables=env_variables,
)
else:
# Enable lazy import when parsing the config.
# Using try-except to make sure ``ConfigDict.lazy`` will be reset
@ -457,15 +481,10 @@ class Config:
except Exception as e:
raise e
finally:
# disable lazy import to get the real type. See more details
# about lazy in the docstring of ConfigDict
ConfigDict.lazy = False
# delete builtin imported objects
for key, value in list(cfg_dict._to_lazy_dict().items()):
if isinstance(value, (types.FunctionType, types.ModuleType)):
cfg_dict.pop(key)
# disable lazy import to get the real type. See more details about
# lazy in the docstring of ConfigDict
cfg = Config(
cfg_dict,
filename=filename,
@ -996,7 +1015,7 @@ class Config:
# accessed, but will not be dumped by default.
with open(filename, encoding='utf-8') as f:
global_dict = {'LazyObject': LazyObject}
global_dict = {'LazyObject': LazyObject, '__file__': filename}
base_dict = {}
parsed_codes = ast.parse(f.read())
@ -1470,9 +1489,13 @@ class Config:
def __iter__(self):
return iter(self._cfg_dict)
def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], dict]:
return (self._cfg_dict, self._filename, self._text,
self._env_variables)
def __getstate__(
self
) -> Tuple[dict, Optional[str], Optional[str], dict, bool, set]:
state = (self._cfg_dict, self._filename, self._text,
self._env_variables, self._format_python_code,
self._imported_names)
return state
def __deepcopy__(self, memo):
cls = self.__class__
@ -1495,12 +1518,13 @@ class Config:
copy = __copy__
def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
dict]):
_cfg_dict, _filename, _text, _env_variables = state
super().__setattr__('_cfg_dict', _cfg_dict)
super().__setattr__('_filename', _filename)
super().__setattr__('_text', _text)
super().__setattr__('_text', _env_variables)
dict, bool, set]):
super().__setattr__('_cfg_dict', state[0])
super().__setattr__('_filename', state[1])
super().__setattr__('_text', state[2])
super().__setattr__('_env_variables', state[3])
super().__setattr__('_format_python_code', state[4])
super().__setattr__('_imported_names', state[5])
def dump(self, file: Optional[Union[str, Path]] = None):
"""Dump config to file or return config text.
@ -1616,8 +1640,8 @@ class Config:
return False
def _to_lazy_dict(self, keep_imported: bool = False) -> dict:
"""Convert config object to dictionary and filter the imported
object."""
"""Convert config object to dictionary with lazy object, and filter the
imported object."""
res = self._cfg_dict._to_lazy_dict()
if hasattr(self, '_imported_names') and not keep_imported:
res = {
@ -1637,7 +1661,14 @@ class Config:
If you import third-party objects in the config file, all imported
objects will be converted to a string like ``torch.optim.SGD``
"""
return self._cfg_dict.to_dict()
cfg_dict = self._cfg_dict.to_dict()
if hasattr(self, '_imported_names') and not keep_imported:
cfg_dict = {
key: value
for key, value in cfg_dict.items()
if key not in self._imported_names
}
return cfg_dict
class DictAction(Action):

View File

@ -121,6 +121,16 @@ class LazyObject:
__repr__ = __str__
# `pickle.dump` will try to get the `__getstate__` and `__setstate__`
# methods of the dumped object. If these two methods are not defined,
# LazyObject will return a `__getstate__` LazyObject` or `__setstate__`
# LazyObject.
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__ = state
class LazyAttr:
"""The attribute of the LazyObject.
@ -219,3 +229,13 @@ class LazyAttr:
return self.name
__repr__ = __str__
# `pickle.dump` will try to get the `__getstate__` and `__setstate__`
# methods of the dumped object. If these two methods are not defined,
# LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__`
# LazyAttr.
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__ = state

View File

@ -165,6 +165,8 @@ def _is_builtin_module(module_name: str) -> bool:
return False
if module_name.startswith('mmengine.config'):
return True
if module_name in sys.builtin_module_names:
return True
spec = find_spec(module_name.split('.')[0])
# Module not found
if spec is None:
@ -314,6 +316,15 @@ class ImportTransformer(ast.NodeTransformer):
# Built-in modules will not be parsed as LazyObject
module = f'{node.level*"."}{node.module}'
if _is_builtin_module(module):
# Make sure builtin module will be added into `self.imported_obj`
for alias in node.names:
if alias.asname is not None:
self.imported_obj.add(alias.asname)
elif alias.name == '*':
raise ConfigParsingError(
'Cannot import * from non-base config')
else:
self.imported_obj.add(alias.name)
return node
if module in self.base_dict:
@ -409,6 +420,8 @@ class ImportTransformer(ast.NodeTransformer):
alias = alias_list[0]
if alias.asname is not None:
self.imported_obj.add(alias.asname)
if _is_builtin_module(alias.name.split('.')[0]):
return node
return ast.parse( # type: ignore
f'{alias.asname} = LazyObject('
f'"{alias.name}",'

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from functools import partial
from itertools import chain
from os.path import basename
from os.path import exists as ex
from os.path import splitext
import numpy as np
path = osp.join('a', 'b')
name, suffix = splitext('a/b.py')
chained = list(chain([1, 2], [3, 4]))
existed = ex(__file__)
cfgname = partial(basename, __file__)()

View File

@ -3,6 +3,7 @@ import argparse
import copy
import os
import os.path as osp
import pickle
import platform
import sys
import tempfile
@ -951,6 +952,19 @@ class TestConfig:
assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss)
DefaultScope._instance_dict.pop('test1')
def test_pickle(self):
# Text style config
cfg_path = osp.join(self.data_path, 'config/py_config/test_py_base.py')
cfg = Config.fromfile(cfg_path)
pickled = pickle.loads(pickle.dumps(cfg))
assert pickled.__dict__ == cfg.__dict__
cfg_path = osp.join(self.data_path,
'config/lazy_module_config/toy_model.py')
cfg = Config.fromfile(cfg_path)
pickled = pickle.loads(pickle.dumps(cfg))
assert pickled.__dict__ == cfg.__dict__
def test_lazy_import(self, tmp_path):
lazy_import_cfg_path = osp.join(
self.data_path, 'config/lazy_module_config/toy_model.py')
@ -1036,6 +1050,26 @@ error_attr = mmengine.error_attr
osp.join(self.data_path,
'config/lazy_module_config/error_mix_using2.py'))
cfg = Config.fromfile(
osp.join(self.data_path,
'config/lazy_module_config/test_mix_builtin.py'))
assert cfg.path == osp.join('a', 'b')
assert cfg.name == 'a/b'
assert cfg.suffix == '.py'
assert cfg.chained == [1, 2, 3, 4]
assert cfg.existed
assert cfg.cfgname == 'test_mix_builtin.py'
cfg_dict = cfg.to_dict()
dumped_cfg_path = tmp_path / 'test_dump_lazy.py'
cfg.dump(dumped_cfg_path)
dumped_cfg = Config.fromfile(dumped_cfg_path)
assert set(dumped_cfg.keys()) == {
'path', 'name', 'suffix', 'chained', 'existed', 'cfgname'
}
assert dumped_cfg.to_dict() == cfg.to_dict()
class TestConfigDict(TestCase):