mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[UnitTest] mmocr unittest (#130)
* WIP test_mmocr 8 out of 20 * test_mmocr_export * test mmocr apis * add test data * add mmocr model unittest 5 passed 1 failed * finish mmocr unittest * fix lint * fix yapf * fix isort * fix flake8 * fix docformatter * fix docformatter * try to fix unittest after merge master * Change test.py for backend.DEFAULT * fix flake8 * fix ut * fix yapf * fix ut build * fix yapf * fix mmocr_export ut * fix mmocr_apis ort not cuda * remove explicit .forward * remove backendwrapper * simplify the crnn and dbnet config * simplify instance_test.json * add another case of decoder * increase coverage of test_mmocr_models base_recognizer * improve coverage * improve encode_decoder coverage * reply for grimoire codereview * what if not check cuda? * remove image data * reply to runningleon code review * fix fpnc * fix lint * try to fix CI UT error * fix fpnc with and wo custom ops * fix yapf * skip fpnc when cuda is not ready in ci * reply for code review * reply for code review * fix yapf * reply for code review * fix yapf * fix conflict * remove unmatched data path * remove unnecessary comments
This commit is contained in:
parent
c5a87fb1bc
commit
9e227b228b
@ -1,6 +1,7 @@
|
|||||||
from typing import Iterable, Sequence, Union
|
from typing import Iterable, Sequence, Union
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmdet.models.builder import DETECTORS
|
from mmdet.models.builder import DETECTORS
|
||||||
from mmocr.datasets import DATASETS
|
from mmocr.datasets import DATASETS
|
||||||
@ -60,6 +61,8 @@ class DeployBaseTextDetector(TextDetectorMixin, SingleStageTextDetector):
|
|||||||
list: A list of predictions.
|
list: A list of predictions.
|
||||||
"""
|
"""
|
||||||
pred = self.forward_of_backend(img, img_metas, *args, **kwargs)
|
pred = self.forward_of_backend(img, img_metas, *args, **kwargs)
|
||||||
|
if isinstance(pred, np.ndarray):
|
||||||
|
pred = torch.from_numpy(pred[0])
|
||||||
if len(img_metas) > 1:
|
if len(img_metas) > 1:
|
||||||
boundaries = [
|
boundaries = [
|
||||||
self.bbox_head.get_boundary(
|
self.bbox_head.get_boundary(
|
||||||
@ -187,7 +190,6 @@ class ONNXRuntimeDetector(DeployBaseTextDetector):
|
|||||||
np.ndarray: Prediction of input model.
|
np.ndarray: Prediction of input model.
|
||||||
"""
|
"""
|
||||||
onnx_pred = self.model({'input': img})
|
onnx_pred = self.model({'input': img})
|
||||||
onnx_pred = torch.from_numpy(onnx_pred[0])
|
|
||||||
return onnx_pred
|
return onnx_pred
|
||||||
|
|
||||||
|
|
||||||
@ -223,7 +225,6 @@ class ONNXRuntimeRecognizer(DeployBaseRecognizer):
|
|||||||
np.ndarray: Prediction of input model.
|
np.ndarray: Prediction of input model.
|
||||||
"""
|
"""
|
||||||
onnx_pred = self.model({'input': img})
|
onnx_pred = self.model({'input': img})
|
||||||
onnx_pred = torch.from_numpy(onnx_pred[0])
|
|
||||||
return onnx_pred
|
return onnx_pred
|
||||||
|
|
||||||
|
|
||||||
@ -403,8 +404,8 @@ class PPLDetector(DeployBaseTextDetector):
|
|||||||
"""
|
"""
|
||||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||||
ppl_pred = self.model({'input': img})
|
ppl_pred = self.model({'input': img})
|
||||||
|
if isinstance(ppl_pred[0], np.ndarray):
|
||||||
ppl_pred = torch.from_numpy(ppl_pred[0])
|
ppl_pred = torch.from_numpy(ppl_pred[0])
|
||||||
return ppl_pred
|
return ppl_pred
|
||||||
|
|
||||||
|
|
||||||
@ -442,7 +443,8 @@ class PPLRecognizer(DeployBaseRecognizer):
|
|||||||
"""
|
"""
|
||||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||||
ppl_pred = self.model({'input': img})[0]
|
ppl_pred = self.model({'input': img})[0]
|
||||||
ppl_pred = torch.from_numpy(ppl_pred[0])
|
if isinstance(ppl_pred[0], np.ndarray):
|
||||||
|
ppl_pred = torch.from_numpy(ppl_pred[0])
|
||||||
return ppl_pred
|
return ppl_pred
|
||||||
|
|
||||||
|
|
||||||
|
@ -295,6 +295,8 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
|
|||||||
backend_model = openvino_apis.OpenVINOWrapper(openvino_file_path)
|
backend_model = openvino_apis.OpenVINOWrapper(openvino_file_path)
|
||||||
|
|
||||||
backend_feats = flatten_model_inputs
|
backend_feats = flatten_model_inputs
|
||||||
|
elif backend == Backend.DEFAULT:
|
||||||
|
return ctx_outputs, False
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f'Unimplemented backend type: {backend.value}')
|
f'Unimplemented backend type: {backend.value}')
|
||||||
|
15
tests/test_mmocr/data/config/crnn.py
Executable file
15
tests/test_mmocr/data/config/crnn.py
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
_base_ = []
|
||||||
|
|
||||||
|
# model
|
||||||
|
label_convertor = dict(
|
||||||
|
type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True)
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='CRNNNet',
|
||||||
|
preprocessor=None,
|
||||||
|
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
|
||||||
|
encoder=None,
|
||||||
|
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||||
|
loss=dict(type='CTCLoss'),
|
||||||
|
label_convertor=label_convertor,
|
||||||
|
pretrained=None)
|
49
tests/test_mmocr/data/config/dbnet.py
Executable file
49
tests/test_mmocr/data/config/dbnet.py
Executable file
@ -0,0 +1,49 @@
|
|||||||
|
model = dict(
|
||||||
|
type='DBNet',
|
||||||
|
backbone=dict(
|
||||||
|
type='mmdet.ResNet',
|
||||||
|
depth=18,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
frozen_stages=-1,
|
||||||
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
|
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
|
||||||
|
norm_eval=False,
|
||||||
|
style='caffe'),
|
||||||
|
neck=dict(type='FPNC', in_channels=[2, 4, 8, 16], lateral_channels=8),
|
||||||
|
bbox_head=dict(
|
||||||
|
type='DBHead',
|
||||||
|
text_repr_type='quad',
|
||||||
|
in_channels=8,
|
||||||
|
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)),
|
||||||
|
train_cfg=None,
|
||||||
|
test_cfg=None)
|
||||||
|
|
||||||
|
dataset_type = 'IcdarDataset'
|
||||||
|
data_root = 'data/icdar2015'
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(128, 64),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', img_scale=(256, 128), keep_ratio=True),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=16,
|
||||||
|
test_dataloader=dict(samples_per_gpu=1),
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ann_file=data_root + '/instances_test.json',
|
||||||
|
img_prefix=data_root + '/imgs',
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
evaluation = dict(interval=100, metric='hmean-iou')
|
1
tests/test_mmocr/data/icdar2015/instances_test.json
Executable file
1
tests/test_mmocr/data/icdar2015/instances_test.json
Executable file
@ -0,0 +1 @@
|
|||||||
|
{"images": [], "categories": [], "annotations": []}
|
291
tests/test_mmocr/test_mmocr_apis.py
Executable file
291
tests/test_mmocr/test_mmocr_apis.py
Executable file
@ -0,0 +1,291 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import mmdeploy.apis.ncnn as ncnn_apis
|
||||||
|
import mmdeploy.apis.onnxruntime as ort_apis
|
||||||
|
import mmdeploy.apis.ppl as ppl_apis
|
||||||
|
import mmdeploy.apis.tensorrt as trt_apis
|
||||||
|
from mmdeploy.mmocr.apis.inference import get_classes_from_config
|
||||||
|
from mmdeploy.mmocr.apis.visualize import show_result
|
||||||
|
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||||
|
def test_TensorRTDetector():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure TensorRTDetector can use TRTWrapper inside itself
|
||||||
|
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||||
|
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
'output': torch.rand(1, 3, 64, 64).cuda(),
|
||||||
|
}
|
||||||
|
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
from mmdeploy.mmocr.apis.inference import TensorRTDetector
|
||||||
|
model_config = mmcv.Config.fromfile(
|
||||||
|
'tests/test_mmocr/data/config/dbnet.py')
|
||||||
|
trt_detector = TensorRTDetector('', model_config, 0, False)
|
||||||
|
# watch from config
|
||||||
|
imgs = [torch.rand(1, 3, 64, 64).cuda()]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [64, 64, 3],
|
||||||
|
'img_shape': [64, 64, 3],
|
||||||
|
'pad_shape': [64, 64, 3],
|
||||||
|
'scale_factor': [1., 1., 1., 1.],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = trt_detector.forward_of_backend(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output \
|
||||||
|
using TensorRTDetector'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_ONNXRuntimeDetector():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure ONNXRuntimeDetector can use ORTWrapper inside itself
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
outputs = (torch.rand(1, 3, 64, 64))
|
||||||
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
from mmdeploy.mmocr.apis.inference import ONNXRuntimeDetector
|
||||||
|
model_config = mmcv.Config.fromfile(
|
||||||
|
'tests/test_mmocr/data/config/dbnet.py')
|
||||||
|
ort_detector = ONNXRuntimeDetector('', model_config, 0, False)
|
||||||
|
imgs = [torch.rand(1, 3, 64, 64)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [64, 64, 3],
|
||||||
|
'img_shape': [64, 64, 3],
|
||||||
|
'pad_shape': [64, 64, 3],
|
||||||
|
'scale_factor': [1., 1., 1., 1.],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = ort_detector.forward_of_backend(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'ONNXRuntimeDetector'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
|
||||||
|
def test_PPLDetector():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure PPLDetector can use PPLWrapper inside itself
|
||||||
|
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
|
||||||
|
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
|
||||||
|
|
||||||
|
outputs = (torch.rand(1, 3, 64, 64))
|
||||||
|
with SwitchBackendWrapper(PPLWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
from mmdeploy.mmocr.apis.inference import PPLDetector
|
||||||
|
model_config = mmcv.Config.fromfile(
|
||||||
|
'tests/test_mmocr/data/config/dbnet.py')
|
||||||
|
ppl_detector = PPLDetector('', model_config, 0, False)
|
||||||
|
imgs = [torch.rand(1, 3, 64, 64)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [64, 64, 3],
|
||||||
|
'img_shape': [64, 64, 3],
|
||||||
|
'pad_shape': [64, 64, 3],
|
||||||
|
'scale_factor': [1., 1., 1., 1.],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = ppl_detector.forward_of_backend(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using PPLDetector'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
|
||||||
|
def test_NCNNDetector():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure NCNNDetector can use NCNNWrapper inside itself
|
||||||
|
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
|
||||||
|
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
|
||||||
|
|
||||||
|
outputs = {'output': torch.rand(1, 3, 64, 64)}
|
||||||
|
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
from mmdeploy.mmocr.apis.inference import NCNNDetector
|
||||||
|
model_config = mmcv.Config.fromfile(
|
||||||
|
'tests/test_mmocr/data/config/dbnet.py')
|
||||||
|
ncnn_detector = NCNNDetector(['', ''], model_config, 0, False)
|
||||||
|
imgs = [torch.rand(1, 3, 64, 64)]
|
||||||
|
img_metas = [[{
|
||||||
|
'ori_shape': [64, 64, 3],
|
||||||
|
'img_shape': [64, 64, 3],
|
||||||
|
'pad_shape': [64, 64, 3],
|
||||||
|
'scale_factor': [1., 1., 1., 1.],
|
||||||
|
}]]
|
||||||
|
|
||||||
|
results = ncnn_detector.forward_of_backend(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using NCNNDetector'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||||
|
def test_TensorRTRecognizer():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure TensorRTRecognizer can use TRTWrapper inside itself
|
||||||
|
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||||
|
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
'output': torch.rand(1, 9, 37).cuda(),
|
||||||
|
}
|
||||||
|
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
from mmdeploy.mmocr.apis.inference import TensorRTRecognizer
|
||||||
|
model_config = mmcv.Config.fromfile(
|
||||||
|
'tests/test_mmocr/data/config/crnn.py')
|
||||||
|
trt_recognizer = TensorRTRecognizer('', model_config, 0, False)
|
||||||
|
# watch from config
|
||||||
|
imgs = [torch.rand(1, 1, 32, 32).cuda()]
|
||||||
|
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||||
|
|
||||||
|
results = trt_recognizer.forward_of_backend(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using \
|
||||||
|
TensorRTRecognizer'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_ONNXRuntimeRecognizer():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure ONNXRuntimeRecognizer can use ORTWrapper inside itself
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
outputs = (torch.rand(1, 9, 37))
|
||||||
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
from mmdeploy.mmocr.apis.inference import ONNXRuntimeRecognizer
|
||||||
|
model_config = mmcv.Config.fromfile(
|
||||||
|
'tests/test_mmocr/data/config/crnn.py')
|
||||||
|
ort_recognizer = ONNXRuntimeRecognizer('', model_config, 0, False)
|
||||||
|
imgs = [torch.rand(1, 1, 32, 32).numpy()]
|
||||||
|
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||||
|
|
||||||
|
results = ort_recognizer.forward_of_backend(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'ONNXRuntimeRecognizer'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
|
||||||
|
def test_PPLRecognizer():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure PPLRecognizer can use PPLWrapper inside itself
|
||||||
|
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
|
||||||
|
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
|
||||||
|
|
||||||
|
outputs = (torch.rand(1, 9, 37))
|
||||||
|
with SwitchBackendWrapper(PPLWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
from mmdeploy.mmocr.apis.inference import PPLRecognizer
|
||||||
|
model_config = mmcv.Config.fromfile(
|
||||||
|
'tests/test_mmocr/data/config/crnn.py')
|
||||||
|
ppl_recognizer = PPLRecognizer('', model_config, 0, False)
|
||||||
|
imgs = [torch.rand(1, 1, 32, 32)]
|
||||||
|
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||||
|
|
||||||
|
results = ppl_recognizer.forward_of_backend(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using PPLRecognizer'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
|
||||||
|
def test_NCNNRecognizer():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure NCNNPSSDetector can use NCNNWrapper inside itself
|
||||||
|
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
|
||||||
|
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
|
||||||
|
|
||||||
|
outputs = {'output': torch.rand(1, 9, 37)}
|
||||||
|
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
from mmdeploy.mmocr.apis.inference import NCNNRecognizer
|
||||||
|
model_config = mmcv.Config.fromfile(
|
||||||
|
'tests/test_mmocr/data/config/crnn.py')
|
||||||
|
ncnn_recognizer = NCNNRecognizer(['', ''], model_config, 0, False)
|
||||||
|
imgs = [torch.rand(1, 1, 32, 32)]
|
||||||
|
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||||
|
|
||||||
|
results = ncnn_recognizer.forward_of_backend(imgs, img_metas)
|
||||||
|
assert results is not None, 'failed to get output using NCNNRecognizer'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'task, model',
|
||||||
|
[('TextDetection', 'tests/test_mmocr/data/config/dbnet.py'),
|
||||||
|
('TextRecognition', 'tests/test_mmocr/data/config/crnn.py')])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_build_ocr_processor(task, model):
|
||||||
|
model_cfg = mmcv.Config.fromfile(model)
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='onnxruntime'),
|
||||||
|
codebase_config=dict(type='mmocr', task=task)))
|
||||||
|
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
|
from mmdeploy.apis.utils import init_backend_model
|
||||||
|
result_model = init_backend_model([''], model_cfg, deploy_cfg, -1)
|
||||||
|
assert result_model is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('model', ['tests/test_mmocr/data/config/dbnet.py'])
|
||||||
|
def test_get_classes_from_config(model):
|
||||||
|
get_classes_from_config(model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'task, model',
|
||||||
|
[('TextDetection', 'tests/test_mmocr/data/config/dbnet.py')])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_show_result(task, model):
|
||||||
|
model_cfg = mmcv.Config.fromfile(model)
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='onnxruntime'),
|
||||||
|
codebase_config=dict(type='mmocr', task=task)))
|
||||||
|
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
|
from mmdeploy.apis.utils import init_backend_model
|
||||||
|
detector = init_backend_model([''], model_cfg, deploy_cfg, -1)
|
||||||
|
img = np.random.random((64, 64, 3))
|
||||||
|
result = {'boundary_result': [[1, 2, 3, 4, 5], [2, 2, 0, 4, 5]]}
|
||||||
|
import os.path
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from mmdeploy.utils.constants import Backend
|
||||||
|
with tempfile.TemporaryDirectory() as dir:
|
||||||
|
filename = dir + 'tmp.jpg'
|
||||||
|
show_result(
|
||||||
|
detector,
|
||||||
|
img,
|
||||||
|
result,
|
||||||
|
filename,
|
||||||
|
Backend.ONNXRUNTIME,
|
||||||
|
show=False)
|
||||||
|
assert os.path.exists(filename)
|
135
tests/test_mmocr/test_mmocr_export.py
Normal file
135
tests/test_mmocr/test_mmocr_export.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mmdeploy.apis.utils import build_dataloader, build_dataset, create_input
|
||||||
|
from mmdeploy.utils.constants import Codebase, Task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('task', [Task.TEXT_DETECTION, Task.TEXT_RECOGNITION])
|
||||||
|
def test_create_input(task):
|
||||||
|
if task == Task.TEXT_DETECTION:
|
||||||
|
test = dict(
|
||||||
|
type='IcdarDataset',
|
||||||
|
pipeline=[{
|
||||||
|
'type': 'LoadImageFromFile',
|
||||||
|
'color_type': 'color_ignore_orientation'
|
||||||
|
}, {
|
||||||
|
'type':
|
||||||
|
'MultiScaleFlipAug',
|
||||||
|
'img_scale': (128, 64),
|
||||||
|
'flip':
|
||||||
|
False,
|
||||||
|
'transforms': [
|
||||||
|
{
|
||||||
|
'type': 'Resize',
|
||||||
|
'img_scale': (256, 128),
|
||||||
|
'keep_ratio': True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'Normalize',
|
||||||
|
'mean': [123.675, 116.28, 103.53],
|
||||||
|
'std': [58.395, 57.12, 57.375],
|
||||||
|
'to_rgb': True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'Pad',
|
||||||
|
'size_divisor': 32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'ImageToTensor',
|
||||||
|
'keys': ['img']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'Collect',
|
||||||
|
'keys': ['img']
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}])
|
||||||
|
imgs = [np.random.rand(128, 64, 3).astype(np.uint8)]
|
||||||
|
elif task == Task.TEXT_RECOGNITION:
|
||||||
|
test = dict(
|
||||||
|
type='UniformConcatDataset',
|
||||||
|
pipeline=[
|
||||||
|
{
|
||||||
|
'type': 'LoadImageFromFile',
|
||||||
|
'color_type': 'grayscale'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'ResizeOCR',
|
||||||
|
'height': 32,
|
||||||
|
'min_width': 32,
|
||||||
|
'max_width': None,
|
||||||
|
'keep_aspect_ratio': True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'Normalize',
|
||||||
|
'mean': [127],
|
||||||
|
'std': [127]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'DefaultFormatBundle'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'Collect',
|
||||||
|
'keys': ['img'],
|
||||||
|
'meta_keys': ['filename', 'resize_shape', 'valid_ratio']
|
||||||
|
},
|
||||||
|
])
|
||||||
|
imgs = [np.random.random((32, 32, 3)).astype(np.uint8)]
|
||||||
|
data = dict(test=test)
|
||||||
|
model_cfg = mmcv.Config(dict(data=data))
|
||||||
|
inputs = create_input(
|
||||||
|
Codebase.MMOCR,
|
||||||
|
task,
|
||||||
|
model_cfg,
|
||||||
|
imgs,
|
||||||
|
input_shape=imgs[0].shape[0:2],
|
||||||
|
device='cpu')
|
||||||
|
assert inputs is not None, 'Failed to create input'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('task', [Task.TEXT_DETECTION, Task.TEXT_RECOGNITION])
|
||||||
|
def test_build_dataset(task):
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
ann_file, ann_path = tempfile.mkstemp()
|
||||||
|
if task == Task.TEXT_DETECTION:
|
||||||
|
data = dict(
|
||||||
|
test={
|
||||||
|
'type': 'IcdarDataset',
|
||||||
|
'ann_file':
|
||||||
|
'tests/test_mmocr/data/icdar2015/instances_test.json',
|
||||||
|
'img_prefix': 'tests/test_mmocr/data/icdar2015/imgs',
|
||||||
|
'pipeline': [
|
||||||
|
{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
elif task == Task.TEXT_RECOGNITION:
|
||||||
|
data = dict(
|
||||||
|
test={
|
||||||
|
'type': 'OCRDataset',
|
||||||
|
'ann_file': ann_path,
|
||||||
|
'img_prefix': '',
|
||||||
|
'loader': {
|
||||||
|
'type': 'HardDiskLoader',
|
||||||
|
'repeat': 1,
|
||||||
|
'parser': {
|
||||||
|
'type': 'LineStrParser',
|
||||||
|
'keys': ['filename', 'text'],
|
||||||
|
'keys_idx': [0, 1],
|
||||||
|
'separator': ' '
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'pipeline': [],
|
||||||
|
'test_mode': True
|
||||||
|
})
|
||||||
|
dataset_cfg = mmcv.Config(dict(data=data))
|
||||||
|
dataset = build_dataset(
|
||||||
|
Codebase.MMOCR, dataset_cfg=dataset_cfg, dataset_type='test')
|
||||||
|
assert dataset is not None, 'Failed to build dataset'
|
||||||
|
dataloader = build_dataloader(Codebase.MMOCR, dataset, 1, 1)
|
||||||
|
os.close(ann_file)
|
||||||
|
assert dataloader is not None, 'Failed to build dataloader'
|
409
tests/test_mmocr/test_mmocr_models.py
Normal file
409
tests/test_mmocr/test_mmocr_models.py
Normal file
@ -0,0 +1,409 @@
|
|||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from mmocr.models.textdet.necks import FPNC
|
||||||
|
|
||||||
|
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
|
||||||
|
get_rewrite_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class FPNCNeckModel(FPNC):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, init_cfg=None):
|
||||||
|
super().__init__(in_channels, init_cfg=init_cfg)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.neck = FPNC(in_channels, init_cfg=init_cfg)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
neck_inputs = [
|
||||||
|
torch.ones(1, channel, inputs.shape[-2], inputs.shape[-1])
|
||||||
|
for channel in self.in_channels
|
||||||
|
]
|
||||||
|
output = self.neck.forward(neck_inputs)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_bidirectionallstm_model():
|
||||||
|
from mmocr.models.textrecog.layers.lstm_layer import BidirectionalLSTM
|
||||||
|
model = BidirectionalLSTM(32, 16, 16)
|
||||||
|
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_single_stage_text_detector_model():
|
||||||
|
from mmocr.models.textdet import SingleStageTextDetector
|
||||||
|
backbone = dict(
|
||||||
|
type='mmdet.ResNet',
|
||||||
|
depth=18,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
frozen_stages=-1,
|
||||||
|
norm_cfg=dict(type='BN', requires_grad=True),
|
||||||
|
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
|
||||||
|
norm_eval=False,
|
||||||
|
style='caffe')
|
||||||
|
neck = dict(
|
||||||
|
type='FPNC',
|
||||||
|
in_channels=[64, 128, 256, 512],
|
||||||
|
lateral_channels=4,
|
||||||
|
out_channels=4)
|
||||||
|
bbox_head = dict(
|
||||||
|
type='DBHead',
|
||||||
|
text_repr_type='quad',
|
||||||
|
in_channels=16,
|
||||||
|
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True))
|
||||||
|
model = SingleStageTextDetector(backbone, neck, bbox_head)
|
||||||
|
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_encode_decode_recognizer_model():
|
||||||
|
from mmocr.models.textrecog import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
cfg = dict(
|
||||||
|
preprocessor=None,
|
||||||
|
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
|
||||||
|
encoder=dict(type='TFEncoder'),
|
||||||
|
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||||
|
loss=dict(type='CTCLoss'),
|
||||||
|
label_convertor=dict(
|
||||||
|
type='CTCConvertor',
|
||||||
|
dict_type='DICT36',
|
||||||
|
with_unknown=False,
|
||||||
|
lower=True),
|
||||||
|
pretrained=None)
|
||||||
|
|
||||||
|
model = EncodeDecodeRecognizer(
|
||||||
|
backbone=cfg['backbone'],
|
||||||
|
encoder=cfg['encoder'],
|
||||||
|
decoder=cfg['decoder'],
|
||||||
|
loss=cfg['loss'],
|
||||||
|
label_convertor=cfg['label_convertor'])
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_crnn_decoder_model(rnn_flag):
|
||||||
|
from mmocr.models.textrecog.decoders import CRNNDecoder
|
||||||
|
model = CRNNDecoder(32, 4, rnn_flag=rnn_flag)
|
||||||
|
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_fpnc_neck_model():
|
||||||
|
model = FPNCNeckModel([2, 4, 8, 16])
|
||||||
|
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_recognizer_model():
|
||||||
|
from mmocr.models.textrecog import CRNNNet
|
||||||
|
|
||||||
|
cfg = dict(
|
||||||
|
preprocessor=None,
|
||||||
|
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
|
||||||
|
encoder=None,
|
||||||
|
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||||
|
loss=dict(type='CTCLoss'),
|
||||||
|
label_convertor=dict(
|
||||||
|
type='CTCConvertor',
|
||||||
|
dict_type='DICT36',
|
||||||
|
with_unknown=False,
|
||||||
|
lower=True),
|
||||||
|
pretrained=None)
|
||||||
|
|
||||||
|
model = CRNNNet(
|
||||||
|
backbone=cfg['backbone'],
|
||||||
|
decoder=cfg['decoder'],
|
||||||
|
loss=cfg['loss'],
|
||||||
|
label_convertor=cfg['label_convertor'])
|
||||||
|
model.requires_grad_(False)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('backend_type', ['ncnn'])
|
||||||
|
def test_bidirectionallstm(backend_type):
|
||||||
|
"""Test forward rewrite of bidirectionallstm."""
|
||||||
|
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||||
|
bilstm = get_bidirectionallstm_model()
|
||||||
|
bilstm.cpu().eval()
|
||||||
|
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type=backend_type),
|
||||||
|
onnx_config=dict(input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmocr',
|
||||||
|
task='TextRecognition',
|
||||||
|
)))
|
||||||
|
|
||||||
|
input = torch.rand(1, 1, 32)
|
||||||
|
|
||||||
|
# to get outputs of pytorch model
|
||||||
|
model_inputs = {
|
||||||
|
'input': input,
|
||||||
|
}
|
||||||
|
model_outputs = get_model_outputs(bilstm, 'forward', model_inputs)
|
||||||
|
|
||||||
|
# to get outputs of onnx model after rewrite
|
||||||
|
wrapped_model = WrapModel(bilstm, 'forward')
|
||||||
|
rewrite_inputs = {'input': input}
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze()
|
||||||
|
assert np.allclose(
|
||||||
|
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_test_of_single_stage_text_detector():
|
||||||
|
"""Test simple_test single_stage_text_detector."""
|
||||||
|
single_stage_text_detector = get_single_stage_text_detector_model()
|
||||||
|
single_stage_text_detector.eval()
|
||||||
|
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='default'),
|
||||||
|
onnx_config=dict(input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmocr',
|
||||||
|
task='TextDetection',
|
||||||
|
)))
|
||||||
|
|
||||||
|
input = torch.rand(1, 3, 64, 64)
|
||||||
|
img_metas = [{
|
||||||
|
'ori_shape': [64, 64, 3],
|
||||||
|
'img_shape': [64, 64, 3],
|
||||||
|
'pad_shape': [64, 64, 3],
|
||||||
|
'scale_factor': [1., 1., 1., 1],
|
||||||
|
}]
|
||||||
|
|
||||||
|
x = single_stage_text_detector.extract_feat(input)
|
||||||
|
model_outputs = single_stage_text_detector.bbox_head(x)
|
||||||
|
|
||||||
|
wrapped_model = WrapModel(single_stage_text_detector, 'simple_test')
|
||||||
|
rewrite_inputs = {'img': input, 'img_metas': img_metas[0]}
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze()
|
||||||
|
assert np.allclose(
|
||||||
|
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('backend_type', ['ncnn'])
|
||||||
|
@pytest.mark.parametrize('rnn_flag', [True, False])
|
||||||
|
def test_crnndecoder(backend_type, rnn_flag):
|
||||||
|
"""Test forward rewrite of crnndecoder."""
|
||||||
|
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||||
|
crnn_decoder = get_crnn_decoder_model(rnn_flag)
|
||||||
|
crnn_decoder.cpu().eval()
|
||||||
|
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type=backend_type),
|
||||||
|
onnx_config=dict(input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmocr',
|
||||||
|
task='TextRecognition',
|
||||||
|
)))
|
||||||
|
|
||||||
|
input = torch.rand(1, 32, 1, 64)
|
||||||
|
out_enc = None
|
||||||
|
targets_dict = None
|
||||||
|
img_metas = None
|
||||||
|
|
||||||
|
# to get outputs of pytorch model
|
||||||
|
model_inputs = {
|
||||||
|
'feat': input,
|
||||||
|
'out_enc': out_enc,
|
||||||
|
'targets_dict': targets_dict,
|
||||||
|
'img_metas': img_metas
|
||||||
|
}
|
||||||
|
model_outputs = get_model_outputs(crnn_decoder, 'forward_train',
|
||||||
|
model_inputs)
|
||||||
|
|
||||||
|
# to get outputs of onnx model after rewrite
|
||||||
|
wrapped_model = WrapModel(
|
||||||
|
crnn_decoder,
|
||||||
|
'forward_train',
|
||||||
|
out_enc=out_enc,
|
||||||
|
targets_dict=targets_dict,
|
||||||
|
img_metas=img_metas)
|
||||||
|
rewrite_inputs = {'feat': input}
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze()
|
||||||
|
assert np.allclose(
|
||||||
|
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'img_metas', [[None], [{
|
||||||
|
'resize_shape': [32, 32],
|
||||||
|
'valid_ratio': 1.0
|
||||||
|
}]])
|
||||||
|
@pytest.mark.parametrize('is_dynamic', [True, False])
|
||||||
|
def test_forward_of_base_recognizer(img_metas, is_dynamic):
|
||||||
|
"""Test forward base_recognizer."""
|
||||||
|
base_recognizer = get_base_recognizer_model()
|
||||||
|
base_recognizer.eval()
|
||||||
|
|
||||||
|
if not is_dynamic:
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='ncnn'),
|
||||||
|
onnx_config=dict(input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmocr',
|
||||||
|
task='TextRecognition',
|
||||||
|
)))
|
||||||
|
else:
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='ncnn'),
|
||||||
|
onnx_config=dict(
|
||||||
|
input_shape=None,
|
||||||
|
dynamic_axes={
|
||||||
|
'input': {
|
||||||
|
0: 'batch',
|
||||||
|
2: 'height',
|
||||||
|
3: 'width'
|
||||||
|
},
|
||||||
|
'output': {
|
||||||
|
0: 'batch',
|
||||||
|
2: 'height',
|
||||||
|
3: 'width'
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmocr',
|
||||||
|
task='TextRecognition',
|
||||||
|
)))
|
||||||
|
|
||||||
|
input = torch.rand(1, 1, 32, 32)
|
||||||
|
|
||||||
|
feat = base_recognizer.extract_feat(input)
|
||||||
|
out_enc = None
|
||||||
|
if base_recognizer.encoder is not None:
|
||||||
|
out_enc = base_recognizer.encoder(feat, img_metas)
|
||||||
|
model_outputs = base_recognizer.decoder(
|
||||||
|
feat, out_enc, None, img_metas, train_mode=False)
|
||||||
|
wrapped_model = WrapModel(
|
||||||
|
base_recognizer, 'forward', img_metas=img_metas[0])
|
||||||
|
rewrite_inputs = {
|
||||||
|
'img': input,
|
||||||
|
}
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze()
|
||||||
|
assert np.allclose(
|
||||||
|
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_test_of_encode_decode_recognizer():
|
||||||
|
"""Test simple_test encode_decode_recognizer."""
|
||||||
|
encode_decode_recognizer = get_encode_decode_recognizer_model()
|
||||||
|
encode_decode_recognizer.eval()
|
||||||
|
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='default'),
|
||||||
|
onnx_config=dict(input_shape=None),
|
||||||
|
codebase_config=dict(
|
||||||
|
type='mmocr',
|
||||||
|
task='TextRecognition',
|
||||||
|
)))
|
||||||
|
|
||||||
|
input = torch.rand(1, 1, 32, 32)
|
||||||
|
img_metas = [{'resize_shape': [32, 32], 'valid_ratio': 1.0}]
|
||||||
|
|
||||||
|
feat = encode_decode_recognizer.extract_feat(input)
|
||||||
|
out_enc = None
|
||||||
|
if encode_decode_recognizer.encoder is not None:
|
||||||
|
out_enc = encode_decode_recognizer.encoder(feat, img_metas)
|
||||||
|
model_outputs = encode_decode_recognizer.decoder(
|
||||||
|
feat, out_enc, None, img_metas, train_mode=False)
|
||||||
|
|
||||||
|
wrapped_model = WrapModel(
|
||||||
|
encode_decode_recognizer, 'simple_test', img_metas=img_metas)
|
||||||
|
rewrite_inputs = {'img': input}
|
||||||
|
rewrite_outputs = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze()
|
||||||
|
assert np.allclose(
|
||||||
|
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.parametrize('backend_type', ['tensorrt'])
|
||||||
|
def test_forward_of_fpnc(backend_type):
|
||||||
|
"""Test forward rewrite of fpnc."""
|
||||||
|
fpnc = get_fpnc_neck_model()
|
||||||
|
fpnc.eval()
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(
|
||||||
|
type=backend_type,
|
||||||
|
common_config=dict(max_workspace_size=1 << 30),
|
||||||
|
model_inputs=[
|
||||||
|
dict(
|
||||||
|
input_shapes=dict(
|
||||||
|
input=dict(
|
||||||
|
min_shape=[1, 3, 64, 64],
|
||||||
|
opt_shape=[1, 3, 64, 64],
|
||||||
|
max_shape=[1, 3, 64, 64])))
|
||||||
|
]),
|
||||||
|
onnx_config=dict(input_shape=[64, 64], output_names=['output']),
|
||||||
|
codebase_config=dict(type='mmocr', task='TextDetection')))
|
||||||
|
|
||||||
|
input = torch.rand(1, 3, 64, 64).cuda()
|
||||||
|
model_inputs = {
|
||||||
|
'inputs': input,
|
||||||
|
}
|
||||||
|
model_outputs = get_model_outputs(fpnc, 'forward', model_inputs)
|
||||||
|
wrapped_model = WrapModel(fpnc, 'forward')
|
||||||
|
rewrite_inputs = {
|
||||||
|
'inputs': input,
|
||||||
|
}
|
||||||
|
rewrite_outputs, is_need_name = get_rewrite_outputs(
|
||||||
|
wrapped_model=wrapped_model,
|
||||||
|
model_inputs=rewrite_inputs,
|
||||||
|
deploy_cfg=deploy_cfg)
|
||||||
|
if is_need_name:
|
||||||
|
model_output = model_outputs[0].squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_outputs['output'].squeeze().cpu().numpy()
|
||||||
|
assert np.allclose(
|
||||||
|
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||||
|
else:
|
||||||
|
for model_output, rewrite_output in zip(model_outputs,
|
||||||
|
rewrite_outputs):
|
||||||
|
model_output = model_output.squeeze().cpu().numpy()
|
||||||
|
rewrite_output = rewrite_output.squeeze().cpu().numpy()
|
||||||
|
assert np.allclose(
|
||||||
|
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
Loading…
x
Reference in New Issue
Block a user