mmclassification/tools/deployment/pytorch2onnx.py

219 lines
7.0 KiB
Python
Raw Normal View History

import argparse
from functools import partial
import mmcv
import numpy as np
import onnxruntime as rt
import torch
from mmcv.onnx import register_extra_symbolics
from mmcv.runner import load_checkpoint
from mmcls.models import build_classifier
torch.manual_seed(3)
def _demo_mm_inputs(input_shape, num_classes):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
gt_labels = rng.randint(
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
mm_inputs = {
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
'gt_labels': torch.LongTensor(gt_labels),
}
return mm_inputs
def pytorch2onnx(model,
input_shape,
opset_version=11,
dynamic_export=False,
show=False,
output_file='tmp.onnx',
do_simplify=False,
verify=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.
Args:
model (nn.Module): Pytorch model we want to export.
input_shape (tuple): Use this input shape to construct
the corresponding dummy input and execute the model.
opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
"""
model.cpu().eval()
num_classes = model.head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
imgs = mm_inputs.pop('imgs')
img_list = [img[None, :] for img in imgs]
# replace original forward function
origin_forward = model.forward
model.forward = partial(model.forward, img_metas={}, return_loss=False)
register_extra_symbolics(opset_version)
# support dynamic shape export
if dynamic_export:
dynamic_axes = {
'input': {
0: 'batch',
2: 'width',
3: 'height'
},
'probs': {
0: 'batch'
}
}
else:
dynamic_axes = {}
with torch.no_grad():
torch.onnx.export(
model, (img_list, ),
output_file,
input_names=['input'],
output_names=['probs'],
export_params=True,
keep_initializers_as_inputs=True,
dynamic_axes=dynamic_axes,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
model.forward = origin_forward
if do_simplify:
from mmcv import digit_version
import onnxsim
min_required_version = '0.3.0'
assert digit_version(mmcv.__version__) >= digit_version(
min_required_version
), f'Requires to install onnx-simplify>={min_required_version}'
if dynamic_axes:
input_shape = (input_shape[0], input_shape[1], input_shape[2] * 2,
input_shape[3] * 2)
else:
input_shape = (input_shape[0], input_shape[1], input_shape[2],
input_shape[3])
imgs = _demo_mm_inputs(input_shape, model.head.num_classes).pop('imgs')
input_dic = {'input': imgs.detach().cpu().numpy()}
input_shape_dic = {'input': list(input_shape)}
onnxsim.simplify(
output_file,
input_shapes=input_shape_dic,
input_data=input_dic,
dynamic_input_shape=dynamic_export)
if verify:
# check by onnx
import onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# test the dynamic model
if dynamic_export:
dynamic_test_inputs = _demo_mm_inputs(
(input_shape[0], input_shape[1], input_shape[2] * 2,
input_shape[3] * 2), model.head.num_classes)
imgs = dynamic_test_inputs.pop('imgs')
img_list = [img[None, :] for img in imgs]
# check the numerical value
# get pytorch output
pytorch_result = model(img_list, img_metas={}, return_loss=False)[0]
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
if not np.allclose(pytorch_result, onnx_result):
raise ValueError(
'The outputs are different between Pytorch and ONNX')
print('The outputs are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(description='Convert MMCls to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument(
'--verify', action='store_true', help='verify the onnx model')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--simplify',
action='store_true',
help='Whether to simplify onnx model.')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[224, 224],
help='input image size')
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Whether to export ONNX with dynamic input shape. \
Defaults to False.')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if len(args.shape) == 1:
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (
1,
3,
) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None
# build the model and load checkpoint
classifier = build_classifier(cfg.model)
if args.checkpoint:
load_checkpoint(classifier, args.checkpoint, map_location='cpu')
# conver model to onnx file
pytorch2onnx(
classifier,
input_shape,
opset_version=args.opset_version,
show=args.show,
dynamic_export=args.dynamic_export,
output_file=args.output_file,
do_simplify=args.simplify,
verify=args.verify)