commit
27bbad88a3
|
@ -14,13 +14,12 @@
|
|||
|
||||
import os
|
||||
import yaml
|
||||
|
||||
from ppcls.utils import check
|
||||
from ppcls.utils import logger
|
||||
|
||||
__all__ = ['get_config']
|
||||
|
||||
CONFIG_SECS = ['ARCHITECTURE', 'TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE']
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, key):
|
||||
|
@ -47,13 +46,12 @@ def create_attr_dict(yaml_config):
|
|||
create_attr_dict(yaml_config[key])
|
||||
else:
|
||||
yaml_config[key] = value
|
||||
return
|
||||
|
||||
|
||||
def parse_config(cfg_file):
|
||||
"""Load a config file into AttrDict"""
|
||||
with open(cfg_file, 'r') as fopen:
|
||||
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.FullLoader))
|
||||
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
|
||||
create_attr_dict(yaml_config)
|
||||
return yaml_config
|
||||
|
||||
|
@ -63,10 +61,8 @@ def print_dict(d, delimiter=0):
|
|||
Recursively visualize a dict and
|
||||
indenting acrrording by the relationship of keys.
|
||||
"""
|
||||
for k, v in d.items():
|
||||
if k in CONFIG_SECS:
|
||||
logger.info("-" * 60)
|
||||
|
||||
placeholder = "-" * 60
|
||||
for k, v in sorted(d.items()):
|
||||
if isinstance(v, dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ", k))
|
||||
print_dict(v, delimiter + 4)
|
||||
|
@ -77,8 +73,8 @@ def print_dict(d, delimiter=0):
|
|||
else:
|
||||
logger.info("{}{} : {}".format(delimiter * " ", k, v))
|
||||
|
||||
if k in CONFIG_SECS:
|
||||
logger.info("-" * 60)
|
||||
if k.isupper():
|
||||
logger.info(placeholder)
|
||||
|
||||
|
||||
def print_config(config):
|
||||
|
@ -88,18 +84,22 @@ def print_config(config):
|
|||
Arguments:
|
||||
config: configs
|
||||
"""
|
||||
copyright = "PaddleClas is powered by PaddlePaddle !"
|
||||
info = "For more info please go to the following website."
|
||||
website = "https://github.com/PaddlePaddle/PaddleClas"
|
||||
AD_LEN = 55
|
||||
|
||||
copyright = "PaddleClas is powered by PaddlePaddle"
|
||||
ad = "https://github.com/PaddlePaddle/PaddleClas"
|
||||
|
||||
logger.info("\n" * 2)
|
||||
logger.info(copyright)
|
||||
logger.info(ad)
|
||||
|
||||
logger.info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(copyright.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(info.center(AD_LEN)),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(website.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4), ))
|
||||
print_dict(config)
|
||||
|
||||
logger.info("-" * 60)
|
||||
|
||||
|
||||
def check_config(config):
|
||||
"""
|
||||
|
@ -157,7 +157,7 @@ def override(dl, ks, v):
|
|||
override(dl[ks[0]], ks[1:], v)
|
||||
|
||||
|
||||
def override_config(config, options=[]):
|
||||
def override_config(config, options=None):
|
||||
"""
|
||||
Recursively override the config
|
||||
|
||||
|
@ -172,32 +172,31 @@ def override_config(config, options=[]):
|
|||
Returns:
|
||||
config(dict): replaced config
|
||||
"""
|
||||
for opt in options:
|
||||
assert isinstance(opt, str), \
|
||||
("option({}) should be a str".format(opt))
|
||||
assert "=" in opt, ("option({}) should contain " \
|
||||
"a = to distinguish between key and value".format(opt))
|
||||
pair = opt.split('=')
|
||||
assert len(pair) == 2, ("there can be only a = in the option")
|
||||
key, value = pair
|
||||
keys = key.split('.')
|
||||
override(config, keys, value)
|
||||
if options is not None:
|
||||
for opt in options:
|
||||
assert isinstance(opt, str), (
|
||||
"option({}) should be a str".format(opt))
|
||||
assert "=" in opt, (
|
||||
"option({}) should contain a ="
|
||||
"to distinguish between key and value".format(opt))
|
||||
pair = opt.split('=')
|
||||
assert len(pair) == 2, ("there can be only a = in the option")
|
||||
key, value = pair
|
||||
keys = key.split('.')
|
||||
override(config, keys, value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_config(fname, overrides=[], show=True):
|
||||
def get_config(fname, overrides=None, show=True):
|
||||
"""
|
||||
Read config from file
|
||||
"""
|
||||
assert os.path.exists(fname), \
|
||||
('config file({}) is not exist'.format(fname))
|
||||
assert os.path.exists(fname), (
|
||||
'config file({}) is not exist'.format(fname))
|
||||
config = parse_config(fname)
|
||||
override_config(config, overrides)
|
||||
if show:
|
||||
print_config(config)
|
||||
if len(overrides) > 0:
|
||||
override_config(config, overrides)
|
||||
if show:
|
||||
print_config(config)
|
||||
check_config(config)
|
||||
return config
|
||||
|
|
Loading…
Reference in New Issue