mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Implement copy
and __copy__
for ConfigDict
(#1230)
This commit is contained in:
parent
20d477dae1
commit
d5a46d4144
@ -125,6 +125,19 @@ class ConfigDict(Dict):
|
||||
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
||||
return other
|
||||
|
||||
def __copy__(self):
|
||||
other = self.__class__()
|
||||
for key, value in super().items():
|
||||
other[key] = value
|
||||
return other
|
||||
|
||||
copy = __copy__
|
||||
|
||||
def __iter__(self):
|
||||
# Implement `__iter__` to overwrite the unpacking operator `**cfg_dict`
|
||||
# to get the built lazy object
|
||||
return iter(self.keys())
|
||||
|
||||
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
||||
"""Get the value of the key. If class attribute ``lazy`` is True, the
|
||||
LazyObject will be built and returned.
|
||||
|
@ -155,11 +155,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
if auto_wrap_policy is None:
|
||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||
elif isinstance(auto_wrap_policy, dict):
|
||||
ori_func = FUNCTIONS.get( # type: ignore
|
||||
auto_wrap_policy.pop('type'))
|
||||
if auto_wrap_policy is None:
|
||||
policy = auto_wrap_policy.pop('type')
|
||||
if isinstance(policy, str):
|
||||
policy = FUNCTIONS.get(policy) # type: ignore
|
||||
if policy is None:
|
||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||
auto_wrap_policy = partial(ori_func, **auto_wrap_policy)
|
||||
auto_wrap_policy = partial(policy, **auto_wrap_policy)
|
||||
|
||||
if not (auto_wrap_policy is None
|
||||
or callable(auto_wrap_policy)): # type: ignore
|
||||
@ -182,10 +183,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
if param_init_fn is None:
|
||||
raise ValueError('`param_init_fn` is not registered!')
|
||||
elif isinstance(param_init_fn, dict):
|
||||
param_init_fn = FUNCTIONS.get(param_init_fn.pop('type'))
|
||||
if param_init_fn is None:
|
||||
init_fn = param_init_fn.pop('type')
|
||||
if isinstance(param_init_fn, str):
|
||||
init_fn = FUNCTIONS.get(init_fn) # type: ignore
|
||||
if init_fn is None:
|
||||
raise ValueError('`param_init_fn` is not registered!')
|
||||
param_init_fn = partial(param_init_fn, **param_init_fn)
|
||||
param_init_fn = partial(init_fn, **param_init_fn)
|
||||
|
||||
if not (callable(param_init_fn) or param_init_fn is None):
|
||||
raise TypeError('`param_init_fn` should be a str, a '
|
||||
|
@ -3,6 +3,7 @@ from mmengine.config import read_base
|
||||
from mmengine.dataset import DefaultSampler
|
||||
from mmengine.hooks import EMAHook
|
||||
from mmengine.model import MomentumAnnealingEMA
|
||||
from mmengine.runner import FlexibleRunner
|
||||
from mmengine.testing.runner_test_case import ToyDataset, ToyMetric
|
||||
|
||||
with read_base():
|
||||
@ -44,3 +45,5 @@ custom_hooks = [
|
||||
strict_load=False,
|
||||
priority=49)
|
||||
]
|
||||
|
||||
runner_type = FlexibleRunner
|
||||
|
@ -965,6 +965,11 @@ class TestConfig:
|
||||
cfg.dump(dumped_cfg_path)
|
||||
dumped_cfg = Config.fromfile(dumped_cfg_path)
|
||||
|
||||
copied_cfg_path = tmp_path / 'test_dump_copied_lazy.py'
|
||||
cfg_copy = cfg.copy()
|
||||
cfg_copy.dump(copied_cfg_path)
|
||||
copied_cfg = Config.fromfile(copied_cfg_path)
|
||||
|
||||
def _compare_dict(a, b):
|
||||
if isinstance(a, dict):
|
||||
assert len(a) == len(b)
|
||||
@ -978,6 +983,7 @@ class TestConfig:
|
||||
assert str(a) == str(b)
|
||||
|
||||
_compare_dict(cfg.to_dict(), dumped_cfg.to_dict())
|
||||
_compare_dict(cfg.to_dict(), copied_cfg.to_dict())
|
||||
|
||||
# TODO reimplement this part of unit test when mmdetection adds the
|
||||
# new config.
|
||||
|
Loading…
x
Reference in New Issue
Block a user