mirror of https://github.com/alibaba/EasyCV.git
update to_ms_config
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9891090 * update to_ms_configpull/191/head
parent
49dd7c2359
commit
bc64851614
|
@ -7,10 +7,21 @@ from easycv.file import io
|
|||
from easycv.utils.config_tools import Config
|
||||
|
||||
MODELSCOPE_PREFIX = 'modelscope'
|
||||
EASYCV_ARCH = '__easycv_arch__'
|
||||
|
||||
|
||||
def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
||||
class EasyCVMeta:
|
||||
ARCH = '__easycv_arch__'
|
||||
|
||||
META = '__easycv_meta__'
|
||||
RESERVED_KEYS = 'reserved_keys'
|
||||
|
||||
|
||||
def to_ms_config(cfg,
|
||||
task,
|
||||
ms_model_name,
|
||||
save_path=None,
|
||||
reserved_keys=[],
|
||||
dump=True):
|
||||
"""Convert EasyCV config to ModelScope style.
|
||||
|
||||
Args:
|
||||
|
@ -18,6 +29,8 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
|||
task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
|
||||
ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
|
||||
save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
|
||||
reserved_keys (list of str): Keys conversion may loss some of the original global keys, not all keys will be retained.
|
||||
If you need to keep some keys, for example, keep the `CLASSES` key of config for inference, you can specify: reserved_keys=['CLASSES'].
|
||||
dump (bool): Whether dump the converted config to `save_path`.
|
||||
"""
|
||||
# TODO: support multi eval_pipelines
|
||||
|
@ -90,7 +103,7 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
|||
framework='pytorch',
|
||||
model={
|
||||
'type': ms_model_name,
|
||||
**easycv_cfg.model, EASYCV_ARCH: {
|
||||
**easycv_cfg.model, EasyCVMeta.ARCH: {
|
||||
'type': ori_model_type
|
||||
}
|
||||
},
|
||||
|
@ -118,6 +131,15 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
|||
pipeline=dict(predictor_config=predict_config),
|
||||
))
|
||||
|
||||
for key in reserved_keys:
|
||||
ms_cfg.merge_from_dict({key: getattr(easycv_cfg, key)})
|
||||
|
||||
if len(reserved_keys) > 1:
|
||||
ms_cfg.merge_from_dict(
|
||||
{EasyCVMeta.META: {
|
||||
EasyCVMeta.RESERVED_KEYS: reserved_keys
|
||||
}})
|
||||
|
||||
if dump:
|
||||
with io.open(save_path, 'w') as f:
|
||||
res = jsonplus.dumps(
|
||||
|
|
|
@ -33,10 +33,12 @@ class MsConfigTest(unittest.TestCase):
|
|||
config_path,
|
||||
task='image-object-detection',
|
||||
ms_model_name='yolox',
|
||||
reserved_keys=['CLASSES'],
|
||||
save_path=ms_cfg_file)
|
||||
cfg = Config.fromfile(ms_cfg_file)
|
||||
self.assertIn('task', cfg)
|
||||
self.assertIn('framework', cfg)
|
||||
self.assertIn('CLASSES', cfg)
|
||||
self.assertEqual(cfg.model.type, 'yolox')
|
||||
self.assertIn('dataset', cfg)
|
||||
self.assertIn('batch_size_per_gpu', cfg.train.dataloader)
|
||||
|
|
Loading…
Reference in New Issue