64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import argparse
|
||
|
import logging
|
||
|
import os.path as osp
|
||
|
|
||
|
from mmdeploy.apis import torch2onnx
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser(description='Export model to ONNX.')
|
||
|
parser.add_argument('deploy_cfg', help='deploy config path')
|
||
|
parser.add_argument('model_cfg', help='model config path')
|
||
|
parser.add_argument('checkpoint', help='model checkpoint path')
|
||
|
parser.add_argument('img', help='image used to convert model model')
|
||
|
parser.add_argument('output', help='output onnx path')
|
||
|
parser.add_argument(
|
||
|
'--device', help='device used for conversion', default='cpu')
|
||
|
parser.add_argument(
|
||
|
'--log-level',
|
||
|
help='set log level',
|
||
|
default='INFO',
|
||
|
choices=list(logging._nameToLevel.keys()))
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
return args
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse_args()
|
||
|
logging.basicConfig(
|
||
|
format='%(asctime)s,%(name)s %(levelname)-8s'
|
||
|
' [%(filename)s:%(lineno)d] %(message)s',
|
||
|
datefmt='%Y-%m-%d:%H:%M:%S')
|
||
|
logger = logging.getLogger()
|
||
|
logger.setLevel(args.log_level)
|
||
|
|
||
|
deploy_cfg_path = args.deploy_cfg
|
||
|
model_cfg_path = args.model_cfg
|
||
|
checkpoint_path = args.checkpoint
|
||
|
img = args.img
|
||
|
output_path = args.output
|
||
|
work_dir, save_file = osp.split(output_path)
|
||
|
device = args.device
|
||
|
|
||
|
logging.info(f'torch2onnx: \n\tmodel_cfg: {model_cfg_path} '
|
||
|
f'\n\tdeploy_cfg: {deploy_cfg_path}')
|
||
|
try:
|
||
|
torch2onnx(
|
||
|
img,
|
||
|
work_dir,
|
||
|
save_file,
|
||
|
deploy_cfg=deploy_cfg_path,
|
||
|
model_cfg=model_cfg_path,
|
||
|
model_checkpoint=checkpoint_path,
|
||
|
device=device)
|
||
|
logging.info('torch2onnx success.')
|
||
|
except Exception as e:
|
||
|
logging.error(e)
|
||
|
logging.error('torch2onnx failed.')
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|