add dynamic export and visualize to pytorch2onnx (#463)

* add dynamic export and visualize to pytorch2onnx

* update document

* fix lint

* fix dynamic error and add visualization

* fix lint

* update docstring

* update doc

* Update help info for --show

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>

* fix lint

Co-authored-by: maningsheng <maningsheng@sensetime.com>
Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
q.yao 2021-04-13 02:54:59 +08:00 committed by GitHub
parent e0e985fa85
commit 789d1a142b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 202 additions and 26 deletions

View File

@ -46,10 +46,32 @@ The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pt
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model. We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
```shell ```bash
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify] python tools/pytorch2onnx.py \
${CONFIG_FILE} \
--checkpoint ${CHECKPOINT_FILE} \
--output-file ${ONNX_FILE} \
--input-img ${INPUT_IMG} \
--shape ${INPUT_SHAPE} \
--show \
--verify \
--dynamic-export \
--cfg-options \
model.test_cfg.mode="whole"
``` ```
Description of arguments:
- `config` : The path of a model config file.
- `--checkpoint` : The path of a model checkpoint file.
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
- `--input-img` : The path of an input image for conversion and visualize.
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`.
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
- `--cfg-options`:Update config options.
**Note**: This tool is still experimental. Some customized operators are not supported for now. **Note**: This tool is still experimental. Some customized operators are not supported for now.
## Miscellaneous ## Miscellaneous

View File

@ -103,7 +103,9 @@ def show_result_pyplot(model,
result, result,
palette=None, palette=None,
fig_size=(15, 10), fig_size=(15, 10),
opacity=0.5): opacity=0.5,
title='',
block=True):
"""Visualize the segmentation results on the image. """Visualize the segmentation results on the image.
Args: Args:
@ -117,6 +119,10 @@ def show_result_pyplot(model,
opacity(float): Opacity of painted segmentation map. opacity(float): Opacity of painted segmentation map.
Default 0.5. Default 0.5.
Must be in (0, 1] range. Must be in (0, 1] range.
title (str): The title of pyplot figure.
Default is ''.
block (bool): Whether to block the pyplot figure.
Default is True.
""" """
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
@ -124,4 +130,6 @@ def show_result_pyplot(model,
img, result, palette=palette, show=False, opacity=opacity) img, result, palette=palette, show=False, opacity=opacity)
plt.figure(figsize=fig_size) plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img)) plt.imshow(mmcv.bgr2rgb(img))
plt.show() plt.title(title)
plt.tight_layout()
plt.show(block=block)

View File

@ -216,9 +216,14 @@ class EncoderDecoder(BaseSegmentor):
seg_logit = self.encode_decode(img, img_meta) seg_logit = self.encode_decode(img, img_meta)
if rescale: if rescale:
# support dynamic shape for onnx
if torch.onnx.is_in_onnx_export():
size = img.shape[2:]
else:
size = img_meta[0]['ori_shape'][:2]
seg_logit = resize( seg_logit = resize(
seg_logit, seg_logit,
size=img_meta[0]['ori_shape'][:2], size=size,
mode='bilinear', mode='bilinear',
align_corners=self.align_corners, align_corners=self.align_corners,
warning=False) warning=False)

View File

@ -1,6 +1,5 @@
import warnings import warnings
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -24,8 +23,6 @@ def resize(input,
'the output would more aligned if ' 'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and ' f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`') f'out size {(output_h, output_w)} is `nx+1`')
if isinstance(size, torch.Size):
size = tuple(int(x) for x in size)
return F.interpolate(input, size, scale_factor, mode, align_corners) return F.interpolate(input, size, scale_factor, mode, align_corners)

View File

@ -7,10 +7,14 @@ import onnxruntime as rt
import torch import torch
import torch._C import torch._C
import torch.serialization import torch.serialization
from mmcv import DictAction
from mmcv.onnx import register_extra_symbolics from mmcv.onnx import register_extra_symbolics
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from torch import nn from torch import nn
from mmseg.apis import show_result_pyplot
from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
torch.manual_seed(3) torch.manual_seed(3)
@ -67,25 +71,61 @@ def _demo_mm_inputs(input_shape, num_classes):
return mm_inputs return mm_inputs
def _prepare_input_img(img_path, test_pipeline, shape=None):
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = shape
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
img_metas = [i.data for i in data['img_metas']]
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
return mm_inputs
def _update_input_img(img_list, img_meta_list):
# update img and its meta list
N, C, H, W = img_list[0].shape
img_meta = img_meta_list[0][0]
new_img_meta_list = [[{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': img_meta['filename'],
'scale_factor': 1.,
'flip': False,
} for _ in range(N)]]
return img_list, new_img_meta_list
def pytorch2onnx(model, def pytorch2onnx(model,
input_shape, mm_inputs,
opset_version=11, opset_version=11,
show=False, show=False,
output_file='tmp.onnx', output_file='tmp.onnx',
verify=False): verify=False,
dynamic_export=False):
"""Export Pytorch model to ONNX model and verify the outputs are same """Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX. between Pytorch and ONNX.
Args: Args:
model (nn.Module): Pytorch model we want to export. model (nn.Module): Pytorch model we want to export.
input_shape (tuple): Use this input shape to construct mm_inputs (dict): Contain the input tensors and img_metas information.
the corresponding dummy input and execute the model.
opset_version (int): The onnx op version. Default: 11. opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False. show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model. output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`. Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX. verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False. Default: False.
dynamic_export (bool): Whether to export ONNX with dynamic axis.
Default: False.
""" """
model.cpu().eval() model.cpu().eval()
@ -94,28 +134,45 @@ def pytorch2onnx(model,
else: else:
num_classes = model.decode_head.num_classes num_classes = model.decode_head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
imgs = mm_inputs.pop('imgs') imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas') img_metas = mm_inputs.pop('img_metas')
ori_shape = img_metas[0]['ori_shape']
img_list = [img[None, :] for img in imgs] img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas] img_meta_list = [[img_meta] for img_meta in img_metas]
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
# replace original forward function # replace original forward function
origin_forward = model.forward origin_forward = model.forward
model.forward = partial( model.forward = partial(
model.forward, img_metas=img_meta_list, return_loss=False) model.forward, img_metas=img_meta_list, return_loss=False)
dynamic_axes = None
if dynamic_export:
dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
1: 'batch',
2: 'height',
3: 'width'
}
}
register_extra_symbolics(opset_version) register_extra_symbolics(opset_version)
with torch.no_grad(): with torch.no_grad():
torch.onnx.export( torch.onnx.export(
model, (img_list, ), model, (img_list, ),
output_file, output_file,
input_names=['input'],
output_names=['output'],
export_params=True, export_params=True,
keep_initializers_as_inputs=True, keep_initializers_as_inputs=False,
verbose=show, verbose=show,
opset_version=opset_version) opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}') print(f'Successfully exported ONNX model: {output_file}')
model.forward = origin_forward model.forward = origin_forward
@ -125,9 +182,28 @@ def pytorch2onnx(model,
onnx_model = onnx.load(output_file) onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model) onnx.checker.check_model(onnx_model)
if dynamic_export:
# scale image for dynamic shape test
img_list = [
nn.functional.interpolate(_, scale_factor=1.5)
for _ in img_list
]
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
torch.cat((ori_img, flip_img), 0)
for ori_img, flip_img in zip(img_list, flip_img_list)
]
# update img_meta
img_list, img_meta_list = _update_input_img(
img_list, img_meta_list)
# check the numerical value # check the numerical value
# get pytorch output # get pytorch output
pytorch_result = model(img_list, img_meta_list, return_loss=False)[0] with torch.no_grad():
pytorch_result = model(img_list, img_meta_list, return_loss=False)
pytorch_result = np.stack(pytorch_result, 0)
# get onnx output # get onnx output
input_all = [node.name for node in onnx_model.graph.input] input_all = [node.name for node in onnx_model.graph.input]
@ -138,10 +214,42 @@ def pytorch2onnx(model,
assert (len(net_feed_input) == 1) assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file) sess = rt.InferenceSession(output_file)
onnx_result = sess.run( onnx_result = sess.run(
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
if not np.allclose(pytorch_result, onnx_result): # show segmentation results
raise ValueError( if show:
'The outputs are different between Pytorch and ONNX') import cv2
import os.path as osp
img = img_meta_list[0][0]['filename']
if not osp.exists(img):
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
img = img.detach().numpy().astype(np.uint8)
# resize onnx_result to ori_shape
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
(ori_shape[1], ori_shape[0]))
show_result_pyplot(
model,
img, (onnx_result_, ),
palette=model.PALETTE,
block=False,
title='ONNXRuntime',
opacity=0.5)
# resize pytorch_result to ori_shape
pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
(ori_shape[1], ori_shape[0]))
show_result_pyplot(
model,
img, (pytorch_result_, ),
title='PyTorch',
palette=model.PALETTE,
opacity=0.5)
# compare results
np.testing.assert_allclose(
pytorch_result.astype(np.float32) / num_classes,
onnx_result.astype(np.float32) / num_classes,
rtol=1e-5,
atol=1e-5,
err_msg='The outputs are different between Pytorch and ONNX')
print('The outputs are same between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX')
@ -149,7 +257,12 @@ def parse_args():
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
parser.add_argument('config', help='test config file path') parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument('--checkpoint', help='checkpoint file', default=None)
parser.add_argument('--show', action='store_true', help='show onnx graph') parser.add_argument(
'--input-img', type=str, help='Images for input', default=None)
parser.add_argument(
'--show',
action='store_true',
help='show onnx graph and segmentation results')
parser.add_argument( parser.add_argument(
'--verify', action='store_true', help='verify the onnx model') '--verify', action='store_true', help='verify the onnx model')
parser.add_argument('--output-file', type=str, default='tmp.onnx') parser.add_argument('--output-file', type=str, default='tmp.onnx')
@ -160,6 +273,20 @@ def parse_args():
nargs='+', nargs='+',
default=[256, 256], default=[256, 256],
help='input image size') help='input image size')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='Override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Whether to export onnx with dynamic axis.')
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -178,6 +305,8 @@ if __name__ == '__main__':
raise ValueError('invalid input shape') raise ValueError('invalid input shape')
cfg = mmcv.Config.fromfile(args.config) cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
cfg.model.pretrained = None cfg.model.pretrained = None
# build the model and load checkpoint # build the model and load checkpoint
@ -188,13 +317,28 @@ if __name__ == '__main__':
segmentor = _convert_batchnorm(segmentor) segmentor = _convert_batchnorm(segmentor)
if args.checkpoint: if args.checkpoint:
load_checkpoint(segmentor, args.checkpoint, map_location='cpu') checkpoint = load_checkpoint(
segmentor, args.checkpoint, map_location='cpu')
segmentor.CLASSES = checkpoint['meta']['CLASSES']
segmentor.PALETTE = checkpoint['meta']['PALETTE']
# conver model to onnx file # read input or create dummpy input
if args.input_img is not None:
mm_inputs = _prepare_input_img(args.input_img, cfg.data.test.pipeline,
(input_shape[3], input_shape[2]))
else:
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
else:
num_classes = segmentor.decode_head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
# convert model to onnx file
pytorch2onnx( pytorch2onnx(
segmentor, segmentor,
input_shape, mm_inputs,
opset_version=args.opset_version, opset_version=args.opset_version,
show=args.show, show=args.show,
output_file=args.output_file, output_file=args.output_file,
verify=args.verify) verify=args.verify,
dynamic_export=args.dynamic_export)