mmdeploy/.github/scripts/test_onnx2ncnn.py
tpoisonooo d04c8dc9c0
refactor(onnx2ncnn): add test case and simplify code (#436)
* refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils

* refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils

* refactor(onnx2ncnn.cpp): split code

* refactor(net_module.cpp): fix build error

* ci(test_onnx2ncnn.py): add generate model adn run

* ci(onnx2ncnn): add ncnn backend

* ci(test_onnx2ncnn): add converted onnx model`

* ci(onnx2ncnn): fix ncnn tar

* ci(backed-ncnn): simplify dependency install

* ci(onnx2ncnn): fix apt install

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* Update backend-ncnn.yml

* fix(ci): add include algorithm

* Update build.yml

* parent aa857605319f63bc624a11956e1cd66b5389e4bf
author q.yao <streetyao@live.com> 1651287879 +0800
committer tpoisonooo <khj.application@aliyun.com> 1652169959 +0800

[Fix] Fix ci (#426)

* fix ci

* add nvidia key

* remote torch

* recover pytorch

refactor(onnx2ncnn.cpp): split it to shape_inference, pass and utils

* fix(onnx2ncnn): review

* fix(onnx2ncnn): build error

Co-authored-by: q.yao <streetyao@live.com>
2022-05-16 10:36:25 +08:00

98 lines
3.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import subprocess
# list of tuple: config, pretrained model, onnx filename
CONFIGS = [
(
'mmclassification/configs/vision_transformer/vit-base-p32_ft-64xb64_in1k-384.py', # noqa: E501
'https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth', # noqa: E501
'vit.onnx'),
(
'mmclassification/configs/resnet/resnet50_8xb32_in1k.py',
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', # noqa: E501
'resnet50.onnx',
),
(
'mmclassification/configs/resnet/resnet18_8xb32_in1k.py',
'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth', # noqa: E501
'resnet18.onnx',
'https://media.githubusercontent.com/media/tpoisonooo/mmdeploy-onnx2ncnn-testdata/main/resnet18.onnx', # noqa: E501
),
(
'mmclassification/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py',
'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth', # noqa: E501
'mobilenet-v2.onnx',
'https://media.githubusercontent.com/media/tpoisonooo/mmdeploy-onnx2ncnn-testdata/main/mobilenet-v2.onnx', # noqa: E501
)
]
def parse_args():
parser = argparse.ArgumentParser(
description='MMDeploy onnx2ncnn test tool.')
parser.add_argument('--run', type=bool, help='Execute onnx2ncnn bin.')
parser.add_argument(
'--repo-dir', type=str, default='~/', help='mmcls directory.')
parser.add_argument(
'--out',
type=str,
default='onnx_output',
help='onnx model output directory.')
parser.add_argument(
'--generate-onnx', type=bool, help='Generate onnx model.')
args = parser.parse_args()
return args
def generate_onnx(args):
import mmcv
mmcv.mkdir_or_exist(args.out)
for conf in CONFIGS:
config = os.path.join(args.repo_dir, conf[0])
model = conf[1]
convert_cmd = [
'python3', 'tools/deploy.py',
'configs/mmcls/classification_ncnn_static.py', config, model,
'cat-dog.png', '--work-dir', 'work_dir', '--device', 'cpu'
]
print(subprocess.call(convert_cmd))
move_cmd = [
'mv', 'work_dir/end2end.onnx',
os.path.join(args.out, conf[2])
]
print(subprocess.call(move_cmd))
def run(args):
for conf in CONFIGS:
if len(conf) < 4:
continue
download_url = conf[3]
filename = conf[2]
download_cmd = ['wget', download_url]
# show processbar
os.system(' '.join(download_cmd))
convert_cmd = ['./onnx2ncnn', filename, 'onnx.param', 'onnx.bin']
subprocess.run(convert_cmd, capture_output=True, check=True)
def main():
"""test `onnx2ncnn.cpp`
First generate onnx model then convert it with `onnx2ncnn`.
"""
args = parse_args()
if args.generate_onnx:
generate_onnx(args)
if args.run:
run(args)
if __name__ == '__main__':
main()