mmocr/tools/deployment/pytorch2onnx.py

378 lines
13 KiB
Python

from argparse import ArgumentParser
from functools import partial
import cv2
import numpy as np
import torch
from mmcv.onnx import register_extra_symbolics
from mmcv.parallel import collate
from mmdet.apis import init_detector
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from tools.deployment.deploy_helper import (ONNXRuntimeDetector,
ONNXRuntimeRecognizer)
from torch import nn
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
def _convert_batchnorm(module):
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
"""update img and its meta list."""
N, C, H, W = img_list[0].shape
img_meta = img_meta_list[0][0]
img_shape = (H, W, C)
if update_ori_shape:
ori_shape = img_shape
else:
ori_shape = img_meta['ori_shape']
pad_shape = img_shape
new_img_meta_list = [[{
'img_shape':
img_shape,
'ori_shape':
ori_shape,
'pad_shape':
pad_shape,
'filename':
img_meta['filename'],
'scale_factor':
np.array(
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2),
'flip':
False,
} for _ in range(N)]]
return img_list, new_img_meta_list
def _prepare_data(cfg, imgs):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
Returns:
result (dict): Predicted results.
"""
if isinstance(imgs, (list, tuple)):
if not isinstance(imgs[0], (np.ndarray, str)):
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
else:
raise AssertionError('imgs must be strings or numpy arrays')
is_ndarray = isinstance(imgs[0], np.ndarray)
if is_ndarray:
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
datas = []
for img in imgs:
# prepare data
if is_ndarray:
# directly add img
data = dict(img=img)
else:
# add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
data = test_pipeline(data)
# get tensor from list to stack for batch mode (text detection)
datas.append(data)
if isinstance(datas[0]['img'], list) and len(datas) > 1:
raise Exception('aug test does not support '
f'inference with batch size '
f'{len(datas)}')
data = collate(datas, samples_per_gpu=len(imgs))
# process img_metas
if isinstance(data['img_metas'], list):
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
else:
data['img_metas'] = data['img_metas'].data
if isinstance(data['img'], list):
data['img'] = [img.data for img in data['img']]
if isinstance(data['img'][0], list):
data['img'] = [img[0] for img in data['img']]
else:
data['img'] = data['img'].data
return data
def pytorch2onnx(model: nn.Module,
model_type: str,
img_path: str,
verbose: bool = False,
show: bool = False,
opset_version: int = 11,
output_file: str = 'tmp.onnx',
verify: bool = False,
dynamic_export: bool = False,
device_id: int = 0):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.
Args:
model (nn.Module): Pytorch model we want to export.
model_type (str): Model type, detection or recognition model.
img_path (str): We need to use this input to execute the model.
opset_version (int): The onnx op version. Default: 11.
verbose (bool): Whether print the computation graph. Default: False.
show (bool): Whether visialize final results. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
dynamic_export (bool): Whether apply dynamic export.
Default: False.
device_id (id): Device id to place model and data.
Default: 0
"""
device = torch.device(type='cuda', index=device_id)
model.to(device).eval()
_convert_batchnorm(model)
# prepare inputs
mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path)
imgs = mm_inputs.pop('img')
img_metas = mm_inputs.pop('img_metas')
if isinstance(imgs, list):
imgs = imgs[0]
img_list = [img[None, :].to(device) for img in imgs]
# update img_meta
img_list, img_metas = _update_input_img(img_list, img_metas)
origin_forward = model.forward
if (model_type == 'det'):
model.forward = partial(
model.simple_test, img_metas=img_metas, rescale=True)
else:
model.forward = partial(
model.forward,
img_metas=img_metas,
return_loss=False,
rescale=True)
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
dynamic_axes = None
if dynamic_export and model_type == 'det':
dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
0: 'batch',
2: 'height',
3: 'width'
}
}
elif dynamic_export and model_type == 'recog':
dynamic_axes = {
'input': {
0: 'batch',
3: 'width'
},
'output': {
0: 'batch',
3: 'width'
}
}
with torch.no_grad():
torch.onnx.export(
model, (img_list[0], ),
output_file,
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=False,
verbose=verbose,
opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
import onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5)
if dynamic_export:
# scale image for dynamic shape test
img_list = [
nn.functional.interpolate(_, scale_factor=scale_factor)
for _ in img_list
]
# update img_meta
img_list, img_metas = _update_input_img(img_list, img_metas)
# check the numerical value
# get pytorch output
with torch.no_grad():
model.forward = origin_forward
pytorch_out = model.simple_test(
img_list[0], img_metas[0], rescale=True)
# get onnx output
if model_type == 'det':
onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id)
else:
onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg,
device_id)
onnx_out = onnx_model.simple_test(
img_list[0], img_metas[0], rescale=True)
# compare results
same_diff = 'same'
if model_type == 'recog':
for onnx_result, pytorch_result in zip(onnx_out, pytorch_out):
if onnx_result['text'] != pytorch_result[
'text'] or not np.allclose(
np.array(onnx_result['score']),
np.array(pytorch_result['score']),
rtol=1e-4,
atol=1e-4):
same_diff = 'different'
break
else:
for onnx_result, pytorch_result in zip(
onnx_out[0]['boundary_result'],
pytorch_out[0]['boundary_result']):
if not np.allclose(
np.array(onnx_result),
np.array(pytorch_result),
rtol=1e-4,
atol=1e-4):
same_diff = 'different'
break
print('The outputs are {} between Pytorch and ONNX'.format(same_diff))
if show:
onnx_img = onnx_model.show_result(
img_path, onnx_out[0], out_file='onnx.jpg', show=False)
pytorch_img = model.show_result(
img_path, pytorch_out[0], out_file='pytorch.jpg', show=False)
if onnx_img is None:
onnx_img = cv2.imread(img_path)
if pytorch_img is None:
pytorch_img = cv2.imread(img_path)
cv2.imshow('PyTorch', pytorch_img)
cv2.imshow('ONNXRuntime', onnx_img)
cv2.waitKey()
return
def main():
parser = ArgumentParser(
description='Convert MMOCR models from pytorch to ONNX')
parser.add_argument('model_config', type=str, help='Config file.')
parser.add_argument(
'model_ckpt', type=str, help='Checkpint file (local or url).')
parser.add_argument(
'model_type',
type=str,
help='Detection or recognition model to deploy.',
choices=['recog', 'det'])
parser.add_argument('image_path', type=str, help='Input Image file.')
parser.add_argument(
'--output-file',
type=str,
help='Output file name of the onnx model.',
default='tmp.onnx')
parser.add_argument(
'--device-id', default=0, help='Device used for inference.')
parser.add_argument(
'--opset-version',
type=int,
help='ONNX opset version, default to 11.',
default=11)
parser.add_argument(
'--verify',
action='store_true',
help='Whether verify the outputs of onnx and pytorch are same.',
default=False)
parser.add_argument(
'--verbose',
action='store_true',
help='Whether print the computation graph.',
default=False)
parser.add_argument(
'--show',
action='store_true',
help='Whether visualize final output.',
default=False)
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Whether dynamicly export onnx model.',
default=False)
args = parser.parse_args()
device = torch.device(type='cuda', index=args.device_id)
# build model
model = init_detector(args.model_config, args.model_ckpt, device=device)
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
pytorch2onnx(
model,
model_type=args.model_type,
output_file=args.output_file,
img_path=args.image_path,
opset_version=args.opset_version,
verify=args.verify,
verbose=args.verbose,
show=args.show,
device_id=args.device_id,
dynamic_export=args.dynamic_export)
if __name__ == '__main__':
main()