From bc648516147f0a972e2b6565719f8c5c1d717e01 Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Thu, 25 Aug 2022 15:31:57 +0800 Subject: [PATCH] update to_ms_config Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9891090 * update to_ms_config --- easycv/utils/ms_utils.py | 28 +++++++++++++++++++++++++--- tests/utils/test_ms_utils.py | 2 ++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/easycv/utils/ms_utils.py b/easycv/utils/ms_utils.py index e6625673..34273986 100644 --- a/easycv/utils/ms_utils.py +++ b/easycv/utils/ms_utils.py @@ -7,10 +7,21 @@ from easycv.file import io from easycv.utils.config_tools import Config MODELSCOPE_PREFIX = 'modelscope' -EASYCV_ARCH = '__easycv_arch__' -def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True): +class EasyCVMeta: + ARCH = '__easycv_arch__' + + META = '__easycv_meta__' + RESERVED_KEYS = 'reserved_keys' + + +def to_ms_config(cfg, + task, + ms_model_name, + save_path=None, + reserved_keys=[], + dump=True): """Convert EasyCV config to ModelScope style. Args: @@ -18,6 +29,8 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True): task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks. ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope. save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True. + reserved_keys (list of str): Keys conversion may loss some of the original global keys, not all keys will be retained. + If you need to keep some keys, for example, keep the `CLASSES` key of config for inference, you can specify: reserved_keys=['CLASSES']. dump (bool): Whether dump the converted config to `save_path`. """ # TODO: support multi eval_pipelines @@ -90,7 +103,7 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True): framework='pytorch', model={ 'type': ms_model_name, - **easycv_cfg.model, EASYCV_ARCH: { + **easycv_cfg.model, EasyCVMeta.ARCH: { 'type': ori_model_type } }, @@ -118,6 +131,15 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True): pipeline=dict(predictor_config=predict_config), )) + for key in reserved_keys: + ms_cfg.merge_from_dict({key: getattr(easycv_cfg, key)}) + + if len(reserved_keys) > 1: + ms_cfg.merge_from_dict( + {EasyCVMeta.META: { + EasyCVMeta.RESERVED_KEYS: reserved_keys + }}) + if dump: with io.open(save_path, 'w') as f: res = jsonplus.dumps( diff --git a/tests/utils/test_ms_utils.py b/tests/utils/test_ms_utils.py index 88da7d27..3b0eb347 100644 --- a/tests/utils/test_ms_utils.py +++ b/tests/utils/test_ms_utils.py @@ -33,10 +33,12 @@ class MsConfigTest(unittest.TestCase): config_path, task='image-object-detection', ms_model_name='yolox', + reserved_keys=['CLASSES'], save_path=ms_cfg_file) cfg = Config.fromfile(ms_cfg_file) self.assertIn('task', cfg) self.assertIn('framework', cfg) + self.assertIn('CLASSES', cfg) self.assertEqual(cfg.model.type, 'yolox') self.assertIn('dataset', cfg) self.assertIn('batch_size_per_gpu', cfg.train.dataloader)