mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add pytorch2onnx part (#12)
* add pytorch2onnx part * Update according to the latest mmcv * add docstring * update docs * update docs Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
a96e2f932d
commit
5ddef979fc
@ -332,3 +332,18 @@ python tools/publish_model.py work_dirs/pspnet/latest.pth psp_r50_hszhao_200ep.p
|
|||||||
```
|
```
|
||||||
|
|
||||||
The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pth`.
|
The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pth`.
|
||||||
|
|
||||||
|
### Convert to ONNX (experimental)
|
||||||
|
|
||||||
|
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
||||||
|
|
||||||
|
## Tutorials
|
||||||
|
|
||||||
|
Currently, we provide four tutorials for users to [add new dataset](tutorials/new_dataset.md), [design data pipeline](tutorials/data_pipeline.md) and [add new modules](tutorials/new_modules.md), [use training tricks](tutorials/training_tricks.md).
|
||||||
|
We also provide a full description about the [config system](config.md).
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@ -171,6 +172,8 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
h_stride, w_stride = self.test_cfg.stride
|
h_stride, w_stride = self.test_cfg.stride
|
||||||
h_crop, w_crop = self.test_cfg.crop_size
|
h_crop, w_crop = self.test_cfg.crop_size
|
||||||
batch_size, _, h_img, w_img = img.size()
|
batch_size, _, h_img, w_img = img.size()
|
||||||
|
assert h_crop <= h_img and w_crop <= w_img, (
|
||||||
|
'crop size should not greater than image size')
|
||||||
num_classes = self.num_classes
|
num_classes = self.num_classes
|
||||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||||
@ -185,14 +188,15 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
y1 = max(y2 - h_crop, 0)
|
y1 = max(y2 - h_crop, 0)
|
||||||
x1 = max(x2 - w_crop, 0)
|
x1 = max(x2 - w_crop, 0)
|
||||||
crop_img = img[:, :, y1:y2, x1:x2]
|
crop_img = img[:, :, y1:y2, x1:x2]
|
||||||
pad_img = crop_img.new_zeros(
|
crop_seg_logit = self.encode_decode(crop_img, img_meta)
|
||||||
(crop_img.size(0), crop_img.size(1), h_crop, w_crop))
|
preds += F.pad(crop_seg_logit,
|
||||||
pad_img[:, :, :y2 - y1, :x2 - x1] = crop_img
|
(int(x1), int(preds.shape[3] - x2), int(y1),
|
||||||
pad_seg_logit = self.encode_decode(pad_img, img_meta)
|
int(preds.shape[2] - y2)))
|
||||||
preds[:, :, y1:y2,
|
|
||||||
x1:x2] += pad_seg_logit[:, :, :y2 - y1, :x2 - x1]
|
|
||||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||||
assert (count_mat == 0).sum() == 0
|
assert (count_mat == 0).sum() == 0
|
||||||
|
# We want to regard count_mat as a constant while exporting to ONNX
|
||||||
|
count_mat = torch.from_numpy(count_mat.detach().numpy())
|
||||||
preds = preds / count_mat
|
preds = preds / count_mat
|
||||||
if rescale:
|
if rescale:
|
||||||
preds = resize(
|
preds = resize(
|
||||||
@ -201,7 +205,6 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
warning=False)
|
warning=False)
|
||||||
|
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def whole_inference(self, img, img_meta, rescale):
|
def whole_inference(self, img, img_meta, rescale):
|
||||||
@ -243,8 +246,8 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
seg_logit = self.whole_inference(img, img_meta, rescale)
|
seg_logit = self.whole_inference(img, img_meta, rescale)
|
||||||
output = F.softmax(seg_logit, dim=1)
|
output = F.softmax(seg_logit, dim=1)
|
||||||
flip = img_meta[0]['flip']
|
flip = img_meta[0]['flip']
|
||||||
flip_direction = img_meta[0]['flip_direction']
|
|
||||||
if flip:
|
if flip:
|
||||||
|
flip_direction = img_meta[0]['flip_direction']
|
||||||
assert flip_direction in ['horizontal', 'vertical']
|
assert flip_direction in ['horizontal', 'vertical']
|
||||||
if flip_direction == 'horizontal':
|
if flip_direction == 'horizontal':
|
||||||
output = output.flip(dims=(3, ))
|
output = output.flip(dims=(3, ))
|
||||||
@ -257,6 +260,8 @@ class EncoderDecoder(BaseSegmentor):
|
|||||||
"""Simple test with single image."""
|
"""Simple test with single image."""
|
||||||
seg_logit = self.inference(img, img_meta, rescale)
|
seg_logit = self.inference(img, img_meta, rescale)
|
||||||
seg_pred = seg_logit.argmax(dim=1)
|
seg_pred = seg_logit.argmax(dim=1)
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
return seg_pred
|
||||||
seg_pred = seg_pred.cpu().numpy()
|
seg_pred = seg_pred.cpu().numpy()
|
||||||
# unravel batch dim
|
# unravel batch dim
|
||||||
seg_pred = list(seg_pred)
|
seg_pred = list(seg_pred)
|
||||||
|
@ -8,6 +8,6 @@ line_length = 79
|
|||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = setuptools
|
known_standard_library = setuptools
|
||||||
known_first_party = mmseg
|
known_first_party = mmseg
|
||||||
known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,pytablewriter,pytest,scipy,torch,torchvision
|
known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnxruntime,pytablewriter,pytest,scipy,torch,torchvision
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
198
tools/pytorch2onnx.py
Normal file
198
tools/pytorch2onnx.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as rt
|
||||||
|
import torch
|
||||||
|
import torch._C
|
||||||
|
import torch.serialization
|
||||||
|
from mmcv.onnx import register_extra_symbolics
|
||||||
|
from mmcv.runner import load_checkpoint
|
||||||
|
|
||||||
|
from mmseg.models import build_segmentor
|
||||||
|
|
||||||
|
torch.manual_seed(3)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_batchnorm(module):
|
||||||
|
module_output = module
|
||||||
|
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||||
|
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
||||||
|
module.momentum, module.affine,
|
||||||
|
module.track_running_stats)
|
||||||
|
if module.affine:
|
||||||
|
module_output.weight.data = module.weight.data.clone().detach()
|
||||||
|
module_output.bias.data = module.bias.data.clone().detach()
|
||||||
|
# keep requires_grad unchanged
|
||||||
|
module_output.weight.requires_grad = module.weight.requires_grad
|
||||||
|
module_output.bias.requires_grad = module.bias.requires_grad
|
||||||
|
module_output.running_mean = module.running_mean
|
||||||
|
module_output.running_var = module.running_var
|
||||||
|
module_output.num_batches_tracked = module.num_batches_tracked
|
||||||
|
for name, child in module.named_children():
|
||||||
|
module_output.add_module(name, _convert_batchnorm(child))
|
||||||
|
del module
|
||||||
|
return module_output
|
||||||
|
|
||||||
|
|
||||||
|
def _demo_mm_inputs(input_shape, num_classes):
|
||||||
|
"""Create a superset of inputs needed to run test or train batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shape (tuple):
|
||||||
|
input batch dimensions
|
||||||
|
num_classes (int):
|
||||||
|
number of semantic classes
|
||||||
|
"""
|
||||||
|
(N, C, H, W) = input_shape
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
imgs = rng.rand(*input_shape)
|
||||||
|
segs = rng.randint(
|
||||||
|
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
||||||
|
img_metas = [{
|
||||||
|
'img_shape': (H, W, C),
|
||||||
|
'ori_shape': (H, W, C),
|
||||||
|
'pad_shape': (H, W, C),
|
||||||
|
'filename': '<demo>.png',
|
||||||
|
'scale_factor': 1.0,
|
||||||
|
'flip': False,
|
||||||
|
} for _ in range(N)]
|
||||||
|
mm_inputs = {
|
||||||
|
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
||||||
|
'img_metas': img_metas,
|
||||||
|
'gt_semantic_seg': torch.LongTensor(segs)
|
||||||
|
}
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch2onnx(model,
|
||||||
|
input_shape,
|
||||||
|
opset_version=11,
|
||||||
|
show=False,
|
||||||
|
output_file='tmp.onnx',
|
||||||
|
verify=False):
|
||||||
|
"""Export Pytorch model to ONNX model and verify the outputs are same
|
||||||
|
between Pytorch and ONNX.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Pytorch model we want to export.
|
||||||
|
input_shape (tuple): Use this input shape to construct
|
||||||
|
the corresponding dummy input and execute the model.
|
||||||
|
opset_version (int): The onnx op version. Default: 11.
|
||||||
|
show (bool): Whether print the computation graph. Default: False.
|
||||||
|
output_file (string): The path to where we store the output ONNX model.
|
||||||
|
Default: `tmp.onnx`.
|
||||||
|
verify (bool): Whether compare the outputs between Pytorch and ONNX.
|
||||||
|
Default: False.
|
||||||
|
"""
|
||||||
|
model.cpu().eval()
|
||||||
|
|
||||||
|
num_classes = model.decode_head.num_classes
|
||||||
|
|
||||||
|
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
|
||||||
|
|
||||||
|
imgs = mm_inputs.pop('imgs')
|
||||||
|
img_metas = mm_inputs.pop('img_metas')
|
||||||
|
|
||||||
|
img_list = [img[None, :] for img in imgs]
|
||||||
|
img_meta_list = [[img_meta] for img_meta in img_metas]
|
||||||
|
|
||||||
|
# replace original forward function
|
||||||
|
origin_forward = model.forward
|
||||||
|
model.forward = partial(
|
||||||
|
model.forward, img_metas=img_meta_list, return_loss=False)
|
||||||
|
|
||||||
|
register_extra_symbolics(opset_version)
|
||||||
|
with torch.no_grad():
|
||||||
|
torch.onnx.export(
|
||||||
|
model, (img_list, ),
|
||||||
|
output_file,
|
||||||
|
export_params=True,
|
||||||
|
keep_initializers_as_inputs=True,
|
||||||
|
verbose=show,
|
||||||
|
opset_version=opset_version)
|
||||||
|
print(f'Successfully exported ONNX model: {output_file}')
|
||||||
|
model.forward = origin_forward
|
||||||
|
|
||||||
|
if verify:
|
||||||
|
# check by onnx
|
||||||
|
import onnx
|
||||||
|
onnx_model = onnx.load(output_file)
|
||||||
|
onnx.checker.check_model(onnx_model)
|
||||||
|
|
||||||
|
# check the numerical value
|
||||||
|
# get pytorch output
|
||||||
|
pytorch_result = model(img_list, img_meta_list, return_loss=False)[0]
|
||||||
|
|
||||||
|
# get onnx output
|
||||||
|
input_all = [node.name for node in onnx_model.graph.input]
|
||||||
|
input_initializer = [
|
||||||
|
node.name for node in onnx_model.graph.initializer
|
||||||
|
]
|
||||||
|
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||||
|
assert (len(net_feed_input) == 1)
|
||||||
|
sess = rt.InferenceSession(output_file)
|
||||||
|
onnx_result = sess.run(
|
||||||
|
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
|
||||||
|
if not np.allclose(pytorch_result, onnx_result):
|
||||||
|
raise ValueError(
|
||||||
|
'The outputs are different between Pytorch and ONNX')
|
||||||
|
print('The outputs are same between Pytorch and ONNX')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Convert MMDet to ONNX')
|
||||||
|
parser.add_argument('config', help='test config file path')
|
||||||
|
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
|
||||||
|
parser.add_argument('--show', action='store_true', help='show onnx graph')
|
||||||
|
parser.add_argument(
|
||||||
|
'--verify', action='store_true', help='verify the onnx model')
|
||||||
|
parser.add_argument('--output-file', type=str, default='tmp.onnx')
|
||||||
|
parser.add_argument('--opset-version', type=int, default=11)
|
||||||
|
parser.add_argument(
|
||||||
|
'--shape',
|
||||||
|
type=int,
|
||||||
|
nargs='+',
|
||||||
|
default=[256, 256],
|
||||||
|
help='input image size')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if len(args.shape) == 1:
|
||||||
|
input_shape = (1, 3, args.shape[0], args.shape[0])
|
||||||
|
elif len(args.shape) == 2:
|
||||||
|
input_shape = (
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
) + tuple(args.shape)
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid input shape')
|
||||||
|
|
||||||
|
cfg = mmcv.Config.fromfile(args.config)
|
||||||
|
cfg.model.pretrained = None
|
||||||
|
|
||||||
|
# build the model and load checkpoint
|
||||||
|
segmentor = build_segmentor(
|
||||||
|
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
|
||||||
|
# convert SyncBN to BN
|
||||||
|
segmentor = _convert_batchnorm(segmentor)
|
||||||
|
|
||||||
|
num_classes = segmentor.decode_head.num_classes
|
||||||
|
|
||||||
|
if args.checkpoint:
|
||||||
|
checkpoint = load_checkpoint(
|
||||||
|
segmentor, args.checkpoint, map_location='cpu')
|
||||||
|
|
||||||
|
# conver model to onnx file
|
||||||
|
pytorch2onnx(
|
||||||
|
segmentor,
|
||||||
|
input_shape,
|
||||||
|
opset_version=args.opset_version,
|
||||||
|
show=args.show,
|
||||||
|
output_file=args.output_file,
|
||||||
|
verify=args.verify)
|
Loading…
x
Reference in New Issue
Block a user