3.5 KiB
3.5 KiB
Pytorch 转 ONNX (试验性的)
如何将模型从 PyTorch 转换到 ONNX
准备工作
- 请参照 安装指南 从源码安装 MMClassification。
- 安装 onnx 和 onnxruntime。
pip install onnx onnxruntime==1.5.1
使用方法
python tools/deployment/pytorch2onnx.py \
${CONFIG_FILE} \
--checkpoint ${CHECKPOINT_FILE} \
--output-file ${OUTPUT_FILE} \
--shape ${IMAGE_SHAPE} \
--opset-version ${OPSET_VERSION} \
--dynamic-shape \
--show \
--simplify \
--verify \
所有参数的说明:
config
: 模型配置文件的路径。--checkpoint
: 模型权重文件的路径。--output-file
: ONNX 模型的输出路径。如果没有指定,默认为当前脚本执行路径下的tmp.onnx
。--shape
: 模型输入的高度和宽度。如果没有指定,默认为224 224
。--opset-version
: ONNX 的 opset 版本。如果没有指定,默认为11
。--dynamic-shape
: 是否以动态输入尺寸导出 ONNX。 如果没有指定,默认为False
。--show
: 是否打印导出模型的架构。如果没有指定,默认为False
。--simplify
: 是否精简导出的 ONNX 模型。如果没有指定,默认为False
。--verify
: 是否验证导出模型的正确性。如果没有指定,默认为False
。
示例:
python tools/deployment/pytorch2onnx.py \
configs/resnet/resnet18_8xb16_cifar10.py \
--checkpoint checkpoints/resnet/resnet18_b16x8_cifar10.pth \
--output-file checkpoints/resnet/resnet18_b16x8_cifar10.onnx \
--dynamic-shape \
--show \
--simplify \
--verify \
支持导出至 ONNX 的模型列表
下表列出了保证可导出至 ONNX,并在 ONNX Runtime 中运行的模型。
模型 | 配置文件 | 批推理 | 动态输入尺寸 | 备注 |
---|---|---|---|---|
MobileNetV2 | configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py |
Y | Y | |
ResNet | configs/resnet/resnet18_8xb16_cifar10.py |
Y | Y | |
ResNeXt | configs/resnext/resnext50-32x4d_8xb32_in1k.py |
Y | Y | |
SE-ResNet | configs/seresnet/seresnet50_8xb32_in1k.py |
Y | Y | |
ShuffleNetV1 | configs/shufflenet_v1/shufflenet-v1-1x_16xb64_in1k.py |
Y | Y | |
ShuffleNetV2 | configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py |
Y | Y |
注:
- 以上所有模型转换测试基于 Pytorch==1.6.0 进行
提示
- 如果你在上述模型的转换中遇到问题,请在 GitHub 中创建一个 issue,我们会尽快处理。未在上表中列出的模型,由于资源限制,我们可能无法提供很多帮助,如果遇到问题,请尝试自行解决。
常见问题
- 无