From f1b003ddb122355118016160dabac512326b5d3f Mon Sep 17 00:00:00 2001
From: AllentDan <41138331+AllentDan@users.noreply.github.com>
Date: Fri, 18 Jun 2021 12:27:10 +0800
Subject: [PATCH] [Feature] Add deployment evaluation (#291)
* add deployment evaluation
* fix lint
* remove cpu unit tests for trt and onnx
* use pytest.mark to skip cpu unit test
* move to mmocr/core
* emm... renamed to wrappers
* renamed to deploy_utils
* renamed unit test to test_deploy_utils
* fix lint
* using pytest.mark.importorskip
---
docs/deployment.md | 187 ++++++++++++++-
mmocr/core/deployment/__init__.py | 7 +
.../core/deployment/deploy_utils.py | 0
setup.cfg | 2 +-
tests/test_core/test_deploy_utils.py | 220 ++++++++++++++++++
tools/deployment/deploy_test.py | 92 ++++++++
tools/deployment/onnx2tensorrt.py | 6 +-
tools/deployment/pytorch2onnx.py | 3 +-
8 files changed, 509 insertions(+), 8 deletions(-)
create mode 100644 mmocr/core/deployment/__init__.py
rename tools/deployment/deploy_helper.py => mmocr/core/deployment/deploy_utils.py (100%)
create mode 100644 tests/test_core/test_deploy_utils.py
create mode 100644 tools/deployment/deploy_test.py
diff --git a/docs/deployment.md b/docs/deployment.md
index 040dd34f..1c16b136 100644
--- a/docs/deployment.md
+++ b/docs/deployment.md
@@ -48,7 +48,7 @@ The table below lists the models that are guaranteed to be exportable to ONNX an
| PSENet | [psenet_r50_fpnf_600e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py) | Y | Y | |
| PANet | [panet_r18_fpem_ffm_600e_ctw1500.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py) | Y | Y | |
| PANet | [panet_r18_fpem_ffm_600e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py) | Y | Y | |
-| CRNN | [crnn_academic_dataset.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textrecog/crnn/crnn_academic_dataset.py) | Y | Y | |
+| CRNN | [crnn_academic_dataset.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textrecog/crnn/crnn_academic_dataset.py) | Y | Y | CRNN only accepts input with height 32 |
**Notes**:
@@ -112,3 +112,188 @@ The table below lists the models that are guaranteed to be exportable to TensorR
- *All models above are tested with Pytorch==1.8.1, onnxruntime==1.7.0 and tensorrt==7.2.1.6*
- If you meet any problem with the listed models above, please create an issue and it would be taken care of soon. For models not included in the list, please try to solve them by yourself.
- Because this feature is experimental and may change fast, please always try with the latest `mmcv` and `mmocr`.
+
+
+### Evaluate ONNX and TensorRT Models (experimental)
+
+We provide methods to evaluate TensorRT and ONNX models in `tools/deployment/deploy_test.py`.
+
+#### Prerequisite
+To evaluate ONNX and TensorRT models, onnx, onnxruntime and TensorRT should be installed first. Install `mmcv-full` with ONNXRuntime custom ops and TensorRT plugins follow [ONNXRuntime in mmcv](https://mmcv.readthedocs.io/en/latest/onnxruntime_op.html) and [TensorRT plugin in mmcv](https://github.com/open-mmlab/mmcv/blob/master/docs/tensorrt_plugin.md).
+
+#### Usage
+
+```bash
+python tools/deploy_test.py \
+ ${CONFIG_FILE} \
+ ${MODEL_PATH} \
+ ${MODEL_TYPE} \
+ ${BACKEND} \
+ --eval ${METRICS} \
+ --device ${DEVICE}
+```
+
+#### Description of all arguments
+
+- `model_config`: The path of a model config file.
+- `model_file`: The path of a TensorRT or an ONNX model file.
+- `model_type`: Detection or recognition model to deploy. Choose `recog` or `det`.
+- `backend`: The backend for testing, choose TensorRT or ONNXRuntime.
+- `--eval`: The evaluation metrics. `acc` for recognition models, `hmean-iou` for detection models.
+- `--device`: Device for evaluation, `cuda:0` as default.
+
+#### Results and Models
+
+
+
+
+
+ Model |
+ Config |
+ Dataset |
+ Metric |
+ PyTorch |
+ ONNX Runtime |
+ TensorRT FP32 |
+ TensorRT FP16 |
+
+
+
+
+ DBNet |
+ dbnet_r18_fpnc_1200e_icdar2015.py
|
+ icdar2015 |
+ Recall
|
+ 0.731 |
+ 0.731 |
+ 0.678 |
+ 0.679 |
+
+
+ Precision |
+ 0.871 |
+ 0.871 |
+ 0.844 |
+ 0.842 |
+
+
+ Hmean |
+ 0.795 |
+ 0.795 |
+ 0.752 |
+ 0.752 |
+
+
+ DBNet* |
+ dbnet_r18_fpnc_1200e_icdar2015.py
|
+ icdar2015 |
+ Recall
|
+ 0.720 |
+ 0.720 |
+ 0.720 |
+ 0.718 |
+
+
+ Precision |
+ 0.868 |
+ 0.868 |
+ 0.868 |
+ 0.868 |
+
+
+ Hmean |
+ 0.787 |
+ 0.787 |
+ 0.787 |
+ 0.786 |
+
+
+ PSENet |
+ psenet_r50_fpnf_600e_icdar2015.py
|
+ icdar2015 |
+ Recall
|
+ 0.753 |
+ 0.753 |
+ 0.753 |
+ 0.752 |
+
+
+ Precision |
+ 0.867 |
+ 0.867 |
+ 0.867 |
+ 0.867 |
+
+
+ Hmean |
+ 0.806 |
+ 0.806 |
+ 0.806 |
+ 0.805 |
+
+
+ PANet |
+ panet_r18_fpem_ffm_600e_icdar2015.py
|
+ icdar2015 |
+ Recall
|
+ 0.740 |
+ 0.740 |
+ 0.687 |
+ N/A |
+
+
+ Precision |
+ 0.860 |
+ 0.860 |
+ 0.815 |
+ N/A |
+
+
+ Hmean |
+ 0.796 |
+ 0.796 |
+ 0.746 |
+ N/A |
+
+
+ PANet* |
+ panet_r18_fpem_ffm_600e_icdar2015.py
|
+ icdar2015 |
+ Recall
|
+ 0.736 |
+ 0.736 |
+ 0.736 |
+ N/A |
+
+
+ Precision |
+ 0.857 |
+ 0.857 |
+ 0.857 |
+ N/A |
+
+
+ Hmean |
+ 0.792 |
+ 0.792 |
+ 0.792 |
+ N/A |
+
+
+ CRNN |
+ crnn_academic_dataset.py
|
+ IIIT5K |
+ Acc |
+ 0.806 |
+ 0.806 |
+ 0.806 |
+ 0.806 |
+
+
+
+
+**Notes**:
+- TensorRT upsampling operation is a little different from pytorch. For DBNet and PANet, we suggest replacing upsampling operations with neast mode to operations with bilinear mode. [Here](https://github.com/open-mmlab/mmocr/blob/50a25e718a028c8b9d96f497e241767dbe9617d1/mmocr/models/textdet/necks/fpem_ffm.py#L33) for PANet, [here](https://github.com/open-mmlab/mmocr/blob/50a25e718a028c8b9d96f497e241767dbe9617d1/mmocr/models/textdet/necks/fpn_cat.py#L111) and [here](https://github.com/open-mmlab/mmocr/blob/50a25e718a028c8b9d96f497e241767dbe9617d1/mmocr/models/textdet/necks/fpn_cat.py#L121) for DBNet. As is shown in the above table, networks with tag * means the upsampling mode is changed.
+- Note that, changing upsampling mode reduces less performance compared with using nearst mode. However, the weights of networks are trained through nearst mode. To persue best performance, using bilinear mode for both training and TensorRT deployment is recommanded.
+- All ONNX and TensorRT models are evaluated with dynamic shape on the datasets and images are preprocessed according to the original config file.
+- This tool is still experimental, and we only support `detection` and `recognition` for now.
diff --git a/mmocr/core/deployment/__init__.py b/mmocr/core/deployment/__init__.py
new file mode 100644
index 00000000..643d20a8
--- /dev/null
+++ b/mmocr/core/deployment/__init__.py
@@ -0,0 +1,7 @@
+from .deploy_utils import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
+ TensorRTDetector, TensorRTRecognizer)
+
+__all__ = [
+ 'ONNXRuntimeRecognizer', 'ONNXRuntimeDetector', 'TensorRTDetector',
+ 'TensorRTRecognizer'
+]
diff --git a/tools/deployment/deploy_helper.py b/mmocr/core/deployment/deploy_utils.py
similarity index 100%
rename from tools/deployment/deploy_helper.py
rename to mmocr/core/deployment/deploy_utils.py
diff --git a/setup.cfg b/setup.cfg
index 21cbc39e..c368a1d7 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -20,7 +20,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmocr
-known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
+known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,packaging,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
diff --git a/tests/test_core/test_deploy_utils.py b/tests/test_core/test_deploy_utils.py
new file mode 100644
index 00000000..748a581c
--- /dev/null
+++ b/tests/test_core/test_deploy_utils.py
@@ -0,0 +1,220 @@
+from functools import partial
+
+import mmcv
+import numpy as np
+import pytest
+import torch
+from mmdet.models import build_detector
+from packaging import version
+
+from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
+ TensorRTDetector, TensorRTRecognizer)
+
+
+@pytest.mark.skipif(torch.__version__ == 'parrots', reason='skip parrots.')
+@pytest.mark.skipif(
+ version.parse(torch.__version__) < version.parse('1.4.0'),
+ reason='skip if torch=1.3.x')
+@pytest.mark.skipif(
+ not torch.cuda.is_available(), reason='skip if on cpu device')
+@pytest.mark.importorskip('onnxruntime')
+@pytest.mark.importorskip('tensorrt')
+@pytest.mark.importorskip('mmcv.tensorrt')
+def test_detector_wraper():
+ import onnxruntime as ort # noqa: F401
+ import tensorrt as trt
+ from mmcv.tensorrt import (onnx2trt, save_trt_engine)
+
+ onnx_path = 'tmp.onnx'
+ cfg = dict(
+ model=dict(
+ type='DBNet',
+ pretrained='torchvision://resnet18',
+ backbone=dict(
+ type='ResNet',
+ depth=18,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ style='caffe'),
+ neck=dict(
+ type='FPNC',
+ in_channels=[64, 128, 256, 512],
+ lateral_channels=256),
+ bbox_head=dict(
+ type='DBHead',
+ text_repr_type='quad',
+ in_channels=256,
+ loss=dict(type='DBLoss', alpha=5.0, beta=10.0,
+ bbce_loss=True)),
+ train_cfg=None,
+ test_cfg=None))
+
+ cfg = mmcv.Config(cfg)
+
+ pytorch_model = build_detector(cfg.model, None, None)
+
+ # prepare data
+ inputs = torch.rand(1, 3, 224, 224)
+ img_metas = [{
+ 'img_shape': [1, 3, 224, 224],
+ 'ori_shape': [1, 3, 224, 224],
+ 'pad_shape': [1, 3, 224, 224],
+ 'filename': None,
+ 'scale_factor': np.array([1, 1, 1, 1])
+ }]
+
+ pytorch_model.forward = pytorch_model.forward_dummy
+ with torch.no_grad():
+ torch.onnx.export(
+ pytorch_model,
+ inputs,
+ onnx_path,
+ input_names=['input'],
+ output_names=['output'],
+ export_params=True,
+ keep_initializers_as_inputs=False,
+ verbose=False,
+ opset_version=11)
+
+ # TensorRT part
+ def get_GiB(x: int):
+ """return x GiB."""
+ return x * (1 << 30)
+
+ trt_path = onnx_path.replace('.onnx', '.trt')
+ min_shape = [1, 3, 224, 224]
+ max_shape = [1, 3, 224, 224]
+ # create trt engine and wraper
+ opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
+ max_workspace_size = get_GiB(1)
+ trt_engine = onnx2trt(
+ onnx_path,
+ opt_shape_dict,
+ log_level=trt.Logger.ERROR,
+ fp16_mode=False,
+ max_workspace_size=max_workspace_size)
+ save_trt_engine(trt_engine, trt_path)
+ print(f'Successfully created TensorRT engine: {trt_path}')
+
+ wrap_onnx = ONNXRuntimeDetector(onnx_path, cfg, 0)
+ wrap_trt = TensorRTDetector(trt_path, cfg, 0)
+ # os.remove(onnx_path)
+ assert isinstance(wrap_onnx, ONNXRuntimeDetector)
+ assert isinstance(wrap_trt, TensorRTDetector)
+
+ with torch.no_grad():
+ onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
+ trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
+
+ assert isinstance(onnx_outputs[0], dict)
+ assert isinstance(trt_outputs[0], dict)
+ assert 'boundary_result' in onnx_outputs[0]
+ assert 'boundary_result' in trt_outputs[0]
+
+
+@pytest.mark.skipif(torch.__version__ == 'parrots', reason='skip parrots.')
+@pytest.mark.skipif(
+ version.parse(torch.__version__) < version.parse('1.4.0'),
+ reason='skip if torch=1.3.x')
+@pytest.mark.skipif(
+ not torch.cuda.is_available(), reason='skip if on cpu device')
+@pytest.mark.importorskip('onnxruntime')
+@pytest.mark.importorskip('tensorrt')
+@pytest.mark.importorskip('mmcv.tensorrt')
+def test_recognizer_wraper():
+ import onnxruntime as ort # noqa: F401
+ import tensorrt as trt
+ from mmcv.tensorrt import (onnx2trt, save_trt_engine)
+
+ onnx_path = 'tmp.onnx'
+ cfg = dict(
+ label_convertor=dict(
+ type='CTCConvertor',
+ dict_type='DICT36',
+ with_unknown=False,
+ lower=True),
+ model=dict(
+ type='CRNNNet',
+ preprocessor=None,
+ backbone=dict(
+ type='VeryDeepVgg', leaky_relu=False, input_channels=1),
+ encoder=None,
+ decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
+ loss=dict(type='CTCLoss'),
+ label_convertor=dict(
+ type='CTCConvertor',
+ dict_type='DICT36',
+ with_unknown=False,
+ lower=True),
+ pretrained=None),
+ train_cfg=None,
+ test_cfg=None)
+
+ cfg = mmcv.Config(cfg)
+
+ pytorch_model = build_detector(cfg.model, None, None)
+
+ # prepare data
+ inputs = torch.rand(1, 1, 32, 32)
+ img_metas = [{
+ 'img_shape': [1, 1, 32, 32],
+ 'ori_shape': [1, 1, 32, 32],
+ 'pad_shape': [1, 1, 32, 32],
+ 'filename': None,
+ 'scale_factor': np.array([1, 1, 1, 1])
+ }]
+
+ pytorch_model.forward = partial(
+ pytorch_model.forward,
+ img_metas=img_metas,
+ return_loss=False,
+ rescale=True)
+ with torch.no_grad():
+ torch.onnx.export(
+ pytorch_model,
+ inputs,
+ onnx_path,
+ input_names=['input'],
+ output_names=['output'],
+ export_params=True,
+ keep_initializers_as_inputs=False,
+ verbose=False,
+ opset_version=11)
+
+ # TensorRT part
+ def get_GiB(x: int):
+ """return x GiB."""
+ return x * (1 << 30)
+
+ trt_path = onnx_path.replace('.onnx', '.trt')
+ min_shape = [1, 1, 32, 32]
+ max_shape = [1, 1, 32, 32]
+ # create trt engine and wraper
+ opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
+ max_workspace_size = get_GiB(1)
+ trt_engine = onnx2trt(
+ onnx_path,
+ opt_shape_dict,
+ log_level=trt.Logger.ERROR,
+ fp16_mode=False,
+ max_workspace_size=max_workspace_size)
+ save_trt_engine(trt_engine, trt_path)
+ print(f'Successfully created TensorRT engine: {trt_path}')
+
+ wrap_onnx = ONNXRuntimeRecognizer(onnx_path, cfg, 0)
+ wrap_trt = TensorRTRecognizer(trt_path, cfg, 0)
+ # os.remove(onnx_path)
+ assert isinstance(wrap_onnx, ONNXRuntimeRecognizer)
+ assert isinstance(wrap_trt, TensorRTRecognizer)
+
+ with torch.no_grad():
+ onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
+ trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
+
+ assert isinstance(onnx_outputs[0], dict)
+ assert isinstance(trt_outputs[0], dict)
+ assert 'text' in onnx_outputs[0]
+ assert 'text' in trt_outputs[0]
diff --git a/tools/deployment/deploy_test.py b/tools/deployment/deploy_test.py
new file mode 100644
index 00000000..ca387f6d
--- /dev/null
+++ b/tools/deployment/deploy_test.py
@@ -0,0 +1,92 @@
+import argparse
+
+from mmcv import Config
+from mmcv.parallel import MMDataParallel
+from mmcv.runner import get_dist_info
+from mmdet.apis import single_gpu_test
+
+from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
+ TensorRTDetector, TensorRTRecognizer)
+from mmocr.datasets import build_dataloader, build_dataset
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='MMOCR test (and eval) a onnx or tensorrt model.')
+ parser.add_argument('model_config', type=str, help='Config file.')
+ parser.add_argument(
+ 'model_file', type=str, help='Input file name for evaluation.')
+ parser.add_argument(
+ 'model_type',
+ type=str,
+ help='Detection or recognition model to deploy.',
+ choices=['recog', 'det'])
+ parser.add_argument(
+ 'backend',
+ type=str,
+ help='Which backend to test, TensorRT or ONNXRuntime.',
+ choices=['TensorRT', 'ONNXRuntime'])
+ parser.add_argument(
+ '--eval',
+ type=str,
+ nargs='+',
+ help='The evaluation metrics, which depends on the dataset, e.g.,'
+ '"bbox", "seg", "proposal" for COCO, and "mAP", "recall" for'
+ 'PASCAL VOC.')
+ parser.add_argument(
+ '--device', default='cuda:0', help='Device used for inference.')
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ args = parse_args()
+ if args.device == 'cpu':
+ args.device = None
+
+ cfg = Config.fromfile(args.model_config)
+
+ # build the model
+ if args.model_type == 'det':
+ if args.backend == 'TensorRT':
+ model = TensorRTDetector(args.model_file, cfg, 0)
+ else:
+ model = ONNXRuntimeDetector(args.model_file, cfg, 0)
+ else:
+ if args.backend == 'TensorRT':
+ model = TensorRTRecognizer(args.model_file, cfg, 0)
+ else:
+ model = ONNXRuntimeRecognizer(args.model_file, cfg, 0)
+
+ # build the dataloader
+ samples_per_gpu = 1
+ dataset = build_dataset(cfg.data.test)
+ data_loader = build_dataloader(
+ dataset,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=False,
+ shuffle=False)
+
+ model = MMDataParallel(model, device_ids=[0])
+ outputs = single_gpu_test(model, data_loader)
+
+ rank, _ = get_dist_info()
+ if rank == 0:
+ kwargs = {}
+ if args.eval:
+ eval_kwargs = cfg.get('evaluation', {}).copy()
+ # hard-code way to remove EvalHook args
+ for key in [
+ 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
+ 'rule'
+ ]:
+ eval_kwargs.pop(key, None)
+ eval_kwargs.update(dict(metric=args.eval, **kwargs))
+ print(dataset.evaluate(outputs, **eval_kwargs))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/deployment/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py
index bac2d941..2cc94319 100644
--- a/tools/deployment/onnx2tensorrt.py
+++ b/tools/deployment/onnx2tensorrt.py
@@ -11,11 +11,9 @@ from mmcv.parallel import collate
from mmcv.tensorrt import is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
-from tools.deployment.deploy_helper import (ONNXRuntimeDetector,
- ONNXRuntimeRecognizer,
- TensorRTDetector,
- TensorRTRecognizer)
+from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
+ TensorRTDetector, TensorRTRecognizer)
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
diff --git a/tools/deployment/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py
index 0b2a7142..f0c6fc8c 100644
--- a/tools/deployment/pytorch2onnx.py
+++ b/tools/deployment/pytorch2onnx.py
@@ -9,10 +9,9 @@ from mmcv.parallel import collate
from mmdet.apis import init_detector
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
-from tools.deployment.deploy_helper import (ONNXRuntimeDetector,
- ONNXRuntimeRecognizer)
from torch import nn
+from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401