[Fix] Implement copy and __copy__ for ConfigDict (#1230)

This commit is contained in:
Mashiro 2023-07-03 15:11:10 +08:00 committed by GitHub
parent 20d477dae1
commit d5a46d4144
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 7 deletions

View File

@ -125,6 +125,19 @@ class ConfigDict(Dict):
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
return other
def __copy__(self):
other = self.__class__()
for key, value in super().items():
other[key] = value
return other
copy = __copy__
def __iter__(self):
# Implement `__iter__` to overwrite the unpacking operator `**cfg_dict`
# to get the built lazy object
return iter(self.keys())
def get(self, key: str, default: Optional[Any] = None) -> Any:
"""Get the value of the key. If class attribute ``lazy`` is True, the
LazyObject will be built and returned.

View File

@ -155,11 +155,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict):
ori_func = FUNCTIONS.get( # type: ignore
auto_wrap_policy.pop('type'))
if auto_wrap_policy is None:
policy = auto_wrap_policy.pop('type')
if isinstance(policy, str):
policy = FUNCTIONS.get(policy) # type: ignore
if policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(ori_func, **auto_wrap_policy)
auto_wrap_policy = partial(policy, **auto_wrap_policy)
if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore
@ -182,10 +183,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
param_init_fn = FUNCTIONS.get(param_init_fn.pop('type'))
if param_init_fn is None:
init_fn = param_init_fn.pop('type')
if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(param_init_fn, **param_init_fn)
param_init_fn = partial(init_fn, **param_init_fn)
if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a '

View File

@ -3,6 +3,7 @@ from mmengine.config import read_base
from mmengine.dataset import DefaultSampler
from mmengine.hooks import EMAHook
from mmengine.model import MomentumAnnealingEMA
from mmengine.runner import FlexibleRunner
from mmengine.testing.runner_test_case import ToyDataset, ToyMetric
with read_base():
@ -44,3 +45,5 @@ custom_hooks = [
strict_load=False,
priority=49)
]
runner_type = FlexibleRunner

View File

@ -965,6 +965,11 @@ class TestConfig:
cfg.dump(dumped_cfg_path)
dumped_cfg = Config.fromfile(dumped_cfg_path)
copied_cfg_path = tmp_path / 'test_dump_copied_lazy.py'
cfg_copy = cfg.copy()
cfg_copy.dump(copied_cfg_path)
copied_cfg = Config.fromfile(copied_cfg_path)
def _compare_dict(a, b):
if isinstance(a, dict):
assert len(a) == len(b)
@ -978,6 +983,7 @@ class TestConfig:
assert str(a) == str(b)
_compare_dict(cfg.to_dict(), dumped_cfg.to_dict())
_compare_dict(cfg.to_dict(), copied_cfg.to_dict())
# TODO reimplement this part of unit test when mmdetection adds the
# new config.