mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Add onnx and tensorrt export tool (#278)
* add onnx and tensorrt export * fix lint * delete batch input to avoid dbnet error * resolve unittest * fix lint * export unittestpull/287/head
parent
07a21fd716
commit
0131b3290f
|
@ -0,0 +1,114 @@
|
|||
## Deployment
|
||||
|
||||
We provide deployment tools under `tools/deployment` directory.
|
||||
|
||||
### Convert to ONNX (experimental)
|
||||
|
||||
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
|
||||
|
||||
```bash
|
||||
python tools/deployment/pytorch2onnx.py
|
||||
${MODEL_CONFIG_PATH} \
|
||||
${MODEL_CKPT_PATH} \
|
||||
${MODEL_TYPE} \
|
||||
${IMAGE_PATH} \
|
||||
--output-file ${OUTPUT_FILE} \
|
||||
--device-id ${DEVICE_ID} \
|
||||
--opset-version ${OPSET_VERSION} \
|
||||
--verify \
|
||||
--verbose \
|
||||
--show \
|
||||
--dynamic-export
|
||||
```
|
||||
|
||||
Description of arguments:
|
||||
|
||||
- `model_config` : The path of a model config file.
|
||||
- `model_ckpt` : The path of a model checkpoint file.
|
||||
- `model_type` : The model type of the config file, options: `recog`, `det`.
|
||||
- `image_path` : The path to input image file.
|
||||
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
|
||||
- `--device-id`: Which gpu to use. If not specified, it will be set to 0.
|
||||
- `--opset-version` : ONNX opset version, default to 11.
|
||||
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
|
||||
- `--verbose`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
|
||||
- `--show`: Determines whether to visualize outputs of ONNXRuntime and pytorch. If not specified, it will be set to `False`.
|
||||
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
|
||||
|
||||
**Note**: This tool is still experimental. Some customized operators are not supported for now. And we only support `detection` and `recognition` for now.
|
||||
|
||||
#### List of supported models exportable to ONNX
|
||||
|
||||
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
|
||||
|
||||
| Model | Config | Dynamic Shape | Batch Inference | Note |
|
||||
|:------:|:------------------------------------------------------------------------------------------------------------------------------------------------:|:-------------:|:---------------:|:----:|
|
||||
| DBNet | [dbnet_r18_fpnc_1200e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | Y | N | |
|
||||
| PSENet | [psenet_r50_fpnf_600e_ctw1500.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py) | Y | Y | |
|
||||
| PSENet | [psenet_r50_fpnf_600e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py) | Y | Y | |
|
||||
| PANet | [panet_r18_fpem_ffm_600e_ctw1500.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py) | Y | Y | |
|
||||
| PANet | [panet_r18_fpem_ffm_600e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py) | Y | Y | |
|
||||
| CRNN | [crnn_academic_dataset.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textrecog/crnn/crnn_academic_dataset.py) | Y | Y | |
|
||||
|
||||
**Notes**:
|
||||
|
||||
- *All models above are tested with Pytorch==1.8.1 and onnxruntime==1.7.0*
|
||||
- If you meet any problem with the listed models above, please create an issue and it would be taken care of soon. For models not included in the list, please try to solve them by yourself.
|
||||
- Because this feature is experimental and may change fast, please always try with the latest `mmcv` and `mmocr`.
|
||||
|
||||
### Convert ONNX to TensorRT (experimental)
|
||||
|
||||
We also provide a script to convert [ONNX](https://github.com/onnx/onnx) model to [TensorRT](https://github.com/NVIDIA/TensorRT) format. Besides, we support comparing the output results between ONNX and TensorRT model.
|
||||
|
||||
|
||||
```bash
|
||||
python tools/deployment/onnx2tensorrt.py
|
||||
${MODEL_CONFIG_PATH} \
|
||||
${MODEL_TYPE} \
|
||||
${IMAGE_PATH} \
|
||||
${ONNX_FILE} \
|
||||
--trt-file ${OUT_TENSORRT} \
|
||||
--max-shape INT INT INT INT \
|
||||
--min-shape INT INT INT INT \
|
||||
--workspace-size INT \
|
||||
--fp16 \
|
||||
--verify \
|
||||
--show \
|
||||
--verbose
|
||||
```
|
||||
|
||||
Description of arguments:
|
||||
|
||||
- `model_config` : The path of a model config file.
|
||||
- `model_type` :The model type of the config file, options:
|
||||
- `image_path` : The path to input image file.
|
||||
- `onnx_file` : The path to input ONNX file.
|
||||
- `--trt-file` : The path of output TensorRT model. If not specified, it will be set to `tmp.trt`.
|
||||
- `--max-shape` : Maximum shape of model input.
|
||||
- `--min-shape` : Minimum shape of model input.
|
||||
- `--workspace-size`: Max workspace size in GiB. If not specified, it will be set to 1 GiB.
|
||||
- `--fp16`: Determines whether to export TensorRT with fp16 mode. If not specified, it will be set to `False`.
|
||||
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
|
||||
- `--show`: Determines whether to show the output of ONNX and TensorRT. If not specified, it will be set to `False`.
|
||||
- `--verbose`: Determines whether to verbose logging messages while creating TensorRT engine. If not specified, it will be set to `False`.
|
||||
|
||||
**Note**: This tool is still experimental. Some customized operators are not supported for now. We only support `detection` and `recognition` for now.
|
||||
|
||||
#### List of supported models exportable to TensorRT
|
||||
|
||||
The table below lists the models that are guaranteed to be exportable to TensorRT engine and runnable in TensorRT.
|
||||
|
||||
| Model | Config | Dynamic Shape | Batch Inference | Note |
|
||||
|:------:|:------------------------------------------------------------------------------------------------------------------------------------------------:|:-------------:|:---------------:|:----:|
|
||||
| DBNet | [dbnet_r18_fpnc_1200e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | Y | N | |
|
||||
| PSENet | [psenet_r50_fpnf_600e_ctw1500.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py) | Y | Y | |
|
||||
| PSENet | [psenet_r50_fpnf_600e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py) | Y | Y | |
|
||||
| PANet | [panet_r18_fpem_ffm_600e_ctw1500.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py) | Y | Y | |
|
||||
| PANet | [panet_r18_fpem_ffm_600e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py) | Y | Y | |
|
||||
| CRNN | [crnn_academic_dataset.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textrecog/crnn/crnn_academic_dataset.py) | Y | Y | CRNN only accepts input with height 32 |
|
||||
|
||||
**Notes**:
|
||||
|
||||
- *All models above are tested with Pytorch==1.8.1, onnxruntime==1.7.0 and tensorrt==7.2.1.6*
|
||||
- If you meet any problem with the listed models above, please create an issue and it would be taken care of soon. For models not included in the list, please try to solve them by yourself.
|
||||
- Because this feature is experimental and may change fast, please always try with the latest `mmcv` and `mmocr`.
|
|
@ -1,3 +1,5 @@
|
|||
import torch
|
||||
|
||||
from mmdet.models.builder import DETECTORS
|
||||
from mmdet.models.detectors import SingleStageDetector
|
||||
|
||||
|
@ -41,6 +43,10 @@ class SingleStageTextDetector(SingleStageDetector):
|
|||
x = self.extract_feat(img)
|
||||
outs = self.bbox_head(x)
|
||||
|
||||
# early return to avoid post processing
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return outs
|
||||
|
||||
if len(img_metas) > 1:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*(outs[i].unsqueeze(0)),
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import torch
|
||||
|
||||
from mmdet.models.builder import DETECTORS, build_backbone, build_loss
|
||||
from mmocr.models.builder import (build_convertor, build_decoder,
|
||||
build_encoder, build_preprocessor)
|
||||
|
@ -136,6 +138,10 @@ class EncodeDecodeRecognizer(BaseRecognizer):
|
|||
out_dec = self.decoder(
|
||||
feat, out_enc, None, img_metas, train_mode=False)
|
||||
|
||||
# early return to avoid post processing
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return out_dec
|
||||
|
||||
label_indexes, label_scores = self.label_convertor.tensor2idx(
|
||||
out_dec, img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""pytest tests/test_detector.py."""
|
||||
import copy
|
||||
from functools import partial
|
||||
from os.path import dirname, exists, join
|
||||
|
||||
import numpy as np
|
||||
|
@ -218,6 +219,17 @@ def test_panet(cfg_file):
|
|||
return_loss=False)
|
||||
batch_results.append(result)
|
||||
|
||||
# Test onnx export
|
||||
detector.forward = partial(
|
||||
detector.simple_test, img_metas=img_metas, rescale=True)
|
||||
torch.onnx.export(
|
||||
detector, (img_list[0], ),
|
||||
'tmp.onnx',
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False)
|
||||
|
||||
# Test show result
|
||||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||
img = np.random.rand(5, 5)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -71,6 +72,20 @@ def test_base_recognizer():
|
|||
assert 'text' in results[0]
|
||||
assert 'score' in results[0]
|
||||
|
||||
# test onnx export
|
||||
recognizer.forward = partial(
|
||||
recognizer.simple_test,
|
||||
img_metas=img_metas,
|
||||
return_loss=False,
|
||||
rescale=True)
|
||||
torch.onnx.export(
|
||||
recognizer, (imgs, ),
|
||||
'tmp.onnx',
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False)
|
||||
|
||||
# test aug_test
|
||||
aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
|
||||
assert isinstance(aug_results, list)
|
||||
|
|
|
@ -0,0 +1,317 @@
|
|||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmdet.models.builder import DETECTORS
|
||||
from mmocr.models.textdet.detectors.single_stage_text_detector import \
|
||||
SingleStageTextDetector
|
||||
from mmocr.models.textdet.detectors.text_detector_mixin import \
|
||||
TextDetectorMixin
|
||||
from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \
|
||||
EncodeDecodeRecognizer
|
||||
|
||||
|
||||
def inference_with_session(sess, io_binding, input_name, output_names,
|
||||
input_tensor):
|
||||
device_type = input_tensor.device.type
|
||||
device_id = input_tensor.device.index
|
||||
device_id = 0 if device_id is None else device_id
|
||||
io_binding.bind_input(
|
||||
name=input_name,
|
||||
device_type=device_type,
|
||||
device_id=device_id,
|
||||
element_type=np.float32,
|
||||
shape=input_tensor.shape,
|
||||
buffer_ptr=input_tensor.data_ptr())
|
||||
for name in output_names:
|
||||
io_binding.bind_output(name)
|
||||
sess.run_with_iobinding(io_binding)
|
||||
pred = io_binding.copy_outputs_to_cpu()
|
||||
return pred
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for evaluating onnx file of detection."""
|
||||
|
||||
def __init__(self,
|
||||
onnx_file: str,
|
||||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
SingleStageTextDetector.__init__(self, cfg.model.backbone,
|
||||
cfg.model.neck, cfg.model.bbox_head,
|
||||
cfg.model.train_cfg,
|
||||
cfg.model.test_cfg,
|
||||
cfg.model.pretrained)
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
import onnxruntime as ort
|
||||
# get the custom op path
|
||||
ort_custom_op_path = ''
|
||||
try:
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with ONNXRuntime from source.')
|
||||
session_options = ort.SessionOptions()
|
||||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
providers = ['CPUExecutionProvider']
|
||||
options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU'
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
options.insert(0, {'device_id': device_id})
|
||||
|
||||
sess.set_providers(providers, options)
|
||||
|
||||
self.sess = sess
|
||||
self.device_id = device_id
|
||||
self.io_binding = sess.io_binding()
|
||||
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||
for name in self.output_names:
|
||||
self.io_binding.bind_output(name)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self,
|
||||
img: torch.Tensor,
|
||||
img_metas: Iterable,
|
||||
rescale: bool = False):
|
||||
onnx_pred = inference_with_session(self.sess, self.io_binding, 'input',
|
||||
self.output_names, img)
|
||||
onnx_pred = torch.from_numpy(onnx_pred[0])
|
||||
if len(img_metas) > 1:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*(onnx_pred[i].unsqueeze(0)),
|
||||
[img_metas[i]], rescale)
|
||||
for i in range(len(img_metas))
|
||||
]
|
||||
|
||||
else:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*onnx_pred, img_metas, rescale)
|
||||
]
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
||||
"""The class for evaluating onnx file of recognition."""
|
||||
|
||||
def __init__(self,
|
||||
onnx_file: str,
|
||||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
EncodeDecodeRecognizer.__init__(self, cfg.model.preprocessor,
|
||||
cfg.model.backbone, cfg.model.encoder,
|
||||
cfg.model.decoder, cfg.model.loss,
|
||||
cfg.model.label_convertor,
|
||||
cfg.train_cfg, cfg.test_cfg, 40,
|
||||
cfg.model.pretrained)
|
||||
import onnxruntime as ort
|
||||
# get the custom op path
|
||||
ort_custom_op_path = ''
|
||||
try:
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with ONNXRuntime from source.')
|
||||
session_options = ort.SessionOptions()
|
||||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
providers = ['CPUExecutionProvider']
|
||||
options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU'
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
options.insert(0, {'device_id': device_id})
|
||||
|
||||
sess.set_providers(providers, options)
|
||||
|
||||
self.sess = sess
|
||||
self.device_id = device_id
|
||||
self.io_binding = sess.io_binding()
|
||||
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||
for name in self.output_names:
|
||||
self.io_binding.bind_output(name)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self,
|
||||
img: torch.Tensor,
|
||||
img_metas: Iterable,
|
||||
rescale: bool = False):
|
||||
"""Test function.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
onnx_pred = inference_with_session(self.sess, self.io_binding, 'input',
|
||||
self.output_names, img)
|
||||
onnx_pred = torch.from_numpy(onnx_pred[0])
|
||||
|
||||
label_indexes, label_scores = self.label_convertor.tensor2idx(
|
||||
onnx_pred, img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for string, score in zip(label_strings, label_scores):
|
||||
results.append(dict(text=string, score=score))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||
"""The class for evaluating TensorRT file of detection."""
|
||||
|
||||
def __init__(self,
|
||||
trt_file: str,
|
||||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
SingleStageTextDetector.__init__(self, cfg.model.backbone,
|
||||
cfg.model.neck, cfg.model.bbox_head,
|
||||
cfg.model.train_cfg,
|
||||
cfg.model.test_cfg,
|
||||
cfg.model.pretrained)
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
|
||||
try:
|
||||
load_tensorrt_plugin()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with TensorRT from source.')
|
||||
model = TRTWraper(
|
||||
trt_file, input_names=['input'], output_names=['output'])
|
||||
|
||||
self.model = model
|
||||
self.device_id = device_id
|
||||
self.cfg = cfg
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self,
|
||||
img: torch.Tensor,
|
||||
img_metas: Iterable,
|
||||
rescale: bool = False):
|
||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||
trt_pred = self.model({'input': img})['output']
|
||||
if len(img_metas) > 1:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*(trt_pred[i].unsqueeze(0)),
|
||||
[img_metas[i]], rescale)
|
||||
for i in range(len(img_metas))
|
||||
]
|
||||
|
||||
else:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(*trt_pred, img_metas, rescale)
|
||||
]
|
||||
|
||||
return boundaries
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class TensorRTRecognizer(EncodeDecodeRecognizer):
|
||||
"""The class for evaluating TensorRT file of recognition."""
|
||||
|
||||
def __init__(self,
|
||||
trt_file: str,
|
||||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
EncodeDecodeRecognizer.__init__(self, cfg.model.preprocessor,
|
||||
cfg.model.backbone, cfg.model.encoder,
|
||||
cfg.model.decoder, cfg.model.loss,
|
||||
cfg.model.label_convertor,
|
||||
cfg.train_cfg, cfg.test_cfg, 40,
|
||||
cfg.model.pretrained)
|
||||
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
|
||||
try:
|
||||
load_tensorrt_plugin()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom op from mmcv, \
|
||||
you may have to build mmcv with TensorRT from source.')
|
||||
model = TRTWraper(
|
||||
trt_file, input_names=['input'], output_names=['output'])
|
||||
|
||||
self.model = model
|
||||
self.device_id = device_id
|
||||
self.cfg = cfg
|
||||
|
||||
def forward_train(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self,
|
||||
img: torch.Tensor,
|
||||
img_metas: Iterable,
|
||||
rescale: bool = False):
|
||||
"""Test function.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||
trt_pred = self.model({'input': img})['output']
|
||||
|
||||
label_indexes, label_scores = self.label_convertor.tensor2idx(
|
||||
trt_pred, img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for string, score in zip(label_strings, label_scores):
|
||||
results.append(dict(text=string, score=score))
|
||||
|
||||
return results
|
|
@ -0,0 +1,305 @@
|
|||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import Iterable
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate
|
||||
from mmcv.tensorrt import is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine
|
||||
from tools.deployment.deploy_helper import (ONNXRuntimeDetector,
|
||||
ONNXRuntimeRecognizer,
|
||||
TensorRTDetector,
|
||||
TensorRTRecognizer)
|
||||
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
|
||||
|
||||
def get_GiB(x: int):
|
||||
"""return x GiB."""
|
||||
return x * (1 << 30)
|
||||
|
||||
|
||||
def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
|
||||
"""update img and its meta list."""
|
||||
N, C, H, W = img_list[0].shape
|
||||
img_meta = img_meta_list[0][0]
|
||||
img_shape = (H, W, C)
|
||||
if update_ori_shape:
|
||||
ori_shape = img_shape
|
||||
else:
|
||||
ori_shape = img_meta['ori_shape']
|
||||
pad_shape = img_shape
|
||||
new_img_meta_list = [[{
|
||||
'img_shape':
|
||||
img_shape,
|
||||
'ori_shape':
|
||||
ori_shape,
|
||||
'pad_shape':
|
||||
pad_shape,
|
||||
'filename':
|
||||
img_meta['filename'],
|
||||
'scale_factor':
|
||||
np.array(
|
||||
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2),
|
||||
'flip':
|
||||
False,
|
||||
} for _ in range(N)]]
|
||||
|
||||
return img_list, new_img_meta_list
|
||||
|
||||
|
||||
def _prepare_input_img(imgs, test_pipeline: Iterable[dict]):
|
||||
"""Inference image(s) with the detector.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded detector.
|
||||
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
||||
Either image files or loaded images.
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
if not isinstance(imgs[0], (np.ndarray, str)):
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
elif isinstance(imgs, (np.ndarray, str)):
|
||||
imgs = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
test_pipeline = replace_ImageToTensor(test_pipeline)
|
||||
test_pipeline = Compose(test_pipeline)
|
||||
|
||||
datas = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
# add information into dict
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
# get tensor from list to stack for batch mode (text detection)
|
||||
datas.append(data)
|
||||
|
||||
if isinstance(datas[0]['img'], list) and len(datas) > 1:
|
||||
raise Exception('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{len(datas)}')
|
||||
|
||||
data = collate(datas, samples_per_gpu=len(imgs))
|
||||
|
||||
# process img_metas
|
||||
if isinstance(data['img_metas'], list):
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
else:
|
||||
data['img_metas'] = data['img_metas'].data
|
||||
|
||||
if isinstance(data['img'], list):
|
||||
data['img'] = [img.data for img in data['img']]
|
||||
if isinstance(data['img'][0], list):
|
||||
data['img'] = [img[0] for img in data['img']]
|
||||
else:
|
||||
data['img'] = data['img'].data
|
||||
return data
|
||||
|
||||
|
||||
def onnx2tensorrt(onnx_file: str,
|
||||
model_type: str,
|
||||
trt_file: str,
|
||||
config: dict,
|
||||
input_config: dict,
|
||||
fp16: bool = False,
|
||||
verify: bool = False,
|
||||
show: bool = False,
|
||||
workspace_size: int = 1,
|
||||
verbose: bool = False):
|
||||
import tensorrt as trt
|
||||
min_shape = input_config['min_shape']
|
||||
max_shape = input_config['max_shape']
|
||||
# create trt engine and wraper
|
||||
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
|
||||
max_workspace_size = get_GiB(workspace_size)
|
||||
trt_engine = onnx2trt(
|
||||
onnx_file,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
|
||||
fp16_mode=fp16,
|
||||
max_workspace_size=max_workspace_size)
|
||||
save_dir, _ = osp.split(trt_file)
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_trt_engine(trt_engine, trt_file)
|
||||
print(f'Successfully created TensorRT engine: {trt_file}')
|
||||
|
||||
if verify:
|
||||
mm_inputs = _prepare_input_img(input_config['input_path'],
|
||||
config.data.test.pipeline)
|
||||
|
||||
imgs = mm_inputs.pop('img')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
|
||||
if isinstance(imgs, list):
|
||||
imgs = imgs[0]
|
||||
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
# update img_meta
|
||||
img_list, img_metas = _update_input_img(img_list, img_metas)
|
||||
|
||||
# Get results from ONNXRuntime
|
||||
if model_type == 'det':
|
||||
onnx_model = ONNXRuntimeDetector(onnx_file, config, 0)
|
||||
else:
|
||||
onnx_model = ONNXRuntimeRecognizer(onnx_file, config, 0)
|
||||
onnx_out = onnx_model.simple_test(
|
||||
img_list[0], img_metas[0], rescale=True)
|
||||
|
||||
# Get results from TensorRT
|
||||
if model_type == 'det':
|
||||
trt_model = TensorRTDetector(trt_file, config, 0)
|
||||
else:
|
||||
trt_model = TensorRTRecognizer(trt_file, config, 0)
|
||||
img_list[0] = img_list[0].to(torch.device('cuda:0'))
|
||||
trt_out = trt_model.simple_test(
|
||||
img_list[0], img_metas[0], rescale=True)
|
||||
|
||||
# compare results
|
||||
same_diff = 'same'
|
||||
if model_type == 'recog':
|
||||
for onnx_result, trt_result in zip(onnx_out, trt_out):
|
||||
if onnx_result['text'] != trt_result['text'] or \
|
||||
not np.allclose(
|
||||
np.array(onnx_result['score']),
|
||||
np.array(trt_result['score']),
|
||||
rtol=1e-4,
|
||||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
else:
|
||||
for onnx_result, trt_result in zip(onnx_out[0]['boundary_result'],
|
||||
trt_out[0]['boundary_result']):
|
||||
if not np.allclose(
|
||||
np.array(onnx_result),
|
||||
np.array(trt_result),
|
||||
rtol=1e-4,
|
||||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
print('The outputs are {} between TensorRT and ONNX'.format(same_diff))
|
||||
|
||||
if show:
|
||||
onnx_img = onnx_model.show_result(
|
||||
input_config['input_path'],
|
||||
onnx_out[0],
|
||||
out_file='onnx.jpg',
|
||||
show=False)
|
||||
trt_img = trt_model.show_result(
|
||||
input_config['input_path'],
|
||||
trt_out[0],
|
||||
out_file='tensorrt.jpg',
|
||||
show=False)
|
||||
if onnx_img is None:
|
||||
onnx_img = cv2.imread(input_config['input_path'])
|
||||
if trt_img is None:
|
||||
trt_img = cv2.imread(input_config['input_path'])
|
||||
|
||||
cv2.imshow('TensorRT', trt_img)
|
||||
cv2.imshow('ONNXRuntime', onnx_img)
|
||||
cv2.waitKey()
|
||||
return
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert MMOCR models from ONNX to TensorRT')
|
||||
parser.add_argument('model_config', help='Config file of the model')
|
||||
parser.add_argument(
|
||||
'model_type',
|
||||
type=str,
|
||||
help='Detection or recognition model to deploy.',
|
||||
choices=['recog', 'det'])
|
||||
parser.add_argument('image_path', type=str, help='Image for test')
|
||||
parser.add_argument('onnx_file', help='Path to the input ONNX model')
|
||||
parser.add_argument(
|
||||
'--trt-file',
|
||||
type=str,
|
||||
help='Path to the output TensorRT engine',
|
||||
default='tmp.trt')
|
||||
parser.add_argument(
|
||||
'--max-shape',
|
||||
type=int,
|
||||
nargs=4,
|
||||
default=[1, 3, 400, 600],
|
||||
help='Maximum shape of model input.')
|
||||
parser.add_argument(
|
||||
'--min-shape',
|
||||
type=int,
|
||||
nargs=4,
|
||||
default=[1, 3, 400, 600],
|
||||
help='Minimum shape of model input.')
|
||||
parser.add_argument(
|
||||
'--workspace-size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Max workspace size in GiB.')
|
||||
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
|
||||
parser.add_argument(
|
||||
'--verify',
|
||||
action='store_true',
|
||||
help='Whether Verify the outputs of ONNXRuntime and TensorRT.',
|
||||
default=True)
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='Whether visiualize outputs of ONNXRuntime and TensorRT.',
|
||||
default=True)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Whether to verbose logging messages while creating \
|
||||
TensorRT engine.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
|
||||
args = parse_args()
|
||||
|
||||
# check arguments
|
||||
assert osp.exists(args.model_config), 'Config {} not found.'.format(
|
||||
args.model_config)
|
||||
assert osp.exists(args.onnx_file), \
|
||||
'ONNX model {} not found.'.format(args.onnx_file)
|
||||
assert args.workspace_size >= 0, 'Workspace size less than 0.'
|
||||
for max_value, min_value in zip(args.max_shape, args.min_shape):
|
||||
assert max_value >= min_value, \
|
||||
'max_shape sould be larger than min shape'
|
||||
|
||||
input_config = {
|
||||
'min_shape': args.min_shape,
|
||||
'max_shape': args.max_shape,
|
||||
'input_path': args.image_path
|
||||
}
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.model_config)
|
||||
if cfg.data.test['type'] == 'ConcatDataset':
|
||||
cfg.data.test.pipeline = \
|
||||
cfg.data.test['datasets'][0].pipeline
|
||||
onnx2tensorrt(
|
||||
args.onnx_file,
|
||||
args.model_type,
|
||||
args.trt_file,
|
||||
cfg,
|
||||
input_config,
|
||||
fp16=args.fp16,
|
||||
verify=args.verify,
|
||||
show=args.show,
|
||||
workspace_size=args.workspace_size,
|
||||
verbose=args.verbose)
|
|
@ -0,0 +1,378 @@
|
|||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.onnx import register_extra_symbolics
|
||||
from mmcv.parallel import collate
|
||||
from tools.deployment.deploy_helper import (ONNXRuntimeDetector,
|
||||
ONNXRuntimeRecognizer)
|
||||
from torch import nn
|
||||
|
||||
from mmdet.apis import init_detector
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||
|
||||
|
||||
def _convert_batchnorm(module):
|
||||
module_output = module
|
||||
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
module_output.weight.data = module.weight.data.clone().detach()
|
||||
module_output.bias.data = module.bias.data.clone().detach()
|
||||
# keep requires_grad unchanged
|
||||
module_output.weight.requires_grad = module.weight.requires_grad
|
||||
module_output.bias.requires_grad = module.bias.requires_grad
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, _convert_batchnorm(child))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
|
||||
"""update img and its meta list."""
|
||||
N, C, H, W = img_list[0].shape
|
||||
img_meta = img_meta_list[0][0]
|
||||
img_shape = (H, W, C)
|
||||
if update_ori_shape:
|
||||
ori_shape = img_shape
|
||||
else:
|
||||
ori_shape = img_meta['ori_shape']
|
||||
pad_shape = img_shape
|
||||
new_img_meta_list = [[{
|
||||
'img_shape':
|
||||
img_shape,
|
||||
'ori_shape':
|
||||
ori_shape,
|
||||
'pad_shape':
|
||||
pad_shape,
|
||||
'filename':
|
||||
img_meta['filename'],
|
||||
'scale_factor':
|
||||
np.array(
|
||||
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2),
|
||||
'flip':
|
||||
False,
|
||||
} for _ in range(N)]]
|
||||
|
||||
return img_list, new_img_meta_list
|
||||
|
||||
|
||||
def _prepare_data(cfg, imgs):
|
||||
"""Inference image(s) with the detector.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded detector.
|
||||
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
||||
Either image files or loaded images.
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
if not isinstance(imgs[0], (np.ndarray, str)):
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
elif isinstance(imgs, (np.ndarray, str)):
|
||||
imgs = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
is_ndarray = isinstance(imgs[0], np.ndarray)
|
||||
|
||||
if is_ndarray:
|
||||
cfg = cfg.copy()
|
||||
# set loading pipeline type
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
|
||||
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
|
||||
datas = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
if is_ndarray:
|
||||
# directly add img
|
||||
data = dict(img=img)
|
||||
else:
|
||||
# add information into dict
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
# get tensor from list to stack for batch mode (text detection)
|
||||
datas.append(data)
|
||||
|
||||
if isinstance(datas[0]['img'], list) and len(datas) > 1:
|
||||
raise Exception('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{len(datas)}')
|
||||
|
||||
data = collate(datas, samples_per_gpu=len(imgs))
|
||||
|
||||
# process img_metas
|
||||
if isinstance(data['img_metas'], list):
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
else:
|
||||
data['img_metas'] = data['img_metas'].data
|
||||
|
||||
if isinstance(data['img'], list):
|
||||
data['img'] = [img.data for img in data['img']]
|
||||
if isinstance(data['img'][0], list):
|
||||
data['img'] = [img[0] for img in data['img']]
|
||||
else:
|
||||
data['img'] = data['img'].data
|
||||
return data
|
||||
|
||||
|
||||
def pytorch2onnx(model: nn.Module,
|
||||
model_type: str,
|
||||
img_path: str,
|
||||
verbose: bool = False,
|
||||
show: bool = False,
|
||||
opset_version: int = 11,
|
||||
output_file: str = 'tmp.onnx',
|
||||
verify: bool = False,
|
||||
dynamic_export: bool = False,
|
||||
device_id: int = 0):
|
||||
"""Export Pytorch model to ONNX model and verify the outputs are same
|
||||
between Pytorch and ONNX.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Pytorch model we want to export.
|
||||
model_type (str): Model type, detection or recognition model.
|
||||
img_path (str): We need to use this input to execute the model.
|
||||
opset_version (int): The onnx op version. Default: 11.
|
||||
verbose (bool): Whether print the computation graph. Default: False.
|
||||
show (bool): Whether visialize final results. Default: False.
|
||||
output_file (string): The path to where we store the output ONNX model.
|
||||
Default: `tmp.onnx`.
|
||||
verify (bool): Whether compare the outputs between Pytorch and ONNX.
|
||||
Default: False.
|
||||
device_id (id): Device id to place model and data.
|
||||
Default: 0
|
||||
"""
|
||||
device = torch.device(type='cuda', index=device_id)
|
||||
model.to(device).eval()
|
||||
_convert_batchnorm(model)
|
||||
# model.forward = model.simple_test
|
||||
end2end_res = {'filename': img_path}
|
||||
end2end_res['result'] = []
|
||||
|
||||
# mm_inputs = _demo_mm_inputs([1,3,512,512], 20)
|
||||
mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path)
|
||||
imgs = mm_inputs.pop('img')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
|
||||
if isinstance(imgs, list):
|
||||
imgs = imgs[0]
|
||||
|
||||
img_list = [img[None, :].to(device) for img in imgs]
|
||||
# update img_meta
|
||||
img_list, img_metas = _update_input_img(img_list, img_metas)
|
||||
|
||||
origin_forward = model.forward
|
||||
if (model_type == 'det'):
|
||||
model.forward = partial(
|
||||
model.simple_test, img_metas=img_metas, rescale=True)
|
||||
else:
|
||||
model.forward = partial(
|
||||
model.forward,
|
||||
img_metas=img_metas,
|
||||
return_loss=False,
|
||||
rescale=True)
|
||||
|
||||
# pytorch has some bug in pytorch1.3, we have to fix it
|
||||
# by replacing these existing op
|
||||
register_extra_symbolics(opset_version)
|
||||
dynamic_axes = None
|
||||
if dynamic_export and model_type == 'det':
|
||||
dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'output': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
}
|
||||
}
|
||||
elif dynamic_export and model_type == 'recog':
|
||||
dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
3: 'width'
|
||||
},
|
||||
'output': {
|
||||
0: 'batch',
|
||||
3: 'width'
|
||||
}
|
||||
}
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
model, (img_list[0], ),
|
||||
output_file,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
verbose=verbose,
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes)
|
||||
print(f'Successfully exported ONNX model: {output_file}')
|
||||
if verify:
|
||||
# check by onnx
|
||||
import onnx
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5)
|
||||
if dynamic_export:
|
||||
# scale image for dynamic shape test
|
||||
img_list = [
|
||||
nn.functional.interpolate(_, scale_factor=scale_factor)
|
||||
for _ in img_list
|
||||
]
|
||||
|
||||
# update img_meta
|
||||
img_list, img_metas = _update_input_img(img_list, img_metas)
|
||||
|
||||
# check the numerical value
|
||||
# get pytorch output
|
||||
with torch.no_grad():
|
||||
model.forward = origin_forward
|
||||
pytorch_out = model.simple_test(
|
||||
img_list[0], img_metas[0], rescale=True)
|
||||
|
||||
# get onnx output
|
||||
if model_type == 'det':
|
||||
onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id)
|
||||
else:
|
||||
onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg,
|
||||
device_id)
|
||||
onnx_out = onnx_model.simple_test(
|
||||
img_list[0], img_metas[0], rescale=True)
|
||||
|
||||
# compare results
|
||||
same_diff = 'same'
|
||||
if model_type == 'recog':
|
||||
for onnx_result, pytorch_result in zip(onnx_out, pytorch_out):
|
||||
if onnx_result['text'] != pytorch_result[
|
||||
'text'] or not np.allclose(
|
||||
np.array(onnx_result['score']),
|
||||
np.array(pytorch_result['score']),
|
||||
rtol=1e-4,
|
||||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
else:
|
||||
for onnx_result, pytorch_result in zip(
|
||||
onnx_out[0]['boundary_result'],
|
||||
pytorch_out[0]['boundary_result']):
|
||||
if not np.allclose(
|
||||
np.array(onnx_result),
|
||||
np.array(pytorch_result),
|
||||
rtol=1e-4,
|
||||
atol=1e-4):
|
||||
same_diff = 'different'
|
||||
break
|
||||
print('The outputs are {} between Pytorch and ONNX'.format(same_diff))
|
||||
|
||||
if show:
|
||||
onnx_img = onnx_model.show_result(
|
||||
img_path, onnx_out[0], out_file='onnx.jpg', show=False)
|
||||
pytorch_img = model.show_result(
|
||||
img_path, pytorch_out[0], out_file='pytorch.jpg', show=False)
|
||||
if onnx_img is None:
|
||||
onnx_img = cv2.imread(img_path)
|
||||
if pytorch_img is None:
|
||||
pytorch_img = cv2.imread(img_path)
|
||||
|
||||
cv2.imshow('PyTorch', pytorch_img)
|
||||
cv2.imshow('ONNXRuntime', onnx_img)
|
||||
cv2.waitKey()
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(
|
||||
description='Convert MMOCR models from pytorch to ONNX')
|
||||
parser.add_argument('model_config', type=str, help='Config file.')
|
||||
parser.add_argument(
|
||||
'model_ckpt', type=str, help='Checkpint file (local or url).')
|
||||
parser.add_argument(
|
||||
'model_type',
|
||||
type=str,
|
||||
help='Detection or recognition model to deploy.',
|
||||
choices=['recog', 'det'])
|
||||
parser.add_argument('image_path', type=str, help='Input Image file.')
|
||||
parser.add_argument(
|
||||
'--output-file',
|
||||
type=str,
|
||||
help='Output file name of the onnx model.',
|
||||
default='tmp.onnx')
|
||||
parser.add_argument(
|
||||
'--device-id', default=0, help='Device used for inference.')
|
||||
parser.add_argument(
|
||||
'--opset-version',
|
||||
type=int,
|
||||
help='ONNX opset version, default to 11.',
|
||||
default=11)
|
||||
parser.add_argument(
|
||||
'--verify',
|
||||
action='store_true',
|
||||
help='Whether verify the outputs of onnx and pytorch are same.',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Whether print the computation graph.',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='Whether visualize final output.',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--dynamic-export',
|
||||
action='store_true',
|
||||
help='Whether dynamicly export onnx model.',
|
||||
default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device(type='cuda', index=args.device_id)
|
||||
|
||||
# build model
|
||||
model = init_detector(args.model_config, args.model_ckpt, device=device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = \
|
||||
model.cfg.data.test['datasets'][0].pipeline
|
||||
|
||||
pytorch2onnx(
|
||||
model,
|
||||
model_type=args.model_type,
|
||||
output_file=args.output_file,
|
||||
img_path=args.image_path,
|
||||
opset_version=args.opset_version,
|
||||
verify=args.verify,
|
||||
verbose=args.verbose,
|
||||
show=args.show,
|
||||
device_id=args.device_id,
|
||||
dynamic_export=args.dynamic_export)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue