[Fix] Implement copy and __copy__ for ConfigDict (#1230)

This commit is contained in:
Mashiro 2023-07-03 15:11:10 +08:00 committed by GitHub
parent 20d477dae1
commit d5a46d4144
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 7 deletions

View File

@ -125,6 +125,19 @@ class ConfigDict(Dict):
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
return other return other
def __copy__(self):
other = self.__class__()
for key, value in super().items():
other[key] = value
return other
copy = __copy__
def __iter__(self):
# Implement `__iter__` to overwrite the unpacking operator `**cfg_dict`
# to get the built lazy object
return iter(self.keys())
def get(self, key: str, default: Optional[Any] = None) -> Any: def get(self, key: str, default: Optional[Any] = None) -> Any:
"""Get the value of the key. If class attribute ``lazy`` is True, the """Get the value of the key. If class attribute ``lazy`` is True, the
LazyObject will be built and returned. LazyObject will be built and returned.

View File

@ -155,11 +155,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
if auto_wrap_policy is None: if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!') raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict): elif isinstance(auto_wrap_policy, dict):
ori_func = FUNCTIONS.get( # type: ignore policy = auto_wrap_policy.pop('type')
auto_wrap_policy.pop('type')) if isinstance(policy, str):
if auto_wrap_policy is None: policy = FUNCTIONS.get(policy) # type: ignore
if policy is None:
raise ValueError('`auto_wrap_policy` is not registered!') raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(ori_func, **auto_wrap_policy) auto_wrap_policy = partial(policy, **auto_wrap_policy)
if not (auto_wrap_policy is None if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore or callable(auto_wrap_policy)): # type: ignore
@ -182,10 +183,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
if param_init_fn is None: if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!') raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict): elif isinstance(param_init_fn, dict):
param_init_fn = FUNCTIONS.get(param_init_fn.pop('type')) init_fn = param_init_fn.pop('type')
if param_init_fn is None: if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!') raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(param_init_fn, **param_init_fn) param_init_fn = partial(init_fn, **param_init_fn)
if not (callable(param_init_fn) or param_init_fn is None): if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a ' raise TypeError('`param_init_fn` should be a str, a '

View File

@ -3,6 +3,7 @@ from mmengine.config import read_base
from mmengine.dataset import DefaultSampler from mmengine.dataset import DefaultSampler
from mmengine.hooks import EMAHook from mmengine.hooks import EMAHook
from mmengine.model import MomentumAnnealingEMA from mmengine.model import MomentumAnnealingEMA
from mmengine.runner import FlexibleRunner
from mmengine.testing.runner_test_case import ToyDataset, ToyMetric from mmengine.testing.runner_test_case import ToyDataset, ToyMetric
with read_base(): with read_base():
@ -44,3 +45,5 @@ custom_hooks = [
strict_load=False, strict_load=False,
priority=49) priority=49)
] ]
runner_type = FlexibleRunner

View File

@ -965,6 +965,11 @@ class TestConfig:
cfg.dump(dumped_cfg_path) cfg.dump(dumped_cfg_path)
dumped_cfg = Config.fromfile(dumped_cfg_path) dumped_cfg = Config.fromfile(dumped_cfg_path)
copied_cfg_path = tmp_path / 'test_dump_copied_lazy.py'
cfg_copy = cfg.copy()
cfg_copy.dump(copied_cfg_path)
copied_cfg = Config.fromfile(copied_cfg_path)
def _compare_dict(a, b): def _compare_dict(a, b):
if isinstance(a, dict): if isinstance(a, dict):
assert len(a) == len(b) assert len(a) == len(b)
@ -978,6 +983,7 @@ class TestConfig:
assert str(a) == str(b) assert str(a) == str(b)
_compare_dict(cfg.to_dict(), dumped_cfg.to_dict()) _compare_dict(cfg.to_dict(), dumped_cfg.to_dict())
_compare_dict(cfg.to_dict(), copied_cfg.to_dict())
# TODO reimplement this part of unit test when mmdetection adds the # TODO reimplement this part of unit test when mmdetection adds the
# new config. # new config.