60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import logging
|
|
import os.path as osp
|
|
|
|
from mmdeploy.apis import torch2onnx
|
|
from mmdeploy.utils import get_root_logger
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Export model to ONNX.')
|
|
parser.add_argument('deploy_cfg', help='deploy config path')
|
|
parser.add_argument('model_cfg', help='model config path')
|
|
parser.add_argument('checkpoint', help='model checkpoint path')
|
|
parser.add_argument('img', help='image used to convert model model')
|
|
parser.add_argument('output', help='output onnx path')
|
|
parser.add_argument(
|
|
'--device', help='device used for conversion', default='cpu')
|
|
parser.add_argument(
|
|
'--log-level',
|
|
help='set log level',
|
|
default='INFO',
|
|
choices=list(logging._nameToLevel.keys()))
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
logger = get_root_logger(log_level=args.log_level)
|
|
|
|
deploy_cfg_path = args.deploy_cfg
|
|
model_cfg_path = args.model_cfg
|
|
checkpoint_path = args.checkpoint
|
|
img = args.img
|
|
output_path = args.output
|
|
work_dir, save_file = osp.split(output_path)
|
|
device = args.device
|
|
|
|
logger.info(f'torch2onnx: \n\tmodel_cfg: {model_cfg_path} '
|
|
f'\n\tdeploy_cfg: {deploy_cfg_path}')
|
|
try:
|
|
torch2onnx(
|
|
img,
|
|
work_dir,
|
|
save_file,
|
|
deploy_cfg=deploy_cfg_path,
|
|
model_cfg=model_cfg_path,
|
|
model_checkpoint=checkpoint_path,
|
|
device=device)
|
|
logger.info('torch2onnx success.')
|
|
except Exception as e:
|
|
logger.error(e)
|
|
logger.error('torch2onnx failed.')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|