mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add dynamic export and visualize to pytorch2onnx (#463)
* add dynamic export and visualize to pytorch2onnx * update document * fix lint * fix dynamic error and add visualization * fix lint * update docstring * update doc * Update help info for --show Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * fix lint Co-authored-by: maningsheng <maningsheng@sensetime.com> Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
e0e985fa85
commit
789d1a142b
@ -46,10 +46,32 @@ The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pt
|
||||
|
||||
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
|
||||
|
||||
```shell
|
||||
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
|
||||
```bash
|
||||
python tools/pytorch2onnx.py \
|
||||
${CONFIG_FILE} \
|
||||
--checkpoint ${CHECKPOINT_FILE} \
|
||||
--output-file ${ONNX_FILE} \
|
||||
--input-img ${INPUT_IMG} \
|
||||
--shape ${INPUT_SHAPE} \
|
||||
--show \
|
||||
--verify \
|
||||
--dynamic-export \
|
||||
--cfg-options \
|
||||
model.test_cfg.mode="whole"
|
||||
```
|
||||
|
||||
Description of arguments:
|
||||
|
||||
- `config` : The path of a model config file.
|
||||
- `--checkpoint` : The path of a model checkpoint file.
|
||||
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
|
||||
- `--input-img` : The path of an input image for conversion and visualize.
|
||||
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`.
|
||||
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
|
||||
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
|
||||
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
|
||||
- `--cfg-options`:Update config options.
|
||||
|
||||
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
||||
|
||||
## Miscellaneous
|
||||
|
@ -103,7 +103,9 @@ def show_result_pyplot(model,
|
||||
result,
|
||||
palette=None,
|
||||
fig_size=(15, 10),
|
||||
opacity=0.5):
|
||||
opacity=0.5,
|
||||
title='',
|
||||
block=True):
|
||||
"""Visualize the segmentation results on the image.
|
||||
|
||||
Args:
|
||||
@ -117,6 +119,10 @@ def show_result_pyplot(model,
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5.
|
||||
Must be in (0, 1] range.
|
||||
title (str): The title of pyplot figure.
|
||||
Default is ''.
|
||||
block (bool): Whether to block the pyplot figure.
|
||||
Default is True.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
@ -124,4 +130,6 @@ def show_result_pyplot(model,
|
||||
img, result, palette=palette, show=False, opacity=opacity)
|
||||
plt.figure(figsize=fig_size)
|
||||
plt.imshow(mmcv.bgr2rgb(img))
|
||||
plt.show()
|
||||
plt.title(title)
|
||||
plt.tight_layout()
|
||||
plt.show(block=block)
|
||||
|
@ -216,9 +216,14 @@ class EncoderDecoder(BaseSegmentor):
|
||||
|
||||
seg_logit = self.encode_decode(img, img_meta)
|
||||
if rescale:
|
||||
# support dynamic shape for onnx
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
size = img.shape[2:]
|
||||
else:
|
||||
size = img_meta[0]['ori_shape'][:2]
|
||||
seg_logit = resize(
|
||||
seg_logit,
|
||||
size=img_meta[0]['ori_shape'][:2],
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners,
|
||||
warning=False)
|
||||
|
@ -1,6 +1,5 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -24,8 +23,6 @@ def resize(input,
|
||||
'the output would more aligned if '
|
||||
f'input size {(input_h, input_w)} is `x+1` and '
|
||||
f'out size {(output_h, output_w)} is `nx+1`')
|
||||
if isinstance(size, torch.Size):
|
||||
size = tuple(int(x) for x in size)
|
||||
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
|
||||
|
||||
|
@ -7,10 +7,14 @@ import onnxruntime as rt
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.serialization
|
||||
from mmcv import DictAction
|
||||
from mmcv.onnx import register_extra_symbolics
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch import nn
|
||||
|
||||
from mmseg.apis import show_result_pyplot
|
||||
from mmseg.apis.inference import LoadImage
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
torch.manual_seed(3)
|
||||
@ -67,25 +71,61 @@ def _demo_mm_inputs(input_shape, num_classes):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def _prepare_input_img(img_path, test_pipeline, shape=None):
|
||||
# build the data pipeline
|
||||
if shape is not None:
|
||||
test_pipeline[1]['img_scale'] = shape
|
||||
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
|
||||
test_pipeline = [LoadImage()] + test_pipeline[1:]
|
||||
test_pipeline = Compose(test_pipeline)
|
||||
# prepare data
|
||||
data = dict(img=img_path)
|
||||
data = test_pipeline(data)
|
||||
imgs = data['img']
|
||||
img_metas = [i.data for i in data['img_metas']]
|
||||
|
||||
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def _update_input_img(img_list, img_meta_list):
|
||||
# update img and its meta list
|
||||
N, C, H, W = img_list[0].shape
|
||||
img_meta = img_meta_list[0][0]
|
||||
new_img_meta_list = [[{
|
||||
'img_shape': (H, W, C),
|
||||
'ori_shape': (H, W, C),
|
||||
'pad_shape': (H, W, C),
|
||||
'filename': img_meta['filename'],
|
||||
'scale_factor': 1.,
|
||||
'flip': False,
|
||||
} for _ in range(N)]]
|
||||
|
||||
return img_list, new_img_meta_list
|
||||
|
||||
|
||||
def pytorch2onnx(model,
|
||||
input_shape,
|
||||
mm_inputs,
|
||||
opset_version=11,
|
||||
show=False,
|
||||
output_file='tmp.onnx',
|
||||
verify=False):
|
||||
verify=False,
|
||||
dynamic_export=False):
|
||||
"""Export Pytorch model to ONNX model and verify the outputs are same
|
||||
between Pytorch and ONNX.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Pytorch model we want to export.
|
||||
input_shape (tuple): Use this input shape to construct
|
||||
the corresponding dummy input and execute the model.
|
||||
mm_inputs (dict): Contain the input tensors and img_metas information.
|
||||
opset_version (int): The onnx op version. Default: 11.
|
||||
show (bool): Whether print the computation graph. Default: False.
|
||||
output_file (string): The path to where we store the output ONNX model.
|
||||
Default: `tmp.onnx`.
|
||||
verify (bool): Whether compare the outputs between Pytorch and ONNX.
|
||||
Default: False.
|
||||
dynamic_export (bool): Whether to export ONNX with dynamic axis.
|
||||
Default: False.
|
||||
"""
|
||||
model.cpu().eval()
|
||||
|
||||
@ -94,28 +134,45 @@ def pytorch2onnx(model,
|
||||
else:
|
||||
num_classes = model.decode_head.num_classes
|
||||
|
||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
ori_shape = img_metas[0]['ori_shape']
|
||||
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
img_meta_list = [[img_meta] for img_meta in img_metas]
|
||||
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
|
||||
|
||||
# replace original forward function
|
||||
origin_forward = model.forward
|
||||
model.forward = partial(
|
||||
model.forward, img_metas=img_meta_list, return_loss=False)
|
||||
dynamic_axes = None
|
||||
if dynamic_export:
|
||||
dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'output': {
|
||||
1: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
}
|
||||
}
|
||||
|
||||
register_extra_symbolics(opset_version)
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
model, (img_list, ),
|
||||
output_file,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
verbose=show,
|
||||
opset_version=opset_version)
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes)
|
||||
print(f'Successfully exported ONNX model: {output_file}')
|
||||
model.forward = origin_forward
|
||||
|
||||
@ -125,9 +182,28 @@ def pytorch2onnx(model,
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
if dynamic_export:
|
||||
# scale image for dynamic shape test
|
||||
img_list = [
|
||||
nn.functional.interpolate(_, scale_factor=1.5)
|
||||
for _ in img_list
|
||||
]
|
||||
# concate flip image for batch test
|
||||
flip_img_list = [_.flip(-1) for _ in img_list]
|
||||
img_list = [
|
||||
torch.cat((ori_img, flip_img), 0)
|
||||
for ori_img, flip_img in zip(img_list, flip_img_list)
|
||||
]
|
||||
|
||||
# update img_meta
|
||||
img_list, img_meta_list = _update_input_img(
|
||||
img_list, img_meta_list)
|
||||
|
||||
# check the numerical value
|
||||
# get pytorch output
|
||||
pytorch_result = model(img_list, img_meta_list, return_loss=False)[0]
|
||||
with torch.no_grad():
|
||||
pytorch_result = model(img_list, img_meta_list, return_loss=False)
|
||||
pytorch_result = np.stack(pytorch_result, 0)
|
||||
|
||||
# get onnx output
|
||||
input_all = [node.name for node in onnx_model.graph.input]
|
||||
@ -138,10 +214,42 @@ def pytorch2onnx(model,
|
||||
assert (len(net_feed_input) == 1)
|
||||
sess = rt.InferenceSession(output_file)
|
||||
onnx_result = sess.run(
|
||||
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
|
||||
if not np.allclose(pytorch_result, onnx_result):
|
||||
raise ValueError(
|
||||
'The outputs are different between Pytorch and ONNX')
|
||||
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
|
||||
# show segmentation results
|
||||
if show:
|
||||
import cv2
|
||||
import os.path as osp
|
||||
img = img_meta_list[0][0]['filename']
|
||||
if not osp.exists(img):
|
||||
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
|
||||
img = img.detach().numpy().astype(np.uint8)
|
||||
# resize onnx_result to ori_shape
|
||||
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
|
||||
(ori_shape[1], ori_shape[0]))
|
||||
show_result_pyplot(
|
||||
model,
|
||||
img, (onnx_result_, ),
|
||||
palette=model.PALETTE,
|
||||
block=False,
|
||||
title='ONNXRuntime',
|
||||
opacity=0.5)
|
||||
|
||||
# resize pytorch_result to ori_shape
|
||||
pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
|
||||
(ori_shape[1], ori_shape[0]))
|
||||
show_result_pyplot(
|
||||
model,
|
||||
img, (pytorch_result_, ),
|
||||
title='PyTorch',
|
||||
palette=model.PALETTE,
|
||||
opacity=0.5)
|
||||
# compare results
|
||||
np.testing.assert_allclose(
|
||||
pytorch_result.astype(np.float32) / num_classes,
|
||||
onnx_result.astype(np.float32) / num_classes,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg='The outputs are different between Pytorch and ONNX')
|
||||
print('The outputs are same between Pytorch and ONNX')
|
||||
|
||||
|
||||
@ -149,7 +257,12 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
||||
parser.add_argument('--show', action='store_true', help='show onnx graph')
|
||||
parser.add_argument(
|
||||
'--input-img', type=str, help='Images for input', default=None)
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='show onnx graph and segmentation results')
|
||||
parser.add_argument(
|
||||
'--verify', action='store_true', help='verify the onnx model')
|
||||
parser.add_argument('--output-file', type=str, default='tmp.onnx')
|
||||
@ -160,6 +273,20 @@ def parse_args():
|
||||
nargs='+',
|
||||
default=[256, 256],
|
||||
help='input image size')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='Override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--dynamic-export',
|
||||
action='store_true',
|
||||
help='Whether to export onnx with dynamic axis.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -178,6 +305,8 @@ if __name__ == '__main__':
|
||||
raise ValueError('invalid input shape')
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
cfg.model.pretrained = None
|
||||
|
||||
# build the model and load checkpoint
|
||||
@ -188,13 +317,28 @@ if __name__ == '__main__':
|
||||
segmentor = _convert_batchnorm(segmentor)
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
|
||||
checkpoint = load_checkpoint(
|
||||
segmentor, args.checkpoint, map_location='cpu')
|
||||
segmentor.CLASSES = checkpoint['meta']['CLASSES']
|
||||
segmentor.PALETTE = checkpoint['meta']['PALETTE']
|
||||
|
||||
# conver model to onnx file
|
||||
# read input or create dummpy input
|
||||
if args.input_img is not None:
|
||||
mm_inputs = _prepare_input_img(args.input_img, cfg.data.test.pipeline,
|
||||
(input_shape[3], input_shape[2]))
|
||||
else:
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||
|
||||
# convert model to onnx file
|
||||
pytorch2onnx(
|
||||
segmentor,
|
||||
input_shape,
|
||||
mm_inputs,
|
||||
opset_version=args.opset_version,
|
||||
show=args.show,
|
||||
output_file=args.output_file,
|
||||
verify=args.verify)
|
||||
verify=args.verify,
|
||||
dynamic_export=args.dynamic_export)
|
||||
|
Loading…
x
Reference in New Issue
Block a user