mmsegmentation/tools/onnx2tensorrt.py

276 lines
9.1 KiB
Python

import argparse
import os
import os.path as osp
from typing import Iterable, Optional, Union
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import onnxruntime as ort
import torch
from mmcv.ops import get_onnxruntime_op_path
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
save_trt_engine)
from mmseg.apis.inference import LoadImage
from mmseg.datasets import DATASETS
from mmseg.datasets.pipelines import Compose
def get_GiB(x: int):
"""return x GiB."""
return x * (1 << 30)
def _prepare_input_img(img_path: str,
test_pipeline: Iterable[dict],
shape: Optional[Iterable] = None,
rescale_shape: Optional[Iterable] = None) -> dict:
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
img_metas = [i.data for i in data['img_metas']]
if rescale_shape is not None:
for img_meta in img_metas:
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
return mm_inputs
def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
# update img and its meta list
N = img_list[0].size(0)
img_meta = img_meta_list[0][0]
img_shape = img_meta['img_shape']
ori_shape = img_meta['ori_shape']
pad_shape = img_meta['pad_shape']
new_img_meta_list = [[{
'img_shape':
img_shape,
'ori_shape':
ori_shape,
'pad_shape':
pad_shape,
'filename':
img_meta['filename'],
'scale_factor':
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
'flip':
False,
} for _ in range(N)]]
return img_list, new_img_meta_list
def show_result_pyplot(img: Union[str, np.ndarray],
result: np.ndarray,
palette: Optional[Iterable] = None,
fig_size: Iterable[int] = (15, 10),
opacity: float = 0.5,
title: str = '',
block: bool = True):
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
seg = mmcv.imresize(seg, img.shape[:2][::-1])
palette = np.array(palette)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.title(title)
plt.tight_layout()
plt.show(block=block)
def onnx2tensorrt(onnx_file: str,
trt_file: str,
config: dict,
input_config: dict,
fp16: bool = False,
verify: bool = False,
show: bool = False,
dataset: str = 'CityscapesDataset',
workspace_size: int = 1,
verbose: bool = False):
import tensorrt as trt
min_shape = input_config['min_shape']
max_shape = input_config['max_shape']
# create trt engine and wraper
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
max_workspace_size = get_GiB(workspace_size)
trt_engine = onnx2trt(
onnx_file,
opt_shape_dict,
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
fp16_mode=fp16,
max_workspace_size=max_workspace_size)
save_dir, _ = osp.split(trt_file)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
save_trt_engine(trt_engine, trt_file)
print(f'Successfully created TensorRT engine: {trt_file}')
if verify:
inputs = _prepare_input_img(
input_config['input_path'],
config.data.test.pipeline,
shape=min_shape[2:])
imgs = inputs['imgs']
img_metas = inputs['img_metas']
img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
# update img_meta
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
if max_shape[0] > 1:
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
torch.cat((ori_img, flip_img), 0)
for ori_img, flip_img in zip(img_list, flip_img_list)
]
# Get results from ONNXRuntime
ort_custom_op_path = get_onnxruntime_op_path()
session_options = ort.SessionOptions()
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
onnx_output = sess.run(['output'],
{'input': img_list[0].detach().numpy()})[0][0]
# Get results from TensorRT
trt_model = TRTWraper(trt_file, ['input'], ['output'])
with torch.no_grad():
trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
trt_output = trt_outputs['output'][0].cpu().detach().numpy()
if show:
dataset = DATASETS.get(dataset)
assert dataset is not None
palette = dataset.PALETTE
show_result_pyplot(
input_config['input_path'],
(onnx_output[0].astype(np.uint8), ),
palette=palette,
title='ONNXRuntime',
block=False)
show_result_pyplot(
input_config['input_path'], (trt_output[0].astype(np.uint8), ),
palette=palette,
title='TensorRT')
np.testing.assert_allclose(
onnx_output, trt_output, rtol=1e-03, atol=1e-05)
print('TensorRT and ONNXRuntime output all close.')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMSegmentation models from ONNX to TensorRT')
parser.add_argument('config', help='Config file of the model')
parser.add_argument('model', help='Path to the input ONNX model')
parser.add_argument(
'--trt-file', type=str, help='Path to the output TensorRT engine')
parser.add_argument(
'--max-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Maximum shape of model input.')
parser.add_argument(
'--min-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Minimum shape of model input.')
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
parser.add_argument(
'--workspace-size',
type=int,
default=1,
help='Max workspace size in GiB')
parser.add_argument(
'--input-img', type=str, default='', help='Image for test')
parser.add_argument(
'--show', action='store_true', help='Whether to show output results')
parser.add_argument(
'--dataset',
type=str,
default='CityscapesDataset',
help='Dataset name')
parser.add_argument(
'--verify',
action='store_true',
help='Verify the outputs of ONNXRuntime and TensorRT')
parser.add_argument(
'--verbose',
action='store_true',
help='Whether to verbose logging messages while creating \
TensorRT engine.')
args = parser.parse_args()
return args
if __name__ == '__main__':
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
args = parse_args()
if not args.input_img:
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
# check arguments
assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
assert osp.exists(args.model), \
'ONNX model {} not found.'.format(args.model)
assert args.workspace_size >= 0, 'Workspace size less than 0.'
assert DATASETS.get(args.dataset) is not None, \
'Dataset {} does not found.'.format(args.dataset)
for max_value, min_value in zip(args.max_shape, args.min_shape):
assert max_value >= min_value, \
'max_shape sould be larger than min shape'
input_config = {
'min_shape': args.min_shape,
'max_shape': args.max_shape,
'input_path': args.input_img
}
cfg = mmcv.Config.fromfile(args.config)
onnx2tensorrt(
args.model,
args.trt_file,
cfg,
input_config,
fp16=args.fp16,
verify=args.verify,
show=args.show,
dataset=args.dataset,
workspace_size=args.workspace_size,
verbose=args.verbose)