mirror of https://github.com/open-mmlab/mmcv.git
Allow to replace nested tuple and list via options (#731)
* Allow to replace nested tuple and list via options * Add comments * Fix single nested items * Simplify the code * Simplify the code * Simplify the code * Simplify the code * Update docstring * Update docstring * Support quotation mark * modify docstringpull/740/head
parent
f0e68404d2
commit
96ebfa652b
|
@ -488,8 +488,10 @@ class Config:
|
|||
class DictAction(Action):
|
||||
"""
|
||||
argparse action to split an argument into KEY=VALUE form
|
||||
on the first = and append to a dictionary. List options should
|
||||
be passed as comma separated values, i.e KEY=V1,V2,V3
|
||||
on the first = and append to a dictionary. List options can
|
||||
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
|
||||
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
|
||||
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
|
@ -506,12 +508,72 @@ class DictAction(Action):
|
|||
return True if val.lower() == 'true' else False
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def _parse_iterable(val):
|
||||
"""Parse iterable values in the string.
|
||||
|
||||
All elements inside '()' or '[]' are treated as iterable values.
|
||||
|
||||
Args:
|
||||
val (str): Value string.
|
||||
|
||||
Returns:
|
||||
list | tuple: The expanded list or tuple from the string.
|
||||
|
||||
Examples:
|
||||
>>> DictAction._parse_iterable('1,2,3')
|
||||
[1, 2, 3]
|
||||
>>> DictAction._parse_iterable('[a, b, c]')
|
||||
['a', 'b', 'c']
|
||||
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
|
||||
[(1, 2, 3), ['a', 'b], 'c']
|
||||
"""
|
||||
|
||||
def find_next_comma(string):
|
||||
"""Find the position of next comma in the string.
|
||||
|
||||
If no ',' is found in the string, return the string length. All
|
||||
chars inside '()' and '[]' are treated as one element and thus ','
|
||||
inside these brackets are ignored.
|
||||
"""
|
||||
assert (string.count('(') == string.count(')')) and (
|
||||
string.count('[') == string.count(']')), \
|
||||
f'Imbalanced brackets exist in {string}'
|
||||
end = len(string)
|
||||
for idx, char in enumerate(string):
|
||||
pre = string[:idx]
|
||||
# The string before this ',' is balanced
|
||||
if ((char == ',') and (pre.count('(') == pre.count(')'))
|
||||
and (pre.count('[') == pre.count(']'))):
|
||||
end = idx
|
||||
break
|
||||
return end
|
||||
|
||||
# Strip ' and " characters and replace whitespace.
|
||||
val = val.strip('\'\"').replace(' ', '')
|
||||
is_tuple = False
|
||||
if val.startswith('(') and val.endswith(')'):
|
||||
is_tuple = True
|
||||
val = val[1:-1]
|
||||
elif val.startswith('[') and val.endswith(']'):
|
||||
val = val[1:-1]
|
||||
elif ',' not in val:
|
||||
# val is a single value
|
||||
return DictAction._parse_int_float_bool(val)
|
||||
|
||||
values = []
|
||||
while len(val) > 0:
|
||||
comma_idx = find_next_comma(val)
|
||||
element = DictAction._parse_iterable(val[:comma_idx])
|
||||
values.append(element)
|
||||
val = val[comma_idx + 1:]
|
||||
if is_tuple:
|
||||
values = tuple(values)
|
||||
return values
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
options = {}
|
||||
for kv in values:
|
||||
key, val = kv.split('=', maxsplit=1)
|
||||
val = [self._parse_int_float_bool(v) for v in val.split(',')]
|
||||
if len(val) == 1:
|
||||
val = val[0]
|
||||
options[key] = val
|
||||
options[key] = self._parse_iterable(val)
|
||||
setattr(namespace, self.dest, options)
|
||||
|
|
|
@ -347,11 +347,23 @@ def test_dict_action():
|
|||
parser = argparse.ArgumentParser(description='Train a detector')
|
||||
parser.add_argument(
|
||||
'--options', nargs='+', action=DictAction, help='custom options')
|
||||
# Nested brackets
|
||||
args = parser.parse_args(
|
||||
['--options', 'item2.a=a,b', 'item2.b=[(a,b), [1,2], false]'])
|
||||
out_dict = {'item2.a': ['a', 'b'], 'item2.b': [('a', 'b'), [1, 2], False]}
|
||||
assert args.options == out_dict
|
||||
# Single Nested brackets
|
||||
args = parser.parse_args(['--options', 'item2.a=[[1]]'])
|
||||
out_dict = {'item2.a': [[1]]}
|
||||
assert args.options == out_dict
|
||||
# Imbalance bracket
|
||||
with pytest.raises(AssertionError):
|
||||
parser.parse_args(['--options', 'item2.a=[(a,b), [1,2], false'])
|
||||
# Normal values
|
||||
args = parser.parse_args(
|
||||
['--options', 'item2.a=1', 'item2.b=0.1', 'item2.c=x', 'item3=false'])
|
||||
out_dict = {'item2.a': 1, 'item2.b': 0.1, 'item2.c': 'x', 'item3': False}
|
||||
assert args.options == out_dict
|
||||
|
||||
cfg_file = osp.join(data_path, 'config/a.py')
|
||||
cfg = Config.fromfile(cfg_file)
|
||||
cfg.merge_from_dict(args.options)
|
||||
|
|
Loading…
Reference in New Issue