parent
9547e7b7a5
commit
edffbcc486
|
@ -0,0 +1,155 @@
|
|||
import argparse
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import onnxruntime as rt
|
||||
import torch
|
||||
from mmcv.onnx import register_extra_symbolics
|
||||
from mmcv.runner import load_checkpoint
|
||||
|
||||
from mmcls.models import build_classifier
|
||||
|
||||
torch.manual_seed(3)
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape, num_classes):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple):
|
||||
input batch dimensions
|
||||
num_classes (int):
|
||||
number of semantic classes
|
||||
"""
|
||||
(N, C, H, W) = input_shape
|
||||
rng = np.random.RandomState(0)
|
||||
imgs = rng.rand(*input_shape)
|
||||
gt_labels = rng.randint(
|
||||
low=0, high=num_classes - 1, size=(N, 1)).astype(np.uint8)
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
||||
'gt_labels': torch.LongTensor(gt_labels),
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def pytorch2onnx(model,
|
||||
input_shape,
|
||||
opset_version=11,
|
||||
show=False,
|
||||
output_file='tmp.onnx',
|
||||
verify=False):
|
||||
"""Export Pytorch model to ONNX model and verify the outputs are same
|
||||
between Pytorch and ONNX.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Pytorch model we want to export.
|
||||
input_shape (tuple): Use this input shape to construct
|
||||
the corresponding dummy input and execute the model.
|
||||
opset_version (int): The onnx op version. Default: 11.
|
||||
show (bool): Whether print the computation graph. Default: False.
|
||||
output_file (string): The path to where we store the output ONNX model.
|
||||
Default: `tmp.onnx`.
|
||||
verify (bool): Whether compare the outputs between Pytorch and ONNX.
|
||||
Default: False.
|
||||
"""
|
||||
model.cpu().eval()
|
||||
|
||||
num_classes = model.head.num_classes
|
||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
|
||||
# replace original forward function
|
||||
origin_forward = model.forward
|
||||
model.forward = partial(model.forward, return_loss=False)
|
||||
|
||||
register_extra_symbolics(opset_version)
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
model, (img_list, ),
|
||||
output_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
verbose=show,
|
||||
opset_version=opset_version)
|
||||
print(f'Successfully exported ONNX model: {output_file}')
|
||||
model.forward = origin_forward
|
||||
|
||||
if verify:
|
||||
# check by onnx
|
||||
import onnx
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
# check the numerical value
|
||||
# get pytorch output
|
||||
pytorch_result = model(img_list, return_loss=False)[0]
|
||||
|
||||
# get onnx output
|
||||
input_all = [node.name for node in onnx_model.graph.input]
|
||||
input_initializer = [
|
||||
node.name for node in onnx_model.graph.initializer
|
||||
]
|
||||
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||
assert (len(net_feed_input) == 1)
|
||||
sess = rt.InferenceSession(output_file)
|
||||
onnx_result = sess.run(
|
||||
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
|
||||
if not np.allclose(pytorch_result, onnx_result):
|
||||
raise ValueError(
|
||||
'The outputs are different between Pytorch and ONNX')
|
||||
print('The outputs are same between Pytorch and ONNX')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Convert MMCls to ONNX')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
||||
parser.add_argument('--show', action='store_true', help='show onnx graph')
|
||||
parser.add_argument(
|
||||
'--verify', action='store_true', help='verify the onnx model')
|
||||
parser.add_argument('--output-file', type=str, default='tmp.onnx')
|
||||
parser.add_argument('--opset-version', type=int, default=11)
|
||||
parser.add_argument(
|
||||
'--shape',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=[224, 224],
|
||||
help='input image size')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
if len(args.shape) == 1:
|
||||
input_shape = (1, 3, args.shape[0], args.shape[0])
|
||||
elif len(args.shape) == 2:
|
||||
input_shape = (
|
||||
1,
|
||||
3,
|
||||
) + tuple(args.shape)
|
||||
else:
|
||||
raise ValueError('invalid input shape')
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
cfg.model.pretrained = None
|
||||
|
||||
# build the model and load checkpoint
|
||||
classifier = build_classifier(cfg.model)
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(classifier, args.checkpoint, map_location='cpu')
|
||||
|
||||
# conver model to onnx file
|
||||
pytorch2onnx(
|
||||
classifier,
|
||||
input_shape,
|
||||
opset_version=args.opset_version,
|
||||
show=args.show,
|
||||
output_file=args.output_file,
|
||||
verify=args.verify)
|
Loading…
Reference in New Issue