EasyCV/easycv/utils/ms_utils.py

126 lines
4.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import jsonplus
from easycv.file import io
from easycv.utils.config_tools import Config
MODELSCOPE_PREFIX = 'modelscope'
EASYCV_ARCH = '__easycv_arch__'
def to_ms_config(cfg, task, ms_model_name, save_path=None, dump=True):
"""Convert EasyCV config to ModelScope style.
Args:
cfg (str | Config): Easycv config file or Config object.
task (str): Task name in modelscope, refer to: modelscope.utils.constant.Tasks.
ms_model_name (str): Model name registered in modelscope, model type will be replaced with `ms_model_name`, used in modelscope.
save_path (str): Save path for saving the generated modelscope configuration file. Only valid when dump is True.
dump (bool): Whether dump the converted config to `save_path`.
"""
# TODO: support multi eval_pipelines
# TODO: support for adding customized required keys to the configuration file
if isinstance(cfg, str):
easycv_cfg = Config.fromfile(cfg)
if dump and save_path is None:
save_dir = os.path.dirname(cfg)
save_name = MODELSCOPE_PREFIX + '_' + os.path.splitext(
os.path.basename(cfg))[0] + '.json'
save_path = os.path.join(save_dir, save_name)
else:
easycv_cfg = cfg
if dump and save_path is None:
raise ValueError('Please provide `save_path`!')
assert save_path.endswith('json'), 'Only support json file!'
optimizer_options = easycv_cfg.optimizer_config
optimizer_options.update({'loss_keys': 'total_loss'})
val_dataset_cfg = easycv_cfg.data.val
val_imgs_per_gpu = val_dataset_cfg.pop('imgs_per_gpu',
easycv_cfg.data.imgs_per_gpu)
val_workers_per_gpu = val_dataset_cfg.pop('workers_per_gpu',
easycv_cfg.data.workers_per_gpu)
log_config = easycv_cfg.log_config
predict_config = easycv_cfg.get('predict', None)
hooks = [{
'type': 'CheckpointHook',
'interval': easycv_cfg.checkpoint_config.interval
}, {
'type': 'EvaluationHook',
'interval': easycv_cfg.eval_config.interval
}, {
'type': 'AddLrLogHook'
}, {
'type': 'IterTimerHook'
}]
custom_hooks = easycv_cfg.get('custom_hooks', [])
hooks.extend(custom_hooks)
for log_hook_i in log_config.hooks:
if log_hook_i['type'] == 'TensorboardLoggerHook':
# replace with modelscope api
hooks.append({
'type': 'TensorboardHook',
'interval': log_config.interval
})
elif log_hook_i['type'] == 'TextLoggerHook':
# use modelscope api
hooks.append({
'type': 'TextLoggerHook',
'interval': log_config.interval
})
else:
log_hook_i.update({'interval': log_config.interval})
hooks.append(log_hook_i)
ori_model_type = easycv_cfg.model.pop('type')
ms_cfg = Config(
dict(
task=task,
framework='pytorch',
model={
'type': ms_model_name,
**easycv_cfg.model, EASYCV_ARCH: {
'type': ori_model_type
}
},
dataset=dict(train=easycv_cfg.data.train, val=val_dataset_cfg),
train=dict(
work_dir=easycv_cfg.get('work_dir', None),
max_epochs=easycv_cfg.total_epochs,
dataloader=dict(
batch_size_per_gpu=easycv_cfg.data.imgs_per_gpu,
workers_per_gpu=easycv_cfg.data.workers_per_gpu,
),
optimizer=dict(
**easycv_cfg.optimizer, options=optimizer_options),
lr_scheduler=easycv_cfg.lr_config,
hooks=hooks),
evaluation=dict(
dataloader=dict(
batch_size_per_gpu=val_imgs_per_gpu,
workers_per_gpu=val_workers_per_gpu,
),
metrics={
'type': 'EasyCVMetric',
'evaluators': easycv_cfg.eval_pipelines[0].evaluators
}),
pipeline=dict(predictor_config=predict_config),
))
if dump:
with io.open(save_path, 'w') as f:
res = jsonplus.dumps(
ms_cfg._cfg_dict.to_dict(), indent=4, sort_keys=False)
f.write(res)
return ms_cfg