mirror of https://github.com/alibaba/EasyCV.git
update to_ms_config
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/9891090 * update to_ms_configpull/191/head
parent
49dd7c2359
commit
bc64851614
|
@ -7,10 +7,21 @@ from easycv.file import io
|
||||||
from easycv.utils.config_tools import Config
|
from easycv.utils.config_tools import Config
|
||||||
|
|
||||||
MODELSCOPE_PREFIX = 'modelscope'
|
MODELSCOPE_PREFIX = 'modelscope'
|
||||||
EASYCV_ARCH = '__easycv_arch__'
|
|
||||||
|
|
||||||
|
|
||||||
def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
class EasyCVMeta:
|
||||||
|
ARCH = '__easycv_arch__'
|
||||||
|
|
||||||
|
META = '__easycv_meta__'
|
||||||
|
RESERVED_KEYS = 'reserved_keys'
|
||||||
|
|
||||||
|
|
||||||
|
def to_ms_config(cfg,
|
||||||
|
task,
|
||||||
|
ms_model_name,
|
||||||
|
save_path=None,
|
||||||
|
reserved_keys=[],
|
||||||
|
dump=True):
|
||||||
"""Convert EasyCV config to ModelScope style.
|
"""Convert EasyCV config to ModelScope style.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -18,6 +29,8 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
||||||
task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
|
task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
|
||||||
ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
|
ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
|
||||||
save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
|
save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
|
||||||
|
reserved_keys (list of str): Keys conversion may loss some of the original global keys, not all keys will be retained.
|
||||||
|
If you need to keep some keys, for example, keep the `CLASSES` key of config for inference, you can specify: reserved_keys=['CLASSES'].
|
||||||
dump (bool): Whether dump the converted config to `save_path`.
|
dump (bool): Whether dump the converted config to `save_path`.
|
||||||
"""
|
"""
|
||||||
# TODO: support multi eval_pipelines
|
# TODO: support multi eval_pipelines
|
||||||
|
@ -90,7 +103,7 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
||||||
framework='pytorch',
|
framework='pytorch',
|
||||||
model={
|
model={
|
||||||
'type': ms_model_name,
|
'type': ms_model_name,
|
||||||
**easycv_cfg.model, EASYCV_ARCH: {
|
**easycv_cfg.model, EasyCVMeta.ARCH: {
|
||||||
'type': ori_model_type
|
'type': ori_model_type
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -118,6 +131,15 @@ def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
|
||||||
pipeline=dict(predictor_config=predict_config),
|
pipeline=dict(predictor_config=predict_config),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
for key in reserved_keys:
|
||||||
|
ms_cfg.merge_from_dict({key: getattr(easycv_cfg, key)})
|
||||||
|
|
||||||
|
if len(reserved_keys) > 1:
|
||||||
|
ms_cfg.merge_from_dict(
|
||||||
|
{EasyCVMeta.META: {
|
||||||
|
EasyCVMeta.RESERVED_KEYS: reserved_keys
|
||||||
|
}})
|
||||||
|
|
||||||
if dump:
|
if dump:
|
||||||
with io.open(save_path, 'w') as f:
|
with io.open(save_path, 'w') as f:
|
||||||
res = jsonplus.dumps(
|
res = jsonplus.dumps(
|
||||||
|
|
|
@ -33,10 +33,12 @@ class MsConfigTest(unittest.TestCase):
|
||||||
config_path,
|
config_path,
|
||||||
task='image-object-detection',
|
task='image-object-detection',
|
||||||
ms_model_name='yolox',
|
ms_model_name='yolox',
|
||||||
|
reserved_keys=['CLASSES'],
|
||||||
save_path=ms_cfg_file)
|
save_path=ms_cfg_file)
|
||||||
cfg = Config.fromfile(ms_cfg_file)
|
cfg = Config.fromfile(ms_cfg_file)
|
||||||
self.assertIn('task', cfg)
|
self.assertIn('task', cfg)
|
||||||
self.assertIn('framework', cfg)
|
self.assertIn('framework', cfg)
|
||||||
|
self.assertIn('CLASSES', cfg)
|
||||||
self.assertEqual(cfg.model.type, 'yolox')
|
self.assertEqual(cfg.model.type, 'yolox')
|
||||||
self.assertIn('dataset', cfg)
|
self.assertIn('dataset', cfg)
|
||||||
self.assertIn('batch_size_per_gpu', cfg.train.dataloader)
|
self.assertIn('batch_size_per_gpu', cfg.train.dataloader)
|
||||||
|
|
Loading…
Reference in New Issue