mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Implement copy
and __copy__
for ConfigDict
(#1230)
This commit is contained in:
parent
20d477dae1
commit
d5a46d4144
@ -125,6 +125,19 @@ class ConfigDict(Dict):
|
|||||||
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
||||||
return other
|
return other
|
||||||
|
|
||||||
|
def __copy__(self):
|
||||||
|
other = self.__class__()
|
||||||
|
for key, value in super().items():
|
||||||
|
other[key] = value
|
||||||
|
return other
|
||||||
|
|
||||||
|
copy = __copy__
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# Implement `__iter__` to overwrite the unpacking operator `**cfg_dict`
|
||||||
|
# to get the built lazy object
|
||||||
|
return iter(self.keys())
|
||||||
|
|
||||||
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
||||||
"""Get the value of the key. If class attribute ``lazy`` is True, the
|
"""Get the value of the key. If class attribute ``lazy`` is True, the
|
||||||
LazyObject will be built and returned.
|
LazyObject will be built and returned.
|
||||||
|
@ -155,11 +155,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||||||
if auto_wrap_policy is None:
|
if auto_wrap_policy is None:
|
||||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||||
elif isinstance(auto_wrap_policy, dict):
|
elif isinstance(auto_wrap_policy, dict):
|
||||||
ori_func = FUNCTIONS.get( # type: ignore
|
policy = auto_wrap_policy.pop('type')
|
||||||
auto_wrap_policy.pop('type'))
|
if isinstance(policy, str):
|
||||||
if auto_wrap_policy is None:
|
policy = FUNCTIONS.get(policy) # type: ignore
|
||||||
|
if policy is None:
|
||||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||||
auto_wrap_policy = partial(ori_func, **auto_wrap_policy)
|
auto_wrap_policy = partial(policy, **auto_wrap_policy)
|
||||||
|
|
||||||
if not (auto_wrap_policy is None
|
if not (auto_wrap_policy is None
|
||||||
or callable(auto_wrap_policy)): # type: ignore
|
or callable(auto_wrap_policy)): # type: ignore
|
||||||
@ -182,10 +183,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||||||
if param_init_fn is None:
|
if param_init_fn is None:
|
||||||
raise ValueError('`param_init_fn` is not registered!')
|
raise ValueError('`param_init_fn` is not registered!')
|
||||||
elif isinstance(param_init_fn, dict):
|
elif isinstance(param_init_fn, dict):
|
||||||
param_init_fn = FUNCTIONS.get(param_init_fn.pop('type'))
|
init_fn = param_init_fn.pop('type')
|
||||||
if param_init_fn is None:
|
if isinstance(param_init_fn, str):
|
||||||
|
init_fn = FUNCTIONS.get(init_fn) # type: ignore
|
||||||
|
if init_fn is None:
|
||||||
raise ValueError('`param_init_fn` is not registered!')
|
raise ValueError('`param_init_fn` is not registered!')
|
||||||
param_init_fn = partial(param_init_fn, **param_init_fn)
|
param_init_fn = partial(init_fn, **param_init_fn)
|
||||||
|
|
||||||
if not (callable(param_init_fn) or param_init_fn is None):
|
if not (callable(param_init_fn) or param_init_fn is None):
|
||||||
raise TypeError('`param_init_fn` should be a str, a '
|
raise TypeError('`param_init_fn` should be a str, a '
|
||||||
|
@ -3,6 +3,7 @@ from mmengine.config import read_base
|
|||||||
from mmengine.dataset import DefaultSampler
|
from mmengine.dataset import DefaultSampler
|
||||||
from mmengine.hooks import EMAHook
|
from mmengine.hooks import EMAHook
|
||||||
from mmengine.model import MomentumAnnealingEMA
|
from mmengine.model import MomentumAnnealingEMA
|
||||||
|
from mmengine.runner import FlexibleRunner
|
||||||
from mmengine.testing.runner_test_case import ToyDataset, ToyMetric
|
from mmengine.testing.runner_test_case import ToyDataset, ToyMetric
|
||||||
|
|
||||||
with read_base():
|
with read_base():
|
||||||
@ -44,3 +45,5 @@ custom_hooks = [
|
|||||||
strict_load=False,
|
strict_load=False,
|
||||||
priority=49)
|
priority=49)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
runner_type = FlexibleRunner
|
||||||
|
@ -965,6 +965,11 @@ class TestConfig:
|
|||||||
cfg.dump(dumped_cfg_path)
|
cfg.dump(dumped_cfg_path)
|
||||||
dumped_cfg = Config.fromfile(dumped_cfg_path)
|
dumped_cfg = Config.fromfile(dumped_cfg_path)
|
||||||
|
|
||||||
|
copied_cfg_path = tmp_path / 'test_dump_copied_lazy.py'
|
||||||
|
cfg_copy = cfg.copy()
|
||||||
|
cfg_copy.dump(copied_cfg_path)
|
||||||
|
copied_cfg = Config.fromfile(copied_cfg_path)
|
||||||
|
|
||||||
def _compare_dict(a, b):
|
def _compare_dict(a, b):
|
||||||
if isinstance(a, dict):
|
if isinstance(a, dict):
|
||||||
assert len(a) == len(b)
|
assert len(a) == len(b)
|
||||||
@ -978,6 +983,7 @@ class TestConfig:
|
|||||||
assert str(a) == str(b)
|
assert str(a) == str(b)
|
||||||
|
|
||||||
_compare_dict(cfg.to_dict(), dumped_cfg.to_dict())
|
_compare_dict(cfg.to_dict(), dumped_cfg.to_dict())
|
||||||
|
_compare_dict(cfg.to_dict(), copied_cfg.to_dict())
|
||||||
|
|
||||||
# TODO reimplement this part of unit test when mmdetection adds the
|
# TODO reimplement this part of unit test when mmdetection adds the
|
||||||
# new config.
|
# new config.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user