EasyCV/easycv/utils/config_tools.py
tuofeilun 23f2b0e399
Adapt designer (#235)
1. Use original config as startup script. (For details, see refactor config parsing method #225)
2. Refactor the splicing rules of the check_base_cfg_path function in the EasyCV/easycv/utils/config_tools.py
3. Support three ways to pass class_list parameter.
4. Fix the bug that clsevalutor may make mistakes when evaluating top5.
5. Fix the bug that the distributed export cannot export the model.
6. Fix the bug that the load pretrained model key does not match.
7. support cls data source itag.
2022-12-01 17:47:10 +08:00

573 lines
22 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import os.path as osp
import platform
import sys
import tempfile
from importlib import import_module
from mmcv import Config, import_modules_from_strings
import easycv
from easycv.file import io
from easycv.framework.errors import IOError, KeyError, ValueError
from .user_config_params_utils import check_value_type
if platform.system() == 'Windows':
import regex as re
else:
import re
def traverse_replace(d, key, value):
if isinstance(d, (dict, Config)):
for k, v in d.items():
if k == key:
d[k] = value
else:
traverse_replace(v, key, value)
elif isinstance(d, (list, tuple, set)):
for v in d:
traverse_replace(v, key, value)
class WrapperConfig(Config):
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def _substitute_predefined_vars(filename,
temp_config_name,
first_order_params=None):
"""
Override Config._substitute_predefined_vars.
Supports first-order parameter reuse to avoid rebuilding custom config.py templates.
Args:
filename (str): Original script file.
temp_config_name (str): Template script file.
first_order_params (dict): first-order parameters.
Returns:
No return value.
"""
file_dirname = osp.dirname(filename)
file_basename = osp.basename(filename)
file_basename_no_extension = osp.splitext(file_basename)[0]
file_extname = osp.splitext(filename)[1]
support_templates = dict(
fileDirname=file_dirname,
fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname)
with open(filename, encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
left_match, right_match = '{([', '])}'
match_list = []
line_list = []
for line in f:
# Push and pop control regular item matching
match_length_before = len(match_list)
for single_str in line:
if single_str in left_match:
match_list.append(single_str)
if single_str in right_match:
match_list.pop()
match_length_after = len(match_list)
key = line.split('=')[0].strip()
# Check whether it is a first-order parameter
if match_length_before == match_length_after == 0 and first_order_params and key in first_order_params:
value = first_order_params[key]
# repr() is used to convert the data into a string form (in the form of a Python expression) suitable for the interpreter to read
line = ' '.join([key, '=', repr(value)]) + '\n'
line_list.append(line)
config_file = ''.join(line_list)
for key, value in support_templates.items():
regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
value = value.replace('\\', '/')
config_file = re.sub(regexp, value, config_file)
with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
tmp_config_file.write(config_file)
def check_base_cfg_path(base_cfg_name='configs/base.py',
father_cfg_name=None,
easycv_root=None):
"""
Concatenate paths by parsing path rules.
for example(pseudo-code):
1. 'configs' in base_cfg_name or 'benchmarks' in base_cfg_name:
base_cfg_name = easycv_root + base_cfg_name
2. 'configs' not in base_cfg_name and 'benchmarks' not in base_cfg_name:
base_cfg_name = father_cfg_name + base_cfg_name
"""
parse_base_cfg = base_cfg_name.split('/')
if parse_base_cfg[0] == 'configs' or parse_base_cfg[0] == 'benchmarks':
if easycv_root is not None:
base_cfg_name = osp.join(easycv_root, base_cfg_name)
else:
if father_cfg_name is not None:
_parse_base_path_list = base_cfg_name.split('/')
parse_base_path_list = copy.deepcopy(_parse_base_path_list)
parse_ori_path_list = father_cfg_name.split('/')
parse_ori_path_list.pop()
for filename in _parse_base_path_list:
if filename == '.':
parse_base_path_list.pop(0)
elif filename == '..':
parse_base_path_list.pop(0)
parse_ori_path_list.pop()
else:
break
base_cfg_name = '/'.join(parse_ori_path_list +
parse_base_path_list)
return base_cfg_name
# Read config without __base__
def mmcv_file2dict_raw(filename, first_order_params=None):
fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix=fileExtname)
if platform.system() == 'Windows':
temp_config_file.close()
temp_config_name = osp.basename(temp_config_file.name)
if first_order_params is not None:
WrapperConfig._substitute_predefined_vars(filename,
temp_config_file.name,
first_order_params)
else:
Config._substitute_predefined_vars(filename, temp_config_file.name)
if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
# delete imported module
del sys.modules[temp_module_name]
elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(temp_config_file.name)
# close temp file
temp_config_file.close()
cfg_text = filename + '\n'
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
return cfg_dict, cfg_text
# Reac config with __base__
def mmcv_file2dict_base(ori_filename,
first_order_params=None,
easycv_root=None):
cfg_dict, cfg_text = mmcv_file2dict_raw(ori_filename, first_order_params)
BASE_KEY = '_base_'
if BASE_KEY in cfg_dict:
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(base_filename,
list) else [base_filename]
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
base_cfg_path = check_base_cfg_path(
f, ori_filename, easycv_root=easycv_root)
_cfg_dict, _cfg_text = mmcv_file2dict_base(base_cfg_path,
first_order_params)
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)
return cfg_dict, cfg_text
def grouping_params(user_config_params):
first_order_params, multi_order_params = {}, {}
for full_key, v in user_config_params.items():
key_list = full_key.split('.')
if len(key_list) == 1:
first_order_params[full_key] = v
else:
multi_order_params[full_key] = v
return first_order_params, multi_order_params
def adapt_pai_params(cfg_dict, class_list_params=None):
"""
The user passes in the class_list_params.
Args:
cfg_dict (dict): All parameters of cfg.
class_list_params (list): class_list_params[1] is num_classes.
class_list_params[0] supports three ways to build parameters.
str(.txt) parameter construction method: 0, 1, 2 or 0, \n, 1, \n, 2\n or 0, \n, 1, 2 or person, dog, cat.
list parameter construction method: '[0, 1, 2]' or '[person, dog, cat]'
'' parameter construction method: The default setting is str(0) - str(num_classes - 1)
Returns:
cfg_dict (dict): Add the cfg of export and oss.
"""
if class_list_params is not None:
class_list, num_classes = class_list_params[0], class_list_params[1]
if '.txt' in class_list:
cfg_dict['class_list'] = []
with open(class_list, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
lines = f.readlines()
for line in lines:
line = line.strip().strip(',').replace(' ', '').split(',')
cfg_dict['class_list'].extend(line)
elif len(class_list) > 0:
cfg_dict['class_list'] = list(map(str, class_list))
else:
cfg_dict['class_list'] = list(map(str, range(0, num_classes)))
# export config
cfg_dict['export'] = dict(export_neck=True)
cfg_dict['checkpoint_sync_export'] = True
# oss config
cfg_dict['oss_sync_config'] = dict(
other_file_list=['**/events.out.tfevents*', '**/*log*'])
cfg_dict['oss_io_config'] = dict(
ak_id='your oss ak id',
ak_secret='your oss ak secret',
hosts='oss-cn-zhangjiakou.aliyuncs.com',
buckets=['your_bucket_2'])
return cfg_dict
def init_path(ori_filename):
easycv_root = osp.dirname(easycv.__file__) # easycv package root path
parse_ori_filename = ori_filename.split('/')
if parse_ori_filename[0] == 'configs' or parse_ori_filename[
0] == 'benchmarks':
if osp.exists(osp.join(easycv_root, ori_filename)):
ori_filename = osp.join(easycv_root, ori_filename)
return ori_filename, easycv_root
# gen mmcv.Config
def mmcv_config_fromfile(ori_filename):
ori_filename, easycv_root = init_path(ori_filename)
cfg_dict, cfg_text = mmcv_file2dict_base(
ori_filename, easycv_root=easycv_root)
if cfg_dict.get('custom_imports', None):
import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=ori_filename)
def pai_config_fromfile(ori_filename,
user_config_params=None,
model_type=None):
ori_filename, easycv_root = init_path(ori_filename)
if user_config_params is not None:
# set class_list
class_list_params = None
if 'class_list' in user_config_params:
class_list = user_config_params.pop('class_list')
for key, value in user_config_params.items():
if 'num_classes' in key:
class_list_params = [class_list, value]
break
# grouping params
first_order_params, multi_order_params = grouping_params(
user_config_params)
else:
class_list_params, first_order_params, multi_order_params = None, None, None
# replace first-order parameters
cfg_dict, cfg_text = mmcv_file2dict_base(
ori_filename, first_order_params, easycv_root=easycv_root)
# Add export and oss related configuration to adapt to pai platform
if model_type:
cfg_dict = adapt_pai_params(cfg_dict, class_list_params)
if cfg_dict.get('custom_imports', None):
import_modules_from_strings(**cfg_dict['custom_imports'])
cfg = Config(cfg_dict, cfg_text=cfg_text, filename=ori_filename)
# replace multi-order parameters
if multi_order_params:
cfg.merge_from_dict(multi_order_params)
return cfg
# get the true value for ori_key in cfg_dict
def get_config_class_value(cfg_dict, ori_key, dict_mem_helper):
if ori_key in dict_mem_helper:
return dict_mem_helper[ori_key]
k_list = ori_key.split('.')
t = cfg_dict
for at_k in k_list:
if isinstance(t, (list, tuple)):
t = t[int(at_k)]
else:
t = t[at_k]
dict_mem_helper[ori_key] = t
return t
def config_dict_edit(ori_cfg_dict, cfg_dict, reg, dict_mem_helper):
"""
edit ${configs.variables} in config dict to solve dependicies in config
ori_cfg_dict: to find the true value of ${configs.variables}
cfg_dict: for find leafs of dict by recursive
reg: Regular expression pattern for find all ${configs.variables} in leafs of dict
dict_mem_helper: to store the true value of ${configs.variables} which have been found
"""
if isinstance(cfg_dict, dict):
for key, value in cfg_dict.items():
if isinstance(value, str):
var_set = set(reg.findall(value))
if len(var_set) > 0:
n_value = value
is_arithmetic_exp = True
is_exp = True
for var in var_set:
# get the true value for ori_key in cfg_dict
var_value = get_config_class_value(
ori_cfg_dict, var[2:-1], dict_mem_helper)
# while ${variables}, replace by true value directly
if var == value:
is_exp = False
cfg_dict[key] = var_value
break
else:
if isinstance(var_value, str):
# str expression, like '${data_root_path}/images/'
is_arithmetic_exp = False
# arithmetic expression, like '-${img_size}//2'
n_value = n_value.replace(var, str(var_value))
if is_exp:
cfg_dict[key] = eval(
n_value) if is_arithmetic_exp else n_value
# recursively find the leafs of dict
config_dict_edit(ori_cfg_dict, value, reg, dict_mem_helper)
elif isinstance(cfg_dict, list):
for i, value in enumerate(cfg_dict):
if isinstance(value, dict):
# recursively find the leafs of dict
config_dict_edit(ori_cfg_dict, value, reg, dict_mem_helper)
elif isinstance(value, list):
# recursively find the leafs of list
config_dict_edit(ori_cfg_dict, value, reg, dict_mem_helper)
elif isinstance(value, str):
var_set = set(reg.findall(value))
if len(var_set) > 0:
n_value = value
is_arithmetic_exp = True
is_exp = True
for var in var_set:
# get the true value for ori_key in cfg_dict
var_value = get_config_class_value(
ori_cfg_dict, var[2:-1], dict_mem_helper)
# while '${variables}', replace by true value directly
if var == value:
# is not an expression, directly replace
is_exp = False
cfg_dict[i] = var_value
break
else: # arithmetic expression
if isinstance(var_value, str):
# str expression, like '${data_root_path}/images/'
is_arithmetic_exp = False
# arithmetic expression, like '-${img_size}//2'
n_value = n_value.replace(var, str(var_value))
if is_exp:
cfg_dict[i] = eval(
n_value) if is_arithmetic_exp else n_value
def rebuild_config(cfg, user_config_params):
"""
# rebuild config by user config params,
modify config by user config params & replace ${configs.variables} by true value
return: Config
"""
print(user_config_params)
assert len(user_config_params
) % 2 == 0, 'user_config_params must be setted as --key value'
new_text = cfg.text + '\n\n#user config params\n'
for ori_args_k, args_v in zip(user_config_params[0::2],
user_config_params[1::2]):
assert ori_args_k.startswith('--')
args_k = ori_args_k[2:]
t = cfg
attr_list = args_k.split('.')
# find the location of key
for at_k in attr_list[:-1]:
if isinstance(t, (list, tuple)):
try:
t = t[int(at_k)]
except:
raise IndexError('%s is out of index' % at_k)
else:
try:
t = getattr(t, at_k)
except:
raise KeyError('%s is not in dict' % at_k)
# replace the value of key bey user config params
if isinstance(t, (list, tuple)): # set value of list[index]
# convert str_type to value_type
formatted_v = check_value_type(args_v, t[int(attr_list[-1])])
try:
t[int(attr_list[-1])] = formatted_v
except:
raise IndexError('%s is out of index' % attr_list[-1])
else: # set value of dict[map]
# convert str_type to value_type
formatted_v = check_value_type(args_v, getattr(t, attr_list[-1]))
try:
setattr(t, attr_list[-1], formatted_v)
except:
raise KeyError('set %s is error' % attr_list[-1])
new_text += '%s %s\n' % (ori_args_k, args_v)
# reg example: '-${a.b.ca}+${a.b.ca}+${d.c}'
reg = re.compile('\$\{[a-zA-Z_][a-zA-Z0-9_\.]*\}')
dict_mem_helper = {}
# edit ${configs.variables} in config dict to solve dependicies in config
config_dict_edit(cfg._cfg_dict, cfg._cfg_dict, reg, dict_mem_helper)
return Config(
cfg_dict=cfg._cfg_dict, cfg_text=new_text, filename=cfg._filename)
# validate config for export, which may be from training config, disable pretrained
def validate_export_config(cfg):
cfg_copy = copy.deepcopy(cfg)
pretrained = getattr(cfg_copy.model, 'pretrained', None)
if pretrained is not None:
setattr(cfg_copy.model, 'pretrained', None)
backbone = getattr(cfg_copy.model, 'backbone', None)
if backbone is not None:
pretrained = getattr(backbone, 'pretrained', None)
if pretrained is not None:
setattr(backbone, 'pretrained', False)
return cfg_copy
CONFIG_TEMPLATE_ZOO = {
# cls
'CLASSIFICATION_RESNET':
'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py',
'CLASSIFICATION_RESNEXT':
'configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py',
'CLASSIFICATION_HRNET':
'configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py',
'CLASSIFICATION_VIT':
'configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py',
'CLASSIFICATION_SWINT':
'configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py',
# metric learning
'METRICLEARNING':
'configs/metric_learning/imagenet_timm_softmaxbased_jpg.py',
'MODELPARALLEL_METRICLEARNING':
'configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py',
# detection
'YOLOX': 'configs/config_templates/yolox.py',
'YOLOX_ITAG': 'configs/config_templates/yolox_itag.py',
# ssl
'MOCO_R50_TFRECORD': 'configs/config_templates/moco_r50_tfrecord.py',
'MOCO_R50_TFRECORD_OSS':
'configs/config_templates/moco_r50_tfrecord_oss.py',
'MOCO_TIMM_TFRECORD': 'configs/config_templates/moco_timm_tfrecord.py',
'MOCO_TIMM_TFRECORD_OSS':
'configs/config_templates/moco_timm_tfrecord_oss.py',
'SWAV_R50_TFRECORD': 'configs/config_templates/swav_r50_tfrecord.py',
'SWAV_R50_TFRECORD_OSS':
'configs/config_templates/swav_r50_tfrecord_oss.py',
'MOBY_TIMM_TFRECORD_OSS':
'configs/config_templates/moby_timm_tfrecord_oss.py',
'DINO_TIMM': 'configs/config_templates/dino_timm.py',
'DINO_TIMM_TFRECORD_OSS':
'configs/config_templates/dino_timm_tfrecord_oss.py',
'DINO_R50_TFRECORD_OSS':
'configs/config_templates/dino_rn50_tfrecord_oss.py',
'MAE': 'configs/config_templates/mae_vit_base_patch16.py',
# edge models
'YOLOX_EDGE': 'configs/config_templates/yolox_edge.py',
'YOLOX_EDGE_ITAG': 'configs/config_templates/yolox_edge_itag.py',
# pose
'TOPDOWN_HRNET': 'configs/config_templates/topdown_hrnet_w48_udp.py',
'TOPDOWN_LITEHRNET': 'configs/config_templates/topdown_litehrnet_30.py',
}