[Fix] Fix pickle the Python style config (#1241)

This commit is contained in:
Mashiro 2023-07-11 20:37:48 +08:00 committed by GitHub
parent b2295a258c
commit 955b5712c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 143 additions and 29 deletions

View File

@ -261,6 +261,24 @@ class ConfigDict(Dict):
for key, value in merged.items(): for key, value in merged.items():
self[key] = value self[key] = value
def __getstate__(self):
state = {}
for key, value in super().items():
state[key] = value
return state
def __setstate__(self, state):
for key, value in state.items():
self[key] = value
def __eq__(self, other):
if isinstance(other, ConfigDict):
return other.to_dict() == self.to_dict()
elif isinstance(other, dict):
return {k: v for k, v in self.items()} == other
else:
return False
def _to_lazy_dict(self): def _to_lazy_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep """Convert the ConfigDict to a normal dictionary recursively, and keep
the ``LazyObject`` or ``LazyAttr`` object not built.""" the ``LazyObject`` or ``LazyAttr`` object not built."""
@ -281,8 +299,8 @@ class ConfigDict(Dict):
return _to_dict(self) return _to_dict(self)
def to_dict(self): def to_dict(self):
"""Convert the ConfigDict to a normal dictionary recursively, and keep """Convert the ConfigDict to a normal dictionary recursively, and
the ``LazyObject`` or ``LazyAttr`` object not built.""" convert the ``LazyObject`` or ``LazyAttr`` to string."""
return _lazy2string(self, dict_type=dict) return _lazy2string(self, dict_type=dict)
@ -363,12 +381,14 @@ class Config:
.. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html .. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html
""" # noqa: E501 """ # noqa: E501
def __init__(self, def __init__(
self,
cfg_dict: dict = None, cfg_dict: dict = None,
cfg_text: Optional[str] = None, cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None, filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None, env_variables: Optional[dict] = None,
format_python_code: bool = True): format_python_code: bool = True,
):
filename = str(filename) if isinstance(filename, Path) else filename filename = str(filename) if isinstance(filename, Path) else filename
if cfg_dict is None: if cfg_dict is None:
cfg_dict = dict() cfg_dict = dict()
@ -384,6 +404,9 @@ class Config:
super().__setattr__('_cfg_dict', cfg_dict) super().__setattr__('_cfg_dict', cfg_dict)
super().__setattr__('_filename', filename) super().__setattr__('_filename', filename)
super().__setattr__('_format_python_code', format_python_code) super().__setattr__('_format_python_code', format_python_code)
if not hasattr(self, '_imported_names'):
super().__setattr__('_imported_names', set())
if cfg_text: if cfg_text:
text = cfg_text text = cfg_text
elif filename: elif filename:
@ -445,7 +468,8 @@ class Config:
cfg_dict, cfg_dict,
cfg_text=cfg_text, cfg_text=cfg_text,
filename=filename, filename=filename,
env_variables=env_variables) env_variables=env_variables,
)
else: else:
# Enable lazy import when parsing the config. # Enable lazy import when parsing the config.
# Using try-except to make sure ``ConfigDict.lazy`` will be reset # Using try-except to make sure ``ConfigDict.lazy`` will be reset
@ -457,15 +481,10 @@ class Config:
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
# disable lazy import to get the real type. See more details
# about lazy in the docstring of ConfigDict
ConfigDict.lazy = False ConfigDict.lazy = False
# delete builtin imported objects
for key, value in list(cfg_dict._to_lazy_dict().items()):
if isinstance(value, (types.FunctionType, types.ModuleType)):
cfg_dict.pop(key)
# disable lazy import to get the real type. See more details about
# lazy in the docstring of ConfigDict
cfg = Config( cfg = Config(
cfg_dict, cfg_dict,
filename=filename, filename=filename,
@ -996,7 +1015,7 @@ class Config:
# accessed, but will not be dumped by default. # accessed, but will not be dumped by default.
with open(filename, encoding='utf-8') as f: with open(filename, encoding='utf-8') as f:
global_dict = {'LazyObject': LazyObject} global_dict = {'LazyObject': LazyObject, '__file__': filename}
base_dict = {} base_dict = {}
parsed_codes = ast.parse(f.read()) parsed_codes = ast.parse(f.read())
@ -1470,9 +1489,13 @@ class Config:
def __iter__(self): def __iter__(self):
return iter(self._cfg_dict) return iter(self._cfg_dict)
def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], dict]: def __getstate__(
return (self._cfg_dict, self._filename, self._text, self
self._env_variables) ) -> Tuple[dict, Optional[str], Optional[str], dict, bool, set]:
state = (self._cfg_dict, self._filename, self._text,
self._env_variables, self._format_python_code,
self._imported_names)
return state
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
cls = self.__class__ cls = self.__class__
@ -1495,12 +1518,13 @@ class Config:
copy = __copy__ copy = __copy__
def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str], def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
dict]): dict, bool, set]):
_cfg_dict, _filename, _text, _env_variables = state super().__setattr__('_cfg_dict', state[0])
super().__setattr__('_cfg_dict', _cfg_dict) super().__setattr__('_filename', state[1])
super().__setattr__('_filename', _filename) super().__setattr__('_text', state[2])
super().__setattr__('_text', _text) super().__setattr__('_env_variables', state[3])
super().__setattr__('_text', _env_variables) super().__setattr__('_format_python_code', state[4])
super().__setattr__('_imported_names', state[5])
def dump(self, file: Optional[Union[str, Path]] = None): def dump(self, file: Optional[Union[str, Path]] = None):
"""Dump config to file or return config text. """Dump config to file or return config text.
@ -1616,8 +1640,8 @@ class Config:
return False return False
def _to_lazy_dict(self, keep_imported: bool = False) -> dict: def _to_lazy_dict(self, keep_imported: bool = False) -> dict:
"""Convert config object to dictionary and filter the imported """Convert config object to dictionary with lazy object, and filter the
object.""" imported object."""
res = self._cfg_dict._to_lazy_dict() res = self._cfg_dict._to_lazy_dict()
if hasattr(self, '_imported_names') and not keep_imported: if hasattr(self, '_imported_names') and not keep_imported:
res = { res = {
@ -1637,7 +1661,14 @@ class Config:
If you import third-party objects in the config file, all imported If you import third-party objects in the config file, all imported
objects will be converted to a string like ``torch.optim.SGD`` objects will be converted to a string like ``torch.optim.SGD``
""" """
return self._cfg_dict.to_dict() cfg_dict = self._cfg_dict.to_dict()
if hasattr(self, '_imported_names') and not keep_imported:
cfg_dict = {
key: value
for key, value in cfg_dict.items()
if key not in self._imported_names
}
return cfg_dict
class DictAction(Action): class DictAction(Action):

View File

@ -121,6 +121,16 @@ class LazyObject:
__repr__ = __str__ __repr__ = __str__
# `pickle.dump` will try to get the `__getstate__` and `__setstate__`
# methods of the dumped object. If these two methods are not defined,
# LazyObject will return a `__getstate__` LazyObject` or `__setstate__`
# LazyObject.
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__ = state
class LazyAttr: class LazyAttr:
"""The attribute of the LazyObject. """The attribute of the LazyObject.
@ -219,3 +229,13 @@ class LazyAttr:
return self.name return self.name
__repr__ = __str__ __repr__ = __str__
# `pickle.dump` will try to get the `__getstate__` and `__setstate__`
# methods of the dumped object. If these two methods are not defined,
# LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__`
# LazyAttr.
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__ = state

View File

@ -165,6 +165,8 @@ def _is_builtin_module(module_name: str) -> bool:
return False return False
if module_name.startswith('mmengine.config'): if module_name.startswith('mmengine.config'):
return True return True
if module_name in sys.builtin_module_names:
return True
spec = find_spec(module_name.split('.')[0]) spec = find_spec(module_name.split('.')[0])
# Module not found # Module not found
if spec is None: if spec is None:
@ -314,6 +316,15 @@ class ImportTransformer(ast.NodeTransformer):
# Built-in modules will not be parsed as LazyObject # Built-in modules will not be parsed as LazyObject
module = f'{node.level*"."}{node.module}' module = f'{node.level*"."}{node.module}'
if _is_builtin_module(module): if _is_builtin_module(module):
# Make sure builtin module will be added into `self.imported_obj`
for alias in node.names:
if alias.asname is not None:
self.imported_obj.add(alias.asname)
elif alias.name == '*':
raise ConfigParsingError(
'Cannot import * from non-base config')
else:
self.imported_obj.add(alias.name)
return node return node
if module in self.base_dict: if module in self.base_dict:
@ -409,6 +420,8 @@ class ImportTransformer(ast.NodeTransformer):
alias = alias_list[0] alias = alias_list[0]
if alias.asname is not None: if alias.asname is not None:
self.imported_obj.add(alias.asname) self.imported_obj.add(alias.asname)
if _is_builtin_module(alias.name.split('.')[0]):
return node
return ast.parse( # type: ignore return ast.parse( # type: ignore
f'{alias.asname} = LazyObject(' f'{alias.asname} = LazyObject('
f'"{alias.name}",' f'"{alias.name}",'

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from functools import partial
from itertools import chain
from os.path import basename
from os.path import exists as ex
from os.path import splitext
import numpy as np
path = osp.join('a', 'b')
name, suffix = splitext('a/b.py')
chained = list(chain([1, 2], [3, 4]))
existed = ex(__file__)
cfgname = partial(basename, __file__)()

View File

@ -3,6 +3,7 @@ import argparse
import copy import copy
import os import os
import os.path as osp import os.path as osp
import pickle
import platform import platform
import sys import sys
import tempfile import tempfile
@ -951,6 +952,19 @@ class TestConfig:
assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss) assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss)
DefaultScope._instance_dict.pop('test1') DefaultScope._instance_dict.pop('test1')
def test_pickle(self):
# Text style config
cfg_path = osp.join(self.data_path, 'config/py_config/test_py_base.py')
cfg = Config.fromfile(cfg_path)
pickled = pickle.loads(pickle.dumps(cfg))
assert pickled.__dict__ == cfg.__dict__
cfg_path = osp.join(self.data_path,
'config/lazy_module_config/toy_model.py')
cfg = Config.fromfile(cfg_path)
pickled = pickle.loads(pickle.dumps(cfg))
assert pickled.__dict__ == cfg.__dict__
def test_lazy_import(self, tmp_path): def test_lazy_import(self, tmp_path):
lazy_import_cfg_path = osp.join( lazy_import_cfg_path = osp.join(
self.data_path, 'config/lazy_module_config/toy_model.py') self.data_path, 'config/lazy_module_config/toy_model.py')
@ -1036,6 +1050,26 @@ error_attr = mmengine.error_attr
osp.join(self.data_path, osp.join(self.data_path,
'config/lazy_module_config/error_mix_using2.py')) 'config/lazy_module_config/error_mix_using2.py'))
cfg = Config.fromfile(
osp.join(self.data_path,
'config/lazy_module_config/test_mix_builtin.py'))
assert cfg.path == osp.join('a', 'b')
assert cfg.name == 'a/b'
assert cfg.suffix == '.py'
assert cfg.chained == [1, 2, 3, 4]
assert cfg.existed
assert cfg.cfgname == 'test_mix_builtin.py'
cfg_dict = cfg.to_dict()
dumped_cfg_path = tmp_path / 'test_dump_lazy.py'
cfg.dump(dumped_cfg_path)
dumped_cfg = Config.fromfile(dumped_cfg_path)
assert set(dumped_cfg.keys()) == {
'path', 'name', 'suffix', 'chained', 'existed', 'cfgname'
}
assert dumped_cfg.to_dict() == cfg.to_dict()
class TestConfigDict(TestCase): class TestConfigDict(TestCase):