mirror of https://github.com/open-mmlab/mmcv.git
Allow list index keys in Config.merge_from_dict (#696)
* Allow list keys in Config merge from dict * Reformat * Set allow_list_keys default as True in merge_from_dict * Fix docstring * fix a small typopull/702/head
parent
eaf25af6c4
commit
c9f96855b0
|
@ -1,5 +1,4 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
|
@ -32,7 +31,6 @@ class CephBackend(BaseStorageBackend):
|
|||
def __init__(self, path_mapping=None):
|
||||
try:
|
||||
import ceph
|
||||
warnings.warn('Ceph is deprecate in favor of Petrel.')
|
||||
except ImportError:
|
||||
raise ImportError('Please install ceph to enable CephBackend.')
|
||||
|
||||
|
|
|
@ -190,20 +190,55 @@ class Config:
|
|||
return cfg_dict, cfg_text
|
||||
|
||||
@staticmethod
|
||||
def _merge_a_into_b(a, b):
|
||||
# merge dict `a` into dict `b` (non-inplace). values in `a` will
|
||||
# overwrite `b`.
|
||||
# copy first to avoid inplace modification
|
||||
def _merge_a_into_b(a, b, allow_list_keys=False):
|
||||
"""merge dict ``a`` into dict ``b`` (non-inplace).
|
||||
|
||||
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
||||
in-place modifications.
|
||||
|
||||
Args:
|
||||
a (dict): The source dict to be merged into ``b``.
|
||||
b (dict): The origin dict to be fetch keys from ``a``.
|
||||
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||
are allowed in source ``a`` and will replace the element of the
|
||||
corresponding index in b if b is a list. Default: False.
|
||||
|
||||
Returns:
|
||||
dict: The modified dict of ``b`` using ``a``.
|
||||
|
||||
Examples:
|
||||
# Normally merge a into b.
|
||||
>>> Config._merge_a_into_b(
|
||||
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
|
||||
{'obj': {'a': 2}}
|
||||
|
||||
# Delete b first and merge a into b.
|
||||
>>> Config._merge_a_into_b(
|
||||
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
|
||||
{'obj': {'a': 2}}
|
||||
|
||||
# b is a list
|
||||
>>> Config._merge_a_into_b(
|
||||
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
|
||||
[{'a': 2}, {'b': 2}]
|
||||
"""
|
||||
b = b.copy()
|
||||
for k, v in a.items():
|
||||
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
|
||||
if not isinstance(b[k], dict):
|
||||
if allow_list_keys and k.isdigit() and isinstance(b, list):
|
||||
k = int(k)
|
||||
if len(b) <= k:
|
||||
raise KeyError(f'Index {k} exceeds the length of list {b}')
|
||||
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
||||
elif isinstance(v,
|
||||
dict) and k in b and not v.pop(DELETE_KEY, False):
|
||||
allowed_types = (dict, list) if allow_list_keys else dict
|
||||
if not isinstance(b[k], allowed_types):
|
||||
raise TypeError(
|
||||
f'{k}={v} in child config cannot inherit from base '
|
||||
f'because {k} is a dict in the child config but is of '
|
||||
f'type {type(b[k])} in base config. You may set '
|
||||
f'`{DELETE_KEY}=True` to ignore the base config')
|
||||
b[k] = Config._merge_a_into_b(v, b[k])
|
||||
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
||||
else:
|
||||
b[k] = v
|
||||
return b
|
||||
|
@ -403,7 +438,7 @@ class Config:
|
|||
else:
|
||||
mmcv.dump(cfg_dict, file)
|
||||
|
||||
def merge_from_dict(self, options):
|
||||
def merge_from_dict(self, options, allow_list_keys=True):
|
||||
"""Merge list into cfg_dict.
|
||||
|
||||
Merge the dict parsed by MultipleKVAction into this cfg.
|
||||
|
@ -417,8 +452,21 @@ class Config:
|
|||
>>> assert cfg_dict == dict(
|
||||
... model=dict(backbone=dict(depth=50, with_cp=True)))
|
||||
|
||||
# Merge list element
|
||||
>>> cfg = Config(dict(pipeline=[
|
||||
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
|
||||
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
|
||||
>>> cfg.merge_from_dict(options, allow_list_keys=True)
|
||||
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
>>> assert cfg_dict == dict(pipeline=[
|
||||
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
|
||||
|
||||
Args:
|
||||
options (dict): dict of configs to merge from.
|
||||
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||
are allowed in ``options`` and will replace the element of the
|
||||
corresponding index in the config if the config is a list.
|
||||
Default: True.
|
||||
"""
|
||||
option_cfg_dict = {}
|
||||
for full_key, v in options.items():
|
||||
|
@ -432,7 +480,9 @@ class Config:
|
|||
|
||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
super(Config, self).__setattr__(
|
||||
'_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
|
||||
'_cfg_dict',
|
||||
Config._merge_a_into_b(
|
||||
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
|
||||
|
||||
|
||||
class DictAction(Action):
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
item = [{'a': 0}, {'b': 0, 'c': 0}]
|
|
@ -70,9 +70,6 @@ class TestFileClient:
|
|||
|
||||
@patch('ceph.S3Client', MockS3Client)
|
||||
def test_ceph_backend(self):
|
||||
with pytest.warns(
|
||||
Warning, match='Ceph is deprecate in favor of Petrel.'):
|
||||
FileClient('ceph')
|
||||
ceph_backend = FileClient('ceph')
|
||||
|
||||
# input path is Path object
|
||||
|
|
|
@ -219,6 +219,24 @@ def test_merge_from_dict():
|
|||
assert cfg.item2 == dict(a=1, b=0.1)
|
||||
assert cfg.item3 is False
|
||||
|
||||
cfg_file = osp.join(data_path, 'config/s.py')
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
|
||||
# Allow list keys
|
||||
input_options = {'item.0.a': 1, 'item.1.b': 1}
|
||||
cfg.merge_from_dict(input_options, allow_list_keys=True)
|
||||
assert cfg.item == [{'a': 1}, {'b': 1, 'c': 0}]
|
||||
|
||||
# allow_list_keys is False
|
||||
input_options = {'item.0.a': 1, 'item.1.b': 1}
|
||||
with pytest.raises(TypeError):
|
||||
cfg.merge_from_dict(input_options, allow_list_keys=False)
|
||||
|
||||
# Overflowed index number
|
||||
input_options = {'item.2.a': 1}
|
||||
with pytest.raises(KeyError):
|
||||
cfg.merge_from_dict(input_options, allow_list_keys=True)
|
||||
|
||||
|
||||
def test_merge_delete():
|
||||
cfg_file = osp.join(data_path, 'config/delete.py')
|
||||
|
|
Loading…
Reference in New Issue