update to_ms_config

Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9891090

    * update to_ms_config
pull/191/head
jiangnana.jnn 2022-08-25 15:31:57 +08:00
parent 49dd7c2359
commit bc64851614
2 changed files with 27 additions and 3 deletions

View File

@ -7,10 +7,21 @@ from easycv.file import io
from easycv.utils.config_tools import Config
MODELSCOPE_PREFIX = 'modelscope'
EASYCV_ARCH = '__easycv_arch__'
def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
class EasyCVMeta:
ARCH = '__easycv_arch__'
META = '__easycv_meta__'
RESERVED_KEYS = 'reserved_keys'
def to_ms_config(cfg,
task,
ms_model_name,
save_path=None,
reserved_keys=[],
dump=True):
"""Convert EasyCV config to ModelScope style.
Args:
@ -18,6 +29,8 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
reserved_keys (list of str): Keys conversion may loss some of the original global keys, not all keys will be retained.
If you need to keep some keys, for example, keep the `CLASSES` key of config for inference, you can specify: reserved_keys=['CLASSES'].
dump (bool): Whether dump the converted config to `save_path`.
"""
# TODO: support multi eval_pipelines
@ -90,7 +103,7 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
framework='pytorch',
model={
'type': ms_model_name,
**easycv_cfg.model, EASYCV_ARCH: {
**easycv_cfg.model, EasyCVMeta.ARCH: {
'type': ori_model_type
}
},
@ -118,6 +131,15 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
pipeline=dict(predictor_config=predict_config),
))
for key in reserved_keys:
ms_cfg.merge_from_dict({key: getattr(easycv_cfg, key)})
if len(reserved_keys) > 1:
ms_cfg.merge_from_dict(
{EasyCVMeta.META: {
EasyCVMeta.RESERVED_KEYS: reserved_keys
}})
if dump:
with io.open(save_path, 'w') as f:
res = jsonplus.dumps(

View File

@ -33,10 +33,12 @@ class MsConfigTest(unittest.TestCase):
config_path,
task='image-object-detection',
ms_model_name='yolox',
reserved_keys=['CLASSES'],
save_path=ms_cfg_file)
cfg = Config.fromfile(ms_cfg_file)
self.assertIn('task', cfg)
self.assertIn('framework', cfg)
self.assertIn('CLASSES', cfg)
self.assertEqual(cfg.model.type, 'yolox')
self.assertIn('dataset', cfg)
self.assertIn('batch_size_per_gpu', cfg.train.dataloader)