mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix pickle the Python style config (#1241)
This commit is contained in:
parent
b2295a258c
commit
955b5712c4
@ -261,6 +261,24 @@ class ConfigDict(Dict):
|
||||
for key, value in merged.items():
|
||||
self[key] = value
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
for key, value in super().items():
|
||||
state[key] = value
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
for key, value in state.items():
|
||||
self[key] = value
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, ConfigDict):
|
||||
return other.to_dict() == self.to_dict()
|
||||
elif isinstance(other, dict):
|
||||
return {k: v for k, v in self.items()} == other
|
||||
else:
|
||||
return False
|
||||
|
||||
def _to_lazy_dict(self):
|
||||
"""Convert the ConfigDict to a normal dictionary recursively, and keep
|
||||
the ``LazyObject`` or ``LazyAttr`` object not built."""
|
||||
@ -281,8 +299,8 @@ class ConfigDict(Dict):
|
||||
return _to_dict(self)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert the ConfigDict to a normal dictionary recursively, and keep
|
||||
the ``LazyObject`` or ``LazyAttr`` object not built."""
|
||||
"""Convert the ConfigDict to a normal dictionary recursively, and
|
||||
convert the ``LazyObject`` or ``LazyAttr`` to string."""
|
||||
return _lazy2string(self, dict_type=dict)
|
||||
|
||||
|
||||
@ -363,12 +381,14 @@ class Config:
|
||||
.. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
cfg_dict: dict = None,
|
||||
cfg_text: Optional[str] = None,
|
||||
filename: Optional[Union[str, Path]] = None,
|
||||
env_variables: Optional[dict] = None,
|
||||
format_python_code: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
cfg_dict: dict = None,
|
||||
cfg_text: Optional[str] = None,
|
||||
filename: Optional[Union[str, Path]] = None,
|
||||
env_variables: Optional[dict] = None,
|
||||
format_python_code: bool = True,
|
||||
):
|
||||
filename = str(filename) if isinstance(filename, Path) else filename
|
||||
if cfg_dict is None:
|
||||
cfg_dict = dict()
|
||||
@ -384,6 +404,9 @@ class Config:
|
||||
super().__setattr__('_cfg_dict', cfg_dict)
|
||||
super().__setattr__('_filename', filename)
|
||||
super().__setattr__('_format_python_code', format_python_code)
|
||||
if not hasattr(self, '_imported_names'):
|
||||
super().__setattr__('_imported_names', set())
|
||||
|
||||
if cfg_text:
|
||||
text = cfg_text
|
||||
elif filename:
|
||||
@ -445,7 +468,8 @@ class Config:
|
||||
cfg_dict,
|
||||
cfg_text=cfg_text,
|
||||
filename=filename,
|
||||
env_variables=env_variables)
|
||||
env_variables=env_variables,
|
||||
)
|
||||
else:
|
||||
# Enable lazy import when parsing the config.
|
||||
# Using try-except to make sure ``ConfigDict.lazy`` will be reset
|
||||
@ -457,15 +481,10 @@ class Config:
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
# disable lazy import to get the real type. See more details
|
||||
# about lazy in the docstring of ConfigDict
|
||||
ConfigDict.lazy = False
|
||||
|
||||
# delete builtin imported objects
|
||||
for key, value in list(cfg_dict._to_lazy_dict().items()):
|
||||
if isinstance(value, (types.FunctionType, types.ModuleType)):
|
||||
cfg_dict.pop(key)
|
||||
|
||||
# disable lazy import to get the real type. See more details about
|
||||
# lazy in the docstring of ConfigDict
|
||||
cfg = Config(
|
||||
cfg_dict,
|
||||
filename=filename,
|
||||
@ -996,7 +1015,7 @@ class Config:
|
||||
# accessed, but will not be dumped by default.
|
||||
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
global_dict = {'LazyObject': LazyObject}
|
||||
global_dict = {'LazyObject': LazyObject, '__file__': filename}
|
||||
base_dict = {}
|
||||
|
||||
parsed_codes = ast.parse(f.read())
|
||||
@ -1470,9 +1489,13 @@ class Config:
|
||||
def __iter__(self):
|
||||
return iter(self._cfg_dict)
|
||||
|
||||
def __getstate__(self) -> Tuple[dict, Optional[str], Optional[str], dict]:
|
||||
return (self._cfg_dict, self._filename, self._text,
|
||||
self._env_variables)
|
||||
def __getstate__(
|
||||
self
|
||||
) -> Tuple[dict, Optional[str], Optional[str], dict, bool, set]:
|
||||
state = (self._cfg_dict, self._filename, self._text,
|
||||
self._env_variables, self._format_python_code,
|
||||
self._imported_names)
|
||||
return state
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
cls = self.__class__
|
||||
@ -1495,12 +1518,13 @@ class Config:
|
||||
copy = __copy__
|
||||
|
||||
def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str],
|
||||
dict]):
|
||||
_cfg_dict, _filename, _text, _env_variables = state
|
||||
super().__setattr__('_cfg_dict', _cfg_dict)
|
||||
super().__setattr__('_filename', _filename)
|
||||
super().__setattr__('_text', _text)
|
||||
super().__setattr__('_text', _env_variables)
|
||||
dict, bool, set]):
|
||||
super().__setattr__('_cfg_dict', state[0])
|
||||
super().__setattr__('_filename', state[1])
|
||||
super().__setattr__('_text', state[2])
|
||||
super().__setattr__('_env_variables', state[3])
|
||||
super().__setattr__('_format_python_code', state[4])
|
||||
super().__setattr__('_imported_names', state[5])
|
||||
|
||||
def dump(self, file: Optional[Union[str, Path]] = None):
|
||||
"""Dump config to file or return config text.
|
||||
@ -1616,8 +1640,8 @@ class Config:
|
||||
return False
|
||||
|
||||
def _to_lazy_dict(self, keep_imported: bool = False) -> dict:
|
||||
"""Convert config object to dictionary and filter the imported
|
||||
object."""
|
||||
"""Convert config object to dictionary with lazy object, and filter the
|
||||
imported object."""
|
||||
res = self._cfg_dict._to_lazy_dict()
|
||||
if hasattr(self, '_imported_names') and not keep_imported:
|
||||
res = {
|
||||
@ -1637,7 +1661,14 @@ class Config:
|
||||
If you import third-party objects in the config file, all imported
|
||||
objects will be converted to a string like ``torch.optim.SGD``
|
||||
"""
|
||||
return self._cfg_dict.to_dict()
|
||||
cfg_dict = self._cfg_dict.to_dict()
|
||||
if hasattr(self, '_imported_names') and not keep_imported:
|
||||
cfg_dict = {
|
||||
key: value
|
||||
for key, value in cfg_dict.items()
|
||||
if key not in self._imported_names
|
||||
}
|
||||
return cfg_dict
|
||||
|
||||
|
||||
class DictAction(Action):
|
||||
|
@ -121,6 +121,16 @@ class LazyObject:
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
# `pickle.dump` will try to get the `__getstate__` and `__setstate__`
|
||||
# methods of the dumped object. If these two methods are not defined,
|
||||
# LazyObject will return a `__getstate__` LazyObject` or `__setstate__`
|
||||
# LazyObject.
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
|
||||
|
||||
class LazyAttr:
|
||||
"""The attribute of the LazyObject.
|
||||
@ -219,3 +229,13 @@ class LazyAttr:
|
||||
return self.name
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
# `pickle.dump` will try to get the `__getstate__` and `__setstate__`
|
||||
# methods of the dumped object. If these two methods are not defined,
|
||||
# LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__`
|
||||
# LazyAttr.
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
|
@ -165,6 +165,8 @@ def _is_builtin_module(module_name: str) -> bool:
|
||||
return False
|
||||
if module_name.startswith('mmengine.config'):
|
||||
return True
|
||||
if module_name in sys.builtin_module_names:
|
||||
return True
|
||||
spec = find_spec(module_name.split('.')[0])
|
||||
# Module not found
|
||||
if spec is None:
|
||||
@ -314,6 +316,15 @@ class ImportTransformer(ast.NodeTransformer):
|
||||
# Built-in modules will not be parsed as LazyObject
|
||||
module = f'{node.level*"."}{node.module}'
|
||||
if _is_builtin_module(module):
|
||||
# Make sure builtin module will be added into `self.imported_obj`
|
||||
for alias in node.names:
|
||||
if alias.asname is not None:
|
||||
self.imported_obj.add(alias.asname)
|
||||
elif alias.name == '*':
|
||||
raise ConfigParsingError(
|
||||
'Cannot import * from non-base config')
|
||||
else:
|
||||
self.imported_obj.add(alias.name)
|
||||
return node
|
||||
|
||||
if module in self.base_dict:
|
||||
@ -409,6 +420,8 @@ class ImportTransformer(ast.NodeTransformer):
|
||||
alias = alias_list[0]
|
||||
if alias.asname is not None:
|
||||
self.imported_obj.add(alias.asname)
|
||||
if _is_builtin_module(alias.name.split('.')[0]):
|
||||
return node
|
||||
return ast.parse( # type: ignore
|
||||
f'{alias.asname} = LazyObject('
|
||||
f'"{alias.name}",'
|
||||
|
16
tests/data/config/lazy_module_config/test_mix_builtin.py
Normal file
16
tests/data/config/lazy_module_config/test_mix_builtin.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from os.path import basename
|
||||
from os.path import exists as ex
|
||||
from os.path import splitext
|
||||
|
||||
import numpy as np
|
||||
|
||||
path = osp.join('a', 'b')
|
||||
name, suffix = splitext('a/b.py')
|
||||
chained = list(chain([1, 2], [3, 4]))
|
||||
existed = ex(__file__)
|
||||
cfgname = partial(basename, __file__)()
|
||||
|
@ -3,6 +3,7 @@ import argparse
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import platform
|
||||
import sys
|
||||
import tempfile
|
||||
@ -951,6 +952,19 @@ class TestConfig:
|
||||
assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss)
|
||||
DefaultScope._instance_dict.pop('test1')
|
||||
|
||||
def test_pickle(self):
|
||||
# Text style config
|
||||
cfg_path = osp.join(self.data_path, 'config/py_config/test_py_base.py')
|
||||
cfg = Config.fromfile(cfg_path)
|
||||
pickled = pickle.loads(pickle.dumps(cfg))
|
||||
assert pickled.__dict__ == cfg.__dict__
|
||||
|
||||
cfg_path = osp.join(self.data_path,
|
||||
'config/lazy_module_config/toy_model.py')
|
||||
cfg = Config.fromfile(cfg_path)
|
||||
pickled = pickle.loads(pickle.dumps(cfg))
|
||||
assert pickled.__dict__ == cfg.__dict__
|
||||
|
||||
def test_lazy_import(self, tmp_path):
|
||||
lazy_import_cfg_path = osp.join(
|
||||
self.data_path, 'config/lazy_module_config/toy_model.py')
|
||||
@ -1036,6 +1050,26 @@ error_attr = mmengine.error_attr
|
||||
osp.join(self.data_path,
|
||||
'config/lazy_module_config/error_mix_using2.py'))
|
||||
|
||||
cfg = Config.fromfile(
|
||||
osp.join(self.data_path,
|
||||
'config/lazy_module_config/test_mix_builtin.py'))
|
||||
assert cfg.path == osp.join('a', 'b')
|
||||
assert cfg.name == 'a/b'
|
||||
assert cfg.suffix == '.py'
|
||||
assert cfg.chained == [1, 2, 3, 4]
|
||||
assert cfg.existed
|
||||
assert cfg.cfgname == 'test_mix_builtin.py'
|
||||
|
||||
cfg_dict = cfg.to_dict()
|
||||
dumped_cfg_path = tmp_path / 'test_dump_lazy.py'
|
||||
cfg.dump(dumped_cfg_path)
|
||||
dumped_cfg = Config.fromfile(dumped_cfg_path)
|
||||
|
||||
assert set(dumped_cfg.keys()) == {
|
||||
'path', 'name', 'suffix', 'chained', 'existed', 'cfgname'
|
||||
}
|
||||
assert dumped_cfg.to_dict() == cfg.to_dict()
|
||||
|
||||
|
||||
class TestConfigDict(TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user