From edffbcc486bfc8234749b94e40e1cc158172ad17 Mon Sep 17 00:00:00 2001 From: Lei Yang Date: Wed, 30 Sep 2020 20:13:36 +0800 Subject: [PATCH] Add pytorch2onnx (#20) * add pytorch2onnx * add pytorch2onnx API --- tools/pytorch2onnx.py | 155 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 tools/pytorch2onnx.py diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py new file mode 100644 index 000000000..9adba04f9 --- /dev/null +++ b/tools/pytorch2onnx.py @@ -0,0 +1,155 @@ +import argparse +from functools import partial + +import mmcv +import numpy as np +import onnxruntime as rt +import torch +from mmcv.onnx import register_extra_symbolics +from mmcv.runner import load_checkpoint + +from mmcls.models import build_classifier + +torch.manual_seed(3) + + +def _demo_mm_inputs(input_shape, num_classes): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + imgs = rng.rand(*input_shape) + gt_labels = rng.randint( + low=0, high=num_classes - 1, size=(N, 1)).astype(np.uint8) + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'gt_labels': torch.LongTensor(gt_labels), + } + return mm_inputs + + +def pytorch2onnx(model, + input_shape, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify=False): + """Export Pytorch model to ONNX model and verify the outputs are same + between Pytorch and ONNX. + + Args: + model (nn.Module): Pytorch model we want to export. + input_shape (tuple): Use this input shape to construct + the corresponding dummy input and execute the model. + opset_version (int): The onnx op version. Default: 11. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output ONNX model. + Default: `tmp.onnx`. + verify (bool): Whether compare the outputs between Pytorch and ONNX. + Default: False. + """ + model.cpu().eval() + + num_classes = model.head.num_classes + mm_inputs = _demo_mm_inputs(input_shape, num_classes) + + imgs = mm_inputs.pop('imgs') + img_list = [img[None, :] for img in imgs] + + # replace original forward function + origin_forward = model.forward + model.forward = partial(model.forward, return_loss=False) + + register_extra_symbolics(opset_version) + with torch.no_grad(): + torch.onnx.export( + model, (img_list, ), + output_file, + export_params=True, + keep_initializers_as_inputs=True, + verbose=show, + opset_version=opset_version) + print(f'Successfully exported ONNX model: {output_file}') + model.forward = origin_forward + + if verify: + # check by onnx + import onnx + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + # check the numerical value + # get pytorch output + pytorch_result = model(img_list, return_loss=False)[0] + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + sess = rt.InferenceSession(output_file) + onnx_result = sess.run( + None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] + if not np.allclose(pytorch_result, onnx_result): + raise ValueError( + 'The outputs are different between Pytorch and ONNX') + print('The outputs are same between Pytorch and ONNX') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert MMCls to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument('--show', action='store_true', help='show onnx graph') + parser.add_argument( + '--verify', action='store_true', help='verify the onnx model') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[224, 224], + help='input image size') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + cfg = mmcv.Config.fromfile(args.config) + cfg.model.pretrained = None + + # build the model and load checkpoint + classifier = build_classifier(cfg.model) + + if args.checkpoint: + load_checkpoint(classifier, args.checkpoint, map_location='cpu') + + # conver model to onnx file + pytorch2onnx( + classifier, + input_shape, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify=args.verify)