mmdeploy/tools/torch2onnx.py

86 lines
2.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os
import os.path as osp
from mmdeploy.apis import (extract_model, get_predefined_partition_cfg,
torch2onnx)
from mmdeploy.utils import (get_ir_config, get_partition_config,
get_root_logger, load_config)
def parse_args():
parser = argparse.ArgumentParser(description='Export model to ONNX.')
parser.add_argument('deploy_cfg', help='deploy config path')
parser.add_argument('model_cfg', help='model config path')
parser.add_argument('checkpoint', help='model checkpoint path')
parser.add_argument('img', help='image used to convert model model')
parser.add_argument(
'--work-dir',
default='./work-dir',
help='Directory to save output files.')
parser.add_argument(
'--device', help='device used for conversion', default='cpu')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
return args
def main():
args = parse_args()
logger = get_root_logger(log_level=args.log_level)
logger.info(f'torch2onnx: \n\tmodel_cfg: {args.model_cfg} '
f'\n\tdeploy_cfg: {args.deploy_cfg}')
os.makedirs(args.work_dir, exist_ok=True)
# load deploy_cfg
deploy_cfg = load_config(args.deploy_cfg)[0]
save_file = get_ir_config(deploy_cfg)['save_file']
torch2onnx(
args.img,
args.work_dir,
save_file,
deploy_cfg=args.deploy_cfg,
model_cfg=args.model_cfg,
model_checkpoint=args.checkpoint,
device=args.device)
# partition model
partition_cfgs = get_partition_config(deploy_cfg)
if partition_cfgs is not None:
if 'partition_cfg' in partition_cfgs:
partition_cfgs = partition_cfgs.get('partition_cfg', None)
else:
assert 'type' in partition_cfgs
partition_cfgs = get_predefined_partition_cfg(
deploy_cfg, partition_cfgs['type'])
origin_ir_file = osp.join(args.work_dir, save_file)
for partition_cfg in partition_cfgs:
save_file = partition_cfg['save_file']
save_path = osp.join(args.work_dir, save_file)
start = partition_cfg['start']
end = partition_cfg['end']
dynamic_axes = partition_cfg.get('dynamic_axes', None)
extract_model(
origin_ir_file,
start,
end,
dynamic_axes=dynamic_axes,
save_file=save_path)
logger.info(f'torch2onnx finished. Results saved to {args.work_dir}')
if __name__ == '__main__':
main()