mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Add deployment evaluation (#291)
* add deployment evaluation * fix lint * remove cpu unit tests for trt and onnx * use pytest.mark to skip cpu unit test * move to mmocr/core * emm... renamed to wrappers * renamed to deploy_utils * renamed unit test to test_deploy_utils * fix lint * using pytest.mark.importorskippull/298/head
parent
d57f279083
commit
f1b003ddb1
|
@ -48,7 +48,7 @@ The table below lists the models that are guaranteed to be exportable to ONNX an
|
|||
| PSENet | [psenet_r50_fpnf_600e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py) | Y | Y | |
|
||||
| PANet | [panet_r18_fpem_ffm_600e_ctw1500.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py) | Y | Y | |
|
||||
| PANet | [panet_r18_fpem_ffm_600e_icdar2015.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py) | Y | Y | |
|
||||
| CRNN | [crnn_academic_dataset.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textrecog/crnn/crnn_academic_dataset.py) | Y | Y | |
|
||||
| CRNN | [crnn_academic_dataset.py](https://github.com/open-mmlab/mmocr/blob/main/configs/textrecog/crnn/crnn_academic_dataset.py) | Y | Y | CRNN only accepts input with height 32 |
|
||||
|
||||
**Notes**:
|
||||
|
||||
|
@ -112,3 +112,188 @@ The table below lists the models that are guaranteed to be exportable to TensorR
|
|||
- *All models above are tested with Pytorch==1.8.1, onnxruntime==1.7.0 and tensorrt==7.2.1.6*
|
||||
- If you meet any problem with the listed models above, please create an issue and it would be taken care of soon. For models not included in the list, please try to solve them by yourself.
|
||||
- Because this feature is experimental and may change fast, please always try with the latest `mmcv` and `mmocr`.
|
||||
|
||||
|
||||
### Evaluate ONNX and TensorRT Models (experimental)
|
||||
|
||||
We provide methods to evaluate TensorRT and ONNX models in `tools/deployment/deploy_test.py`.
|
||||
|
||||
#### Prerequisite
|
||||
To evaluate ONNX and TensorRT models, onnx, onnxruntime and TensorRT should be installed first. Install `mmcv-full` with ONNXRuntime custom ops and TensorRT plugins follow [ONNXRuntime in mmcv](https://mmcv.readthedocs.io/en/latest/onnxruntime_op.html) and [TensorRT plugin in mmcv](https://github.com/open-mmlab/mmcv/blob/master/docs/tensorrt_plugin.md).
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
python tools/deploy_test.py \
|
||||
${CONFIG_FILE} \
|
||||
${MODEL_PATH} \
|
||||
${MODEL_TYPE} \
|
||||
${BACKEND} \
|
||||
--eval ${METRICS} \
|
||||
--device ${DEVICE}
|
||||
```
|
||||
|
||||
#### Description of all arguments
|
||||
|
||||
- `model_config`: The path of a model config file.
|
||||
- `model_file`: The path of a TensorRT or an ONNX model file.
|
||||
- `model_type`: Detection or recognition model to deploy. Choose `recog` or `det`.
|
||||
- `backend`: The backend for testing, choose TensorRT or ONNXRuntime.
|
||||
- `--eval`: The evaluation metrics. `acc` for recognition models, `hmean-iou` for detection models.
|
||||
- `--device`: Device for evaluation, `cuda:0` as default.
|
||||
|
||||
#### Results and Models
|
||||
|
||||
|
||||
<table class="tg">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="tg-9wq8">Model</th>
|
||||
<th class="tg-9wq8">Config</th>
|
||||
<th class="tg-9wq8">Dataset</th>
|
||||
<th class="tg-9wq8">Metric</th>
|
||||
<th class="tg-9wq8">PyTorch</th>
|
||||
<th class="tg-9wq8">ONNX Runtime</th>
|
||||
<th class="tg-9wq8">TensorRT FP32</th>
|
||||
<th class="tg-9wq8">TensorRT FP16</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td class="tg-9wq8" rowspan="3">DBNet</td>
|
||||
<td class="tg-9wq8" rowspan="3">dbnet_r18_fpnc_1200e_icdar2015.py<br></td>
|
||||
<td class="tg-9wq8" rowspan="3">icdar2015</td>
|
||||
<td class="tg-9wq8"><span style="font-style:normal">Recall</span><br></td>
|
||||
<td class="tg-9wq8">0.731</td>
|
||||
<td class="tg-9wq8">0.731</td>
|
||||
<td class="tg-9wq8">0.678</td>
|
||||
<td class="tg-9wq8">0.679</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8">Precision</td>
|
||||
<td class="tg-9wq8"><span style="font-weight:400;font-style:normal">0.871</span></td>
|
||||
<td class="tg-9wq8">0.871</td>
|
||||
<td class="tg-9wq8">0.844</td>
|
||||
<td class="tg-9wq8">0.842</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8"><span style="font-style:normal">Hmean</span></td>
|
||||
<td class="tg-9wq8"><span style="font-weight:400;font-style:normal">0.795</span></td>
|
||||
<td class="tg-9wq8">0.795</td>
|
||||
<td class="tg-9wq8">0.752</td>
|
||||
<td class="tg-9wq8">0.752</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" rowspan="3">DBNet*</td>
|
||||
<td class="tg-9wq8" rowspan="3">dbnet_r18_fpnc_1200e_icdar2015.py<br></td>
|
||||
<td class="tg-9wq8" rowspan="3">icdar2015</td>
|
||||
<td class="tg-9wq8"><span style="font-style:normal">Recall</span><br></td>
|
||||
<td class="tg-9wq8">0.720</td>
|
||||
<td class="tg-9wq8">0.720</td>
|
||||
<td class="tg-9wq8">0.720</td>
|
||||
<td class="tg-9wq8">0.718</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8">Precision</td>
|
||||
<td class="tg-9wq8"><span style="font-weight:400;font-style:normal">0.868</span></td>
|
||||
<td class="tg-9wq8"><span style="font-weight:400;font-style:normal">0.868</span></td>
|
||||
<td class="tg-9wq8"><span style="font-weight:400;font-style:normal">0.868</span></td>
|
||||
<td class="tg-9wq8">0.868</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8"><span style="font-style:normal">Hmean</span></td>
|
||||
<td class="tg-9wq8"><span style="font-weight:400;font-style:normal">0.787</span></td>
|
||||
<td class="tg-9wq8"><span style="font-weight:400;font-style:normal">0.787</span></td>
|
||||
<td class="tg-9wq8"><span style="font-weight:400;font-style:normal">0.787</span></td>
|
||||
<td class="tg-9wq8">0.786</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" rowspan="3">PSENet</td>
|
||||
<td class="tg-9wq8" rowspan="3">psenet_r50_fpnf_600e_icdar2015.py<br></td>
|
||||
<td class="tg-9wq8" rowspan="3">icdar2015</td>
|
||||
<td class="tg-9wq8"><span style="font-style:normal">Recall</span><br></td>
|
||||
<td class="tg-9wq8">0.753</td>
|
||||
<td class="tg-9wq8">0.753</td>
|
||||
<td class="tg-9wq8">0.753</td>
|
||||
<td class="tg-9wq8">0.752</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8">Precision</td>
|
||||
<td class="tg-9wq8">0.867</td>
|
||||
<td class="tg-9wq8">0.867</td>
|
||||
<td class="tg-9wq8">0.867</td>
|
||||
<td class="tg-9wq8">0.867</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8"><span style="font-style:normal">Hmean</span></td>
|
||||
<td class="tg-9wq8">0.806</td>
|
||||
<td class="tg-9wq8">0.806</td>
|
||||
<td class="tg-9wq8">0.806</td>
|
||||
<td class="tg-9wq8">0.805</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8" rowspan="3">PANet</td>
|
||||
<td class="tg-9wq8" rowspan="3">panet_r18_fpem_ffm_600e_icdar2015.py<br></td>
|
||||
<td class="tg-9wq8" rowspan="3">icdar2015</td>
|
||||
<td class="tg-9wq8">Recall<br></td>
|
||||
<td class="tg-9wq8">0.740</td>
|
||||
<td class="tg-9wq8">0.740</td>
|
||||
<td class="tg-9wq8">0.687</td>
|
||||
<td class="tg-9wq8">N/A</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8">Precision</td>
|
||||
<td class="tg-9wq8">0.860</td>
|
||||
<td class="tg-9wq8">0.860</td>
|
||||
<td class="tg-9wq8">0.815</td>
|
||||
<td class="tg-9wq8">N/A</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8">Hmean</td>
|
||||
<td class="tg-9wq8">0.796</td>
|
||||
<td class="tg-9wq8">0.796</td>
|
||||
<td class="tg-9wq8">0.746</td>
|
||||
<td class="tg-9wq8">N/A</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-nrix" rowspan="3">PANet*</td>
|
||||
<td class="tg-nrix" rowspan="3">panet_r18_fpem_ffm_600e_icdar2015.py<br></td>
|
||||
<td class="tg-nrix" rowspan="3">icdar2015</td>
|
||||
<td class="tg-nrix">Recall<br></td>
|
||||
<td class="tg-nrix">0.736</td>
|
||||
<td class="tg-nrix">0.736</td>
|
||||
<td class="tg-nrix">0.736</td>
|
||||
<td class="tg-nrix">N/A</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-nrix">Precision</td>
|
||||
<td class="tg-nrix">0.857</td>
|
||||
<td class="tg-nrix">0.857</td>
|
||||
<td class="tg-nrix">0.857</td>
|
||||
<td class="tg-nrix">N/A</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-nrix">Hmean</td>
|
||||
<td class="tg-nrix">0.792</td>
|
||||
<td class="tg-nrix">0.792</td>
|
||||
<td class="tg-nrix">0.792</td>
|
||||
<td class="tg-nrix">N/A</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="tg-9wq8">CRNN</td>
|
||||
<td class="tg-9wq8">crnn_academic_dataset.py<br></td>
|
||||
<td class="tg-9wq8">IIIT5K</td>
|
||||
<td class="tg-9wq8">Acc</td>
|
||||
<td class="tg-9wq8">0.806</td>
|
||||
<td class="tg-9wq8">0.806</td>
|
||||
<td class="tg-9wq8">0.806</td>
|
||||
<td class="tg-9wq8">0.806</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
**Notes**:
|
||||
- TensorRT upsampling operation is a little different from pytorch. For DBNet and PANet, we suggest replacing upsampling operations with neast mode to operations with bilinear mode. [Here](https://github.com/open-mmlab/mmocr/blob/50a25e718a028c8b9d96f497e241767dbe9617d1/mmocr/models/textdet/necks/fpem_ffm.py#L33) for PANet, [here](https://github.com/open-mmlab/mmocr/blob/50a25e718a028c8b9d96f497e241767dbe9617d1/mmocr/models/textdet/necks/fpn_cat.py#L111) and [here](https://github.com/open-mmlab/mmocr/blob/50a25e718a028c8b9d96f497e241767dbe9617d1/mmocr/models/textdet/necks/fpn_cat.py#L121) for DBNet. As is shown in the above table, networks with tag * means the upsampling mode is changed.
|
||||
- Note that, changing upsampling mode reduces less performance compared with using nearst mode. However, the weights of networks are trained through nearst mode. To persue best performance, using bilinear mode for both training and TensorRT deployment is recommanded.
|
||||
- All ONNX and TensorRT models are evaluated with dynamic shape on the datasets and images are preprocessed according to the original config file.
|
||||
- This tool is still experimental, and we only support `detection` and `recognition` for now.
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .deploy_utils import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
|
||||
__all__ = [
|
||||
'ONNXRuntimeRecognizer', 'ONNXRuntimeDetector', 'TensorRTDetector',
|
||||
'TensorRTRecognizer'
|
||||
]
|
|
@ -20,7 +20,7 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = setuptools
|
||||
known_first_party = mmocr
|
||||
known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
|
||||
known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,packaging,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
|
|
|
@ -0,0 +1,220 @@
|
|||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmdet.models import build_detector
|
||||
from packaging import version
|
||||
|
||||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ == 'parrots', reason='skip parrots.')
|
||||
@pytest.mark.skipif(
|
||||
version.parse(torch.__version__) < version.parse('1.4.0'),
|
||||
reason='skip if torch=1.3.x')
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='skip if on cpu device')
|
||||
@pytest.mark.importorskip('onnxruntime')
|
||||
@pytest.mark.importorskip('tensorrt')
|
||||
@pytest.mark.importorskip('mmcv.tensorrt')
|
||||
def test_detector_wraper():
|
||||
import onnxruntime as ort # noqa: F401
|
||||
import tensorrt as trt
|
||||
from mmcv.tensorrt import (onnx2trt, save_trt_engine)
|
||||
|
||||
onnx_path = 'tmp.onnx'
|
||||
cfg = dict(
|
||||
model=dict(
|
||||
type='DBNet',
|
||||
pretrained='torchvision://resnet18',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
style='caffe'),
|
||||
neck=dict(
|
||||
type='FPNC',
|
||||
in_channels=[64, 128, 256, 512],
|
||||
lateral_channels=256),
|
||||
bbox_head=dict(
|
||||
type='DBHead',
|
||||
text_repr_type='quad',
|
||||
in_channels=256,
|
||||
loss=dict(type='DBLoss', alpha=5.0, beta=10.0,
|
||||
bbce_loss=True)),
|
||||
train_cfg=None,
|
||||
test_cfg=None))
|
||||
|
||||
cfg = mmcv.Config(cfg)
|
||||
|
||||
pytorch_model = build_detector(cfg.model, None, None)
|
||||
|
||||
# prepare data
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
img_metas = [{
|
||||
'img_shape': [1, 3, 224, 224],
|
||||
'ori_shape': [1, 3, 224, 224],
|
||||
'pad_shape': [1, 3, 224, 224],
|
||||
'filename': None,
|
||||
'scale_factor': np.array([1, 1, 1, 1])
|
||||
}]
|
||||
|
||||
pytorch_model.forward = pytorch_model.forward_dummy
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
pytorch_model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
verbose=False,
|
||||
opset_version=11)
|
||||
|
||||
# TensorRT part
|
||||
def get_GiB(x: int):
|
||||
"""return x GiB."""
|
||||
return x * (1 << 30)
|
||||
|
||||
trt_path = onnx_path.replace('.onnx', '.trt')
|
||||
min_shape = [1, 3, 224, 224]
|
||||
max_shape = [1, 3, 224, 224]
|
||||
# create trt engine and wraper
|
||||
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
|
||||
max_workspace_size = get_GiB(1)
|
||||
trt_engine = onnx2trt(
|
||||
onnx_path,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.ERROR,
|
||||
fp16_mode=False,
|
||||
max_workspace_size=max_workspace_size)
|
||||
save_trt_engine(trt_engine, trt_path)
|
||||
print(f'Successfully created TensorRT engine: {trt_path}')
|
||||
|
||||
wrap_onnx = ONNXRuntimeDetector(onnx_path, cfg, 0)
|
||||
wrap_trt = TensorRTDetector(trt_path, cfg, 0)
|
||||
# os.remove(onnx_path)
|
||||
assert isinstance(wrap_onnx, ONNXRuntimeDetector)
|
||||
assert isinstance(wrap_trt, TensorRTDetector)
|
||||
|
||||
with torch.no_grad():
|
||||
onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
|
||||
trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
|
||||
|
||||
assert isinstance(onnx_outputs[0], dict)
|
||||
assert isinstance(trt_outputs[0], dict)
|
||||
assert 'boundary_result' in onnx_outputs[0]
|
||||
assert 'boundary_result' in trt_outputs[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ == 'parrots', reason='skip parrots.')
|
||||
@pytest.mark.skipif(
|
||||
version.parse(torch.__version__) < version.parse('1.4.0'),
|
||||
reason='skip if torch=1.3.x')
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='skip if on cpu device')
|
||||
@pytest.mark.importorskip('onnxruntime')
|
||||
@pytest.mark.importorskip('tensorrt')
|
||||
@pytest.mark.importorskip('mmcv.tensorrt')
|
||||
def test_recognizer_wraper():
|
||||
import onnxruntime as ort # noqa: F401
|
||||
import tensorrt as trt
|
||||
from mmcv.tensorrt import (onnx2trt, save_trt_engine)
|
||||
|
||||
onnx_path = 'tmp.onnx'
|
||||
cfg = dict(
|
||||
label_convertor=dict(
|
||||
type='CTCConvertor',
|
||||
dict_type='DICT36',
|
||||
with_unknown=False,
|
||||
lower=True),
|
||||
model=dict(
|
||||
type='CRNNNet',
|
||||
preprocessor=None,
|
||||
backbone=dict(
|
||||
type='VeryDeepVgg', leaky_relu=False, input_channels=1),
|
||||
encoder=None,
|
||||
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||
loss=dict(type='CTCLoss'),
|
||||
label_convertor=dict(
|
||||
type='CTCConvertor',
|
||||
dict_type='DICT36',
|
||||
with_unknown=False,
|
||||
lower=True),
|
||||
pretrained=None),
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
|
||||
cfg = mmcv.Config(cfg)
|
||||
|
||||
pytorch_model = build_detector(cfg.model, None, None)
|
||||
|
||||
# prepare data
|
||||
inputs = torch.rand(1, 1, 32, 32)
|
||||
img_metas = [{
|
||||
'img_shape': [1, 1, 32, 32],
|
||||
'ori_shape': [1, 1, 32, 32],
|
||||
'pad_shape': [1, 1, 32, 32],
|
||||
'filename': None,
|
||||
'scale_factor': np.array([1, 1, 1, 1])
|
||||
}]
|
||||
|
||||
pytorch_model.forward = partial(
|
||||
pytorch_model.forward,
|
||||
img_metas=img_metas,
|
||||
return_loss=False,
|
||||
rescale=True)
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
pytorch_model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
verbose=False,
|
||||
opset_version=11)
|
||||
|
||||
# TensorRT part
|
||||
def get_GiB(x: int):
|
||||
"""return x GiB."""
|
||||
return x * (1 << 30)
|
||||
|
||||
trt_path = onnx_path.replace('.onnx', '.trt')
|
||||
min_shape = [1, 1, 32, 32]
|
||||
max_shape = [1, 1, 32, 32]
|
||||
# create trt engine and wraper
|
||||
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
|
||||
max_workspace_size = get_GiB(1)
|
||||
trt_engine = onnx2trt(
|
||||
onnx_path,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.ERROR,
|
||||
fp16_mode=False,
|
||||
max_workspace_size=max_workspace_size)
|
||||
save_trt_engine(trt_engine, trt_path)
|
||||
print(f'Successfully created TensorRT engine: {trt_path}')
|
||||
|
||||
wrap_onnx = ONNXRuntimeRecognizer(onnx_path, cfg, 0)
|
||||
wrap_trt = TensorRTRecognizer(trt_path, cfg, 0)
|
||||
# os.remove(onnx_path)
|
||||
assert isinstance(wrap_onnx, ONNXRuntimeRecognizer)
|
||||
assert isinstance(wrap_trt, TensorRTRecognizer)
|
||||
|
||||
with torch.no_grad():
|
||||
onnx_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
|
||||
trt_outputs = wrap_onnx.simple_test(inputs, img_metas, rescale=False)
|
||||
|
||||
assert isinstance(onnx_outputs[0], dict)
|
||||
assert isinstance(trt_outputs[0], dict)
|
||||
assert 'text' in onnx_outputs[0]
|
||||
assert 'text' in trt_outputs[0]
|
|
@ -0,0 +1,92 @@
|
|||
import argparse
|
||||
|
||||
from mmcv import Config
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmdet.apis import single_gpu_test
|
||||
|
||||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
from mmocr.datasets import build_dataloader, build_dataset
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMOCR test (and eval) a onnx or tensorrt model.')
|
||||
parser.add_argument('model_config', type=str, help='Config file.')
|
||||
parser.add_argument(
|
||||
'model_file', type=str, help='Input file name for evaluation.')
|
||||
parser.add_argument(
|
||||
'model_type',
|
||||
type=str,
|
||||
help='Detection or recognition model to deploy.',
|
||||
choices=['recog', 'det'])
|
||||
parser.add_argument(
|
||||
'backend',
|
||||
type=str,
|
||||
help='Which backend to test, TensorRT or ONNXRuntime.',
|
||||
choices=['TensorRT', 'ONNXRuntime'])
|
||||
parser.add_argument(
|
||||
'--eval',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='The evaluation metrics, which depends on the dataset, e.g.,'
|
||||
'"bbox", "seg", "proposal" for COCO, and "mAP", "recall" for'
|
||||
'PASCAL VOC.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.device == 'cpu':
|
||||
args.device = None
|
||||
|
||||
cfg = Config.fromfile(args.model_config)
|
||||
|
||||
# build the model
|
||||
if args.model_type == 'det':
|
||||
if args.backend == 'TensorRT':
|
||||
model = TensorRTDetector(args.model_file, cfg, 0)
|
||||
else:
|
||||
model = ONNXRuntimeDetector(args.model_file, cfg, 0)
|
||||
else:
|
||||
if args.backend == 'TensorRT':
|
||||
model = TensorRTRecognizer(args.model_file, cfg, 0)
|
||||
else:
|
||||
model = ONNXRuntimeRecognizer(args.model_file, cfg, 0)
|
||||
|
||||
# build the dataloader
|
||||
samples_per_gpu = 1
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=False,
|
||||
shuffle=False)
|
||||
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
outputs = single_gpu_test(model, data_loader)
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
kwargs = {}
|
||||
if args.eval:
|
||||
eval_kwargs = cfg.get('evaluation', {}).copy()
|
||||
# hard-code way to remove EvalHook args
|
||||
for key in [
|
||||
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
|
||||
'rule'
|
||||
]:
|
||||
eval_kwargs.pop(key, None)
|
||||
eval_kwargs.update(dict(metric=args.eval, **kwargs))
|
||||
print(dataset.evaluate(outputs, **eval_kwargs))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -11,11 +11,9 @@ from mmcv.parallel import collate
|
|||
from mmcv.tensorrt import is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from tools.deployment.deploy_helper import (ONNXRuntimeDetector,
|
||||
ONNXRuntimeRecognizer,
|
||||
TensorRTDetector,
|
||||
TensorRTRecognizer)
|
||||
|
||||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||
|
||||
|
||||
|
|
|
@ -9,10 +9,9 @@ from mmcv.parallel import collate
|
|||
from mmdet.apis import init_detector
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from tools.deployment.deploy_helper import (ONNXRuntimeDetector,
|
||||
ONNXRuntimeRecognizer)
|
||||
from torch import nn
|
||||
|
||||
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
|
||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue