86 lines
2.7 KiB
Python
86 lines
2.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import os.path as osp
|
|
|
|
from mmdeploy.apis import (extract_model, get_predefined_partition_cfg,
|
|
torch2onnx)
|
|
from mmdeploy.utils import (get_ir_config, get_partition_config,
|
|
get_root_logger, load_config)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Export model to ONNX.')
|
|
parser.add_argument('deploy_cfg', help='deploy config path')
|
|
parser.add_argument('model_cfg', help='model config path')
|
|
parser.add_argument('checkpoint', help='model checkpoint path')
|
|
parser.add_argument('img', help='image used to convert model model')
|
|
parser.add_argument(
|
|
'--work-dir',
|
|
default='./work-dir',
|
|
help='Directory to save output files.')
|
|
parser.add_argument(
|
|
'--device', help='device used for conversion', default='cpu')
|
|
parser.add_argument(
|
|
'--log-level',
|
|
help='set log level',
|
|
default='INFO',
|
|
choices=list(logging._nameToLevel.keys()))
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
logger = get_root_logger(log_level=args.log_level)
|
|
|
|
logger.info(f'torch2onnx: \n\tmodel_cfg: {args.model_cfg} '
|
|
f'\n\tdeploy_cfg: {args.deploy_cfg}')
|
|
|
|
os.makedirs(args.work_dir, exist_ok=True)
|
|
# load deploy_cfg
|
|
deploy_cfg = load_config(args.deploy_cfg)[0]
|
|
save_file = get_ir_config(deploy_cfg)['save_file']
|
|
|
|
torch2onnx(
|
|
args.img,
|
|
args.work_dir,
|
|
save_file,
|
|
deploy_cfg=args.deploy_cfg,
|
|
model_cfg=args.model_cfg,
|
|
model_checkpoint=args.checkpoint,
|
|
device=args.device)
|
|
|
|
# partition model
|
|
partition_cfgs = get_partition_config(deploy_cfg)
|
|
|
|
if partition_cfgs is not None:
|
|
if 'partition_cfg' in partition_cfgs:
|
|
partition_cfgs = partition_cfgs.get('partition_cfg', None)
|
|
else:
|
|
assert 'type' in partition_cfgs
|
|
partition_cfgs = get_predefined_partition_cfg(
|
|
deploy_cfg, partition_cfgs['type'])
|
|
|
|
origin_ir_file = osp.join(args.work_dir, save_file)
|
|
for partition_cfg in partition_cfgs:
|
|
save_file = partition_cfg['save_file']
|
|
save_path = osp.join(args.work_dir, save_file)
|
|
start = partition_cfg['start']
|
|
end = partition_cfg['end']
|
|
dynamic_axes = partition_cfg.get('dynamic_axes', None)
|
|
|
|
extract_model(
|
|
origin_ir_file,
|
|
start,
|
|
end,
|
|
dynamic_axes=dynamic_axes,
|
|
save_file=save_path)
|
|
logger.info(f'torch2onnx finished. Results saved to {args.work_dir}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|