mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[UnitTest] mmocr unittest (#130)
* WIP test_mmocr 8 out of 20 * test_mmocr_export * test mmocr apis * add test data * add mmocr model unittest 5 passed 1 failed * finish mmocr unittest * fix lint * fix yapf * fix isort * fix flake8 * fix docformatter * fix docformatter * try to fix unittest after merge master * Change test.py for backend.DEFAULT * fix flake8 * fix ut * fix yapf * fix ut build * fix yapf * fix mmocr_export ut * fix mmocr_apis ort not cuda * remove explicit .forward * remove backendwrapper * simplify the crnn and dbnet config * simplify instance_test.json * add another case of decoder * increase coverage of test_mmocr_models base_recognizer * improve coverage * improve encode_decoder coverage * reply for grimoire codereview * what if not check cuda? * remove image data * reply to runningleon code review * fix fpnc * fix lint * try to fix CI UT error * fix fpnc with and wo custom ops * fix yapf * skip fpnc when cuda is not ready in ci * reply for code review * reply for code review * fix yapf * reply for code review * fix yapf * fix conflict * remove unmatched data path * remove unnecessary comments
This commit is contained in:
parent
c5a87fb1bc
commit
9e227b228b
@ -1,6 +1,7 @@
|
||||
from typing import Iterable, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmdet.models.builder import DETECTORS
|
||||
from mmocr.datasets import DATASETS
|
||||
@ -60,6 +61,8 @@ class DeployBaseTextDetector(TextDetectorMixin, SingleStageTextDetector):
|
||||
list: A list of predictions.
|
||||
"""
|
||||
pred = self.forward_of_backend(img, img_metas, *args, **kwargs)
|
||||
if isinstance(pred, np.ndarray):
|
||||
pred = torch.from_numpy(pred[0])
|
||||
if len(img_metas) > 1:
|
||||
boundaries = [
|
||||
self.bbox_head.get_boundary(
|
||||
@ -187,7 +190,6 @@ class ONNXRuntimeDetector(DeployBaseTextDetector):
|
||||
np.ndarray: Prediction of input model.
|
||||
"""
|
||||
onnx_pred = self.model({'input': img})
|
||||
onnx_pred = torch.from_numpy(onnx_pred[0])
|
||||
return onnx_pred
|
||||
|
||||
|
||||
@ -223,7 +225,6 @@ class ONNXRuntimeRecognizer(DeployBaseRecognizer):
|
||||
np.ndarray: Prediction of input model.
|
||||
"""
|
||||
onnx_pred = self.model({'input': img})
|
||||
onnx_pred = torch.from_numpy(onnx_pred[0])
|
||||
return onnx_pred
|
||||
|
||||
|
||||
@ -403,8 +404,8 @@ class PPLDetector(DeployBaseTextDetector):
|
||||
"""
|
||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||
ppl_pred = self.model({'input': img})
|
||||
|
||||
ppl_pred = torch.from_numpy(ppl_pred[0])
|
||||
if isinstance(ppl_pred[0], np.ndarray):
|
||||
ppl_pred = torch.from_numpy(ppl_pred[0])
|
||||
return ppl_pred
|
||||
|
||||
|
||||
@ -442,7 +443,8 @@ class PPLRecognizer(DeployBaseRecognizer):
|
||||
"""
|
||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||
ppl_pred = self.model({'input': img})[0]
|
||||
ppl_pred = torch.from_numpy(ppl_pred[0])
|
||||
if isinstance(ppl_pred[0], np.ndarray):
|
||||
ppl_pred = torch.from_numpy(ppl_pred[0])
|
||||
return ppl_pred
|
||||
|
||||
|
||||
|
@ -295,6 +295,8 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
|
||||
backend_model = openvino_apis.OpenVINOWrapper(openvino_file_path)
|
||||
|
||||
backend_feats = flatten_model_inputs
|
||||
elif backend == Backend.DEFAULT:
|
||||
return ctx_outputs, False
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Unimplemented backend type: {backend.value}')
|
||||
|
15
tests/test_mmocr/data/config/crnn.py
Executable file
15
tests/test_mmocr/data/config/crnn.py
Executable file
@ -0,0 +1,15 @@
|
||||
_base_ = []
|
||||
|
||||
# model
|
||||
label_convertor = dict(
|
||||
type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True)
|
||||
|
||||
model = dict(
|
||||
type='CRNNNet',
|
||||
preprocessor=None,
|
||||
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
|
||||
encoder=None,
|
||||
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||
loss=dict(type='CTCLoss'),
|
||||
label_convertor=label_convertor,
|
||||
pretrained=None)
|
49
tests/test_mmocr/data/config/dbnet.py
Executable file
49
tests/test_mmocr/data/config/dbnet.py
Executable file
@ -0,0 +1,49 @@
|
||||
model = dict(
|
||||
type='DBNet',
|
||||
backbone=dict(
|
||||
type='mmdet.ResNet',
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
|
||||
norm_eval=False,
|
||||
style='caffe'),
|
||||
neck=dict(type='FPNC', in_channels=[2, 4, 8, 16], lateral_channels=8),
|
||||
bbox_head=dict(
|
||||
type='DBHead',
|
||||
text_repr_type='quad',
|
||||
in_channels=8,
|
||||
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)),
|
||||
train_cfg=None,
|
||||
test_cfg=None)
|
||||
|
||||
dataset_type = 'IcdarDataset'
|
||||
data_root = 'data/icdar2015'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(128, 64),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', img_scale=(256, 128), keep_ratio=True),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_test.json',
|
||||
img_prefix=data_root + '/imgs',
|
||||
pipeline=test_pipeline))
|
||||
evaluation = dict(interval=100, metric='hmean-iou')
|
1
tests/test_mmocr/data/icdar2015/instances_test.json
Executable file
1
tests/test_mmocr/data/icdar2015/instances_test.json
Executable file
@ -0,0 +1 @@
|
||||
{"images": [], "categories": [], "annotations": []}
|
291
tests/test_mmocr/test_mmocr_apis.py
Executable file
291
tests/test_mmocr/test_mmocr_apis.py
Executable file
@ -0,0 +1,291 @@
|
||||
import importlib
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import mmdeploy.apis.ncnn as ncnn_apis
|
||||
import mmdeploy.apis.onnxruntime as ort_apis
|
||||
import mmdeploy.apis.ppl as ppl_apis
|
||||
import mmdeploy.apis.tensorrt as trt_apis
|
||||
from mmdeploy.mmocr.apis.inference import get_classes_from_config
|
||||
from mmdeploy.mmocr.apis.visualize import show_result
|
||||
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||
def test_TensorRTDetector():
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure TensorRTDetector can use TRTWrapper inside itself
|
||||
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||
|
||||
outputs = {
|
||||
'output': torch.rand(1, 3, 64, 64).cuda(),
|
||||
}
|
||||
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
from mmdeploy.mmocr.apis.inference import TensorRTDetector
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/dbnet.py')
|
||||
trt_detector = TensorRTDetector('', model_config, 0, False)
|
||||
# watch from config
|
||||
imgs = [torch.rand(1, 3, 64, 64).cuda()]
|
||||
img_metas = [[{
|
||||
'ori_shape': [64, 64, 3],
|
||||
'img_shape': [64, 64, 3],
|
||||
'pad_shape': [64, 64, 3],
|
||||
'scale_factor': [1., 1., 1., 1.],
|
||||
}]]
|
||||
|
||||
results = trt_detector.forward_of_backend(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output \
|
||||
using TensorRTDetector'
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||
def test_ONNXRuntimeDetector():
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure ONNXRuntimeDetector can use ORTWrapper inside itself
|
||||
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
||||
outputs = (torch.rand(1, 3, 64, 64))
|
||||
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
from mmdeploy.mmocr.apis.inference import ONNXRuntimeDetector
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/dbnet.py')
|
||||
ort_detector = ONNXRuntimeDetector('', model_config, 0, False)
|
||||
imgs = [torch.rand(1, 3, 64, 64)]
|
||||
img_metas = [[{
|
||||
'ori_shape': [64, 64, 3],
|
||||
'img_shape': [64, 64, 3],
|
||||
'pad_shape': [64, 64, 3],
|
||||
'scale_factor': [1., 1., 1., 1.],
|
||||
}]]
|
||||
|
||||
results = ort_detector.forward_of_backend(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using '\
|
||||
'ONNXRuntimeDetector'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
|
||||
def test_PPLDetector():
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure PPLDetector can use PPLWrapper inside itself
|
||||
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
|
||||
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
|
||||
|
||||
outputs = (torch.rand(1, 3, 64, 64))
|
||||
with SwitchBackendWrapper(PPLWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
from mmdeploy.mmocr.apis.inference import PPLDetector
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/dbnet.py')
|
||||
ppl_detector = PPLDetector('', model_config, 0, False)
|
||||
imgs = [torch.rand(1, 3, 64, 64)]
|
||||
img_metas = [[{
|
||||
'ori_shape': [64, 64, 3],
|
||||
'img_shape': [64, 64, 3],
|
||||
'pad_shape': [64, 64, 3],
|
||||
'scale_factor': [1., 1., 1., 1.],
|
||||
}]]
|
||||
|
||||
results = ppl_detector.forward_of_backend(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using PPLDetector'
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
|
||||
def test_NCNNDetector():
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure NCNNDetector can use NCNNWrapper inside itself
|
||||
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
|
||||
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
|
||||
|
||||
outputs = {'output': torch.rand(1, 3, 64, 64)}
|
||||
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
from mmdeploy.mmocr.apis.inference import NCNNDetector
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/dbnet.py')
|
||||
ncnn_detector = NCNNDetector(['', ''], model_config, 0, False)
|
||||
imgs = [torch.rand(1, 3, 64, 64)]
|
||||
img_metas = [[{
|
||||
'ori_shape': [64, 64, 3],
|
||||
'img_shape': [64, 64, 3],
|
||||
'pad_shape': [64, 64, 3],
|
||||
'scale_factor': [1., 1., 1., 1.],
|
||||
}]]
|
||||
|
||||
results = ncnn_detector.forward_of_backend(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using NCNNDetector'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||
def test_TensorRTRecognizer():
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure TensorRTRecognizer can use TRTWrapper inside itself
|
||||
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||
|
||||
outputs = {
|
||||
'output': torch.rand(1, 9, 37).cuda(),
|
||||
}
|
||||
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
from mmdeploy.mmocr.apis.inference import TensorRTRecognizer
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/crnn.py')
|
||||
trt_recognizer = TensorRTRecognizer('', model_config, 0, False)
|
||||
# watch from config
|
||||
imgs = [torch.rand(1, 1, 32, 32).cuda()]
|
||||
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||
|
||||
results = trt_recognizer.forward_of_backend(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using \
|
||||
TensorRTRecognizer'
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||
def test_ONNXRuntimeRecognizer():
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure ONNXRuntimeRecognizer can use ORTWrapper inside itself
|
||||
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
||||
outputs = (torch.rand(1, 9, 37))
|
||||
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
from mmdeploy.mmocr.apis.inference import ONNXRuntimeRecognizer
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/crnn.py')
|
||||
ort_recognizer = ONNXRuntimeRecognizer('', model_config, 0, False)
|
||||
imgs = [torch.rand(1, 1, 32, 32).numpy()]
|
||||
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||
|
||||
results = ort_recognizer.forward_of_backend(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using '\
|
||||
'ONNXRuntimeRecognizer'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
|
||||
def test_PPLRecognizer():
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure PPLRecognizer can use PPLWrapper inside itself
|
||||
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
|
||||
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
|
||||
|
||||
outputs = (torch.rand(1, 9, 37))
|
||||
with SwitchBackendWrapper(PPLWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
from mmdeploy.mmocr.apis.inference import PPLRecognizer
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/crnn.py')
|
||||
ppl_recognizer = PPLRecognizer('', model_config, 0, False)
|
||||
imgs = [torch.rand(1, 1, 32, 32)]
|
||||
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||
|
||||
results = ppl_recognizer.forward_of_backend(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using PPLRecognizer'
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
|
||||
def test_NCNNRecognizer():
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure NCNNPSSDetector can use NCNNWrapper inside itself
|
||||
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
|
||||
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
|
||||
|
||||
outputs = {'output': torch.rand(1, 9, 37)}
|
||||
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
|
||||
wrapper.set(outputs=outputs)
|
||||
from mmdeploy.mmocr.apis.inference import NCNNRecognizer
|
||||
model_config = mmcv.Config.fromfile(
|
||||
'tests/test_mmocr/data/config/crnn.py')
|
||||
ncnn_recognizer = NCNNRecognizer(['', ''], model_config, 0, False)
|
||||
imgs = [torch.rand(1, 1, 32, 32)]
|
||||
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
|
||||
|
||||
results = ncnn_recognizer.forward_of_backend(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using NCNNRecognizer'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'task, model',
|
||||
[('TextDetection', 'tests/test_mmocr/data/config/dbnet.py'),
|
||||
('TextRecognition', 'tests/test_mmocr/data/config/crnn.py')])
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||
def test_build_ocr_processor(task, model):
|
||||
model_cfg = mmcv.Config.fromfile(model)
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(type='mmocr', task=task)))
|
||||
|
||||
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
||||
# simplify backend inference
|
||||
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||
from mmdeploy.apis.utils import init_backend_model
|
||||
result_model = init_backend_model([''], model_cfg, deploy_cfg, -1)
|
||||
assert result_model is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model', ['tests/test_mmocr/data/config/dbnet.py'])
|
||||
def test_get_classes_from_config(model):
|
||||
get_classes_from_config(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'task, model',
|
||||
[('TextDetection', 'tests/test_mmocr/data/config/dbnet.py')])
|
||||
@pytest.mark.skipif(
|
||||
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||
def test_show_result(task, model):
|
||||
model_cfg = mmcv.Config.fromfile(model)
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(type='mmocr', task=task)))
|
||||
|
||||
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
||||
# simplify backend inference
|
||||
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||
from mmdeploy.apis.utils import init_backend_model
|
||||
detector = init_backend_model([''], model_cfg, deploy_cfg, -1)
|
||||
img = np.random.random((64, 64, 3))
|
||||
result = {'boundary_result': [[1, 2, 3, 4, 5], [2, 2, 0, 4, 5]]}
|
||||
import os.path
|
||||
import tempfile
|
||||
|
||||
from mmdeploy.utils.constants import Backend
|
||||
with tempfile.TemporaryDirectory() as dir:
|
||||
filename = dir + 'tmp.jpg'
|
||||
show_result(
|
||||
detector,
|
||||
img,
|
||||
result,
|
||||
filename,
|
||||
Backend.ONNXRUNTIME,
|
||||
show=False)
|
||||
assert os.path.exists(filename)
|
135
tests/test_mmocr/test_mmocr_export.py
Normal file
135
tests/test_mmocr/test_mmocr_export.py
Normal file
@ -0,0 +1,135 @@
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmdeploy.apis.utils import build_dataloader, build_dataset, create_input
|
||||
from mmdeploy.utils.constants import Codebase, Task
|
||||
|
||||
|
||||
@pytest.mark.parametrize('task', [Task.TEXT_DETECTION, Task.TEXT_RECOGNITION])
|
||||
def test_create_input(task):
|
||||
if task == Task.TEXT_DETECTION:
|
||||
test = dict(
|
||||
type='IcdarDataset',
|
||||
pipeline=[{
|
||||
'type': 'LoadImageFromFile',
|
||||
'color_type': 'color_ignore_orientation'
|
||||
}, {
|
||||
'type':
|
||||
'MultiScaleFlipAug',
|
||||
'img_scale': (128, 64),
|
||||
'flip':
|
||||
False,
|
||||
'transforms': [
|
||||
{
|
||||
'type': 'Resize',
|
||||
'img_scale': (256, 128),
|
||||
'keep_ratio': True
|
||||
},
|
||||
{
|
||||
'type': 'Normalize',
|
||||
'mean': [123.675, 116.28, 103.53],
|
||||
'std': [58.395, 57.12, 57.375],
|
||||
'to_rgb': True
|
||||
},
|
||||
{
|
||||
'type': 'Pad',
|
||||
'size_divisor': 32
|
||||
},
|
||||
{
|
||||
'type': 'ImageToTensor',
|
||||
'keys': ['img']
|
||||
},
|
||||
{
|
||||
'type': 'Collect',
|
||||
'keys': ['img']
|
||||
},
|
||||
]
|
||||
}])
|
||||
imgs = [np.random.rand(128, 64, 3).astype(np.uint8)]
|
||||
elif task == Task.TEXT_RECOGNITION:
|
||||
test = dict(
|
||||
type='UniformConcatDataset',
|
||||
pipeline=[
|
||||
{
|
||||
'type': 'LoadImageFromFile',
|
||||
'color_type': 'grayscale'
|
||||
},
|
||||
{
|
||||
'type': 'ResizeOCR',
|
||||
'height': 32,
|
||||
'min_width': 32,
|
||||
'max_width': None,
|
||||
'keep_aspect_ratio': True
|
||||
},
|
||||
{
|
||||
'type': 'Normalize',
|
||||
'mean': [127],
|
||||
'std': [127]
|
||||
},
|
||||
{
|
||||
'type': 'DefaultFormatBundle'
|
||||
},
|
||||
{
|
||||
'type': 'Collect',
|
||||
'keys': ['img'],
|
||||
'meta_keys': ['filename', 'resize_shape', 'valid_ratio']
|
||||
},
|
||||
])
|
||||
imgs = [np.random.random((32, 32, 3)).astype(np.uint8)]
|
||||
data = dict(test=test)
|
||||
model_cfg = mmcv.Config(dict(data=data))
|
||||
inputs = create_input(
|
||||
Codebase.MMOCR,
|
||||
task,
|
||||
model_cfg,
|
||||
imgs,
|
||||
input_shape=imgs[0].shape[0:2],
|
||||
device='cpu')
|
||||
assert inputs is not None, 'Failed to create input'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('task', [Task.TEXT_DETECTION, Task.TEXT_RECOGNITION])
|
||||
def test_build_dataset(task):
|
||||
import tempfile
|
||||
import os
|
||||
ann_file, ann_path = tempfile.mkstemp()
|
||||
if task == Task.TEXT_DETECTION:
|
||||
data = dict(
|
||||
test={
|
||||
'type': 'IcdarDataset',
|
||||
'ann_file':
|
||||
'tests/test_mmocr/data/icdar2015/instances_test.json',
|
||||
'img_prefix': 'tests/test_mmocr/data/icdar2015/imgs',
|
||||
'pipeline': [
|
||||
{
|
||||
'type': 'LoadImageFromFile'
|
||||
},
|
||||
]
|
||||
})
|
||||
elif task == Task.TEXT_RECOGNITION:
|
||||
data = dict(
|
||||
test={
|
||||
'type': 'OCRDataset',
|
||||
'ann_file': ann_path,
|
||||
'img_prefix': '',
|
||||
'loader': {
|
||||
'type': 'HardDiskLoader',
|
||||
'repeat': 1,
|
||||
'parser': {
|
||||
'type': 'LineStrParser',
|
||||
'keys': ['filename', 'text'],
|
||||
'keys_idx': [0, 1],
|
||||
'separator': ' '
|
||||
}
|
||||
},
|
||||
'pipeline': [],
|
||||
'test_mode': True
|
||||
})
|
||||
dataset_cfg = mmcv.Config(dict(data=data))
|
||||
dataset = build_dataset(
|
||||
Codebase.MMOCR, dataset_cfg=dataset_cfg, dataset_type='test')
|
||||
assert dataset is not None, 'Failed to build dataset'
|
||||
dataloader = build_dataloader(Codebase.MMOCR, dataset, 1, 1)
|
||||
os.close(ann_file)
|
||||
assert dataloader is not None, 'Failed to build dataloader'
|
409
tests/test_mmocr/test_mmocr_models.py
Normal file
409
tests/test_mmocr/test_mmocr_models.py
Normal file
@ -0,0 +1,409 @@
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmocr.models.textdet.necks import FPNC
|
||||
|
||||
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
|
||||
get_rewrite_outputs)
|
||||
|
||||
|
||||
class FPNCNeckModel(FPNC):
|
||||
|
||||
def __init__(self, in_channels, init_cfg=None):
|
||||
super().__init__(in_channels, init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.neck = FPNC(in_channels, init_cfg=init_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
neck_inputs = [
|
||||
torch.ones(1, channel, inputs.shape[-2], inputs.shape[-1])
|
||||
for channel in self.in_channels
|
||||
]
|
||||
output = self.neck.forward(neck_inputs)
|
||||
return output
|
||||
|
||||
|
||||
def get_bidirectionallstm_model():
|
||||
from mmocr.models.textrecog.layers.lstm_layer import BidirectionalLSTM
|
||||
model = BidirectionalLSTM(32, 16, 16)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_single_stage_text_detector_model():
|
||||
from mmocr.models.textdet import SingleStageTextDetector
|
||||
backbone = dict(
|
||||
type='mmdet.ResNet',
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
|
||||
norm_eval=False,
|
||||
style='caffe')
|
||||
neck = dict(
|
||||
type='FPNC',
|
||||
in_channels=[64, 128, 256, 512],
|
||||
lateral_channels=4,
|
||||
out_channels=4)
|
||||
bbox_head = dict(
|
||||
type='DBHead',
|
||||
text_repr_type='quad',
|
||||
in_channels=16,
|
||||
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True))
|
||||
model = SingleStageTextDetector(backbone, neck, bbox_head)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_encode_decode_recognizer_model():
|
||||
from mmocr.models.textrecog import EncodeDecodeRecognizer
|
||||
|
||||
cfg = dict(
|
||||
preprocessor=None,
|
||||
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
|
||||
encoder=dict(type='TFEncoder'),
|
||||
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||
loss=dict(type='CTCLoss'),
|
||||
label_convertor=dict(
|
||||
type='CTCConvertor',
|
||||
dict_type='DICT36',
|
||||
with_unknown=False,
|
||||
lower=True),
|
||||
pretrained=None)
|
||||
|
||||
model = EncodeDecodeRecognizer(
|
||||
backbone=cfg['backbone'],
|
||||
encoder=cfg['encoder'],
|
||||
decoder=cfg['decoder'],
|
||||
loss=cfg['loss'],
|
||||
label_convertor=cfg['label_convertor'])
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_crnn_decoder_model(rnn_flag):
|
||||
from mmocr.models.textrecog.decoders import CRNNDecoder
|
||||
model = CRNNDecoder(32, 4, rnn_flag=rnn_flag)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_fpnc_neck_model():
|
||||
model = FPNCNeckModel([2, 4, 8, 16])
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_base_recognizer_model():
|
||||
from mmocr.models.textrecog import CRNNNet
|
||||
|
||||
cfg = dict(
|
||||
preprocessor=None,
|
||||
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
|
||||
encoder=None,
|
||||
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||
loss=dict(type='CTCLoss'),
|
||||
label_convertor=dict(
|
||||
type='CTCConvertor',
|
||||
dict_type='DICT36',
|
||||
with_unknown=False,
|
||||
lower=True),
|
||||
pretrained=None)
|
||||
|
||||
model = CRNNNet(
|
||||
backbone=cfg['backbone'],
|
||||
decoder=cfg['decoder'],
|
||||
loss=cfg['loss'],
|
||||
label_convertor=cfg['label_convertor'])
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', ['ncnn'])
|
||||
def test_bidirectionallstm(backend_type):
|
||||
"""Test forward rewrite of bidirectionallstm."""
|
||||
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||
bilstm = get_bidirectionallstm_model()
|
||||
bilstm.cpu().eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
task='TextRecognition',
|
||||
)))
|
||||
|
||||
input = torch.rand(1, 1, 32)
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'input': input,
|
||||
}
|
||||
model_outputs = get_model_outputs(bilstm, 'forward', model_inputs)
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
wrapped_model = WrapModel(bilstm, 'forward')
|
||||
rewrite_inputs = {'input': input}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def test_simple_test_of_single_stage_text_detector():
|
||||
"""Test simple_test single_stage_text_detector."""
|
||||
single_stage_text_detector = get_single_stage_text_detector_model()
|
||||
single_stage_text_detector.eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='default'),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
task='TextDetection',
|
||||
)))
|
||||
|
||||
input = torch.rand(1, 3, 64, 64)
|
||||
img_metas = [{
|
||||
'ori_shape': [64, 64, 3],
|
||||
'img_shape': [64, 64, 3],
|
||||
'pad_shape': [64, 64, 3],
|
||||
'scale_factor': [1., 1., 1., 1],
|
||||
}]
|
||||
|
||||
x = single_stage_text_detector.extract_feat(input)
|
||||
model_outputs = single_stage_text_detector.bbox_head(x)
|
||||
|
||||
wrapped_model = WrapModel(single_stage_text_detector, 'simple_test')
|
||||
rewrite_inputs = {'img': input, 'img_metas': img_metas[0]}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', ['ncnn'])
|
||||
@pytest.mark.parametrize('rnn_flag', [True, False])
|
||||
def test_crnndecoder(backend_type, rnn_flag):
|
||||
"""Test forward rewrite of crnndecoder."""
|
||||
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
|
||||
crnn_decoder = get_crnn_decoder_model(rnn_flag)
|
||||
crnn_decoder.cpu().eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
task='TextRecognition',
|
||||
)))
|
||||
|
||||
input = torch.rand(1, 32, 1, 64)
|
||||
out_enc = None
|
||||
targets_dict = None
|
||||
img_metas = None
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'feat': input,
|
||||
'out_enc': out_enc,
|
||||
'targets_dict': targets_dict,
|
||||
'img_metas': img_metas
|
||||
}
|
||||
model_outputs = get_model_outputs(crnn_decoder, 'forward_train',
|
||||
model_inputs)
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
wrapped_model = WrapModel(
|
||||
crnn_decoder,
|
||||
'forward_train',
|
||||
out_enc=out_enc,
|
||||
targets_dict=targets_dict,
|
||||
img_metas=img_metas)
|
||||
rewrite_inputs = {'feat': input}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'img_metas', [[None], [{
|
||||
'resize_shape': [32, 32],
|
||||
'valid_ratio': 1.0
|
||||
}]])
|
||||
@pytest.mark.parametrize('is_dynamic', [True, False])
|
||||
def test_forward_of_base_recognizer(img_metas, is_dynamic):
|
||||
"""Test forward base_recognizer."""
|
||||
base_recognizer = get_base_recognizer_model()
|
||||
base_recognizer.eval()
|
||||
|
||||
if not is_dynamic:
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='ncnn'),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
task='TextRecognition',
|
||||
)))
|
||||
else:
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='ncnn'),
|
||||
onnx_config=dict(
|
||||
input_shape=None,
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'output': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
}
|
||||
}),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
task='TextRecognition',
|
||||
)))
|
||||
|
||||
input = torch.rand(1, 1, 32, 32)
|
||||
|
||||
feat = base_recognizer.extract_feat(input)
|
||||
out_enc = None
|
||||
if base_recognizer.encoder is not None:
|
||||
out_enc = base_recognizer.encoder(feat, img_metas)
|
||||
model_outputs = base_recognizer.decoder(
|
||||
feat, out_enc, None, img_metas, train_mode=False)
|
||||
wrapped_model = WrapModel(
|
||||
base_recognizer, 'forward', img_metas=img_metas[0])
|
||||
rewrite_inputs = {
|
||||
'img': input,
|
||||
}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def test_simple_test_of_encode_decode_recognizer():
|
||||
"""Test simple_test encode_decode_recognizer."""
|
||||
encode_decode_recognizer = get_encode_decode_recognizer_model()
|
||||
encode_decode_recognizer.eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='default'),
|
||||
onnx_config=dict(input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmocr',
|
||||
task='TextRecognition',
|
||||
)))
|
||||
|
||||
input = torch.rand(1, 1, 32, 32)
|
||||
img_metas = [{'resize_shape': [32, 32], 'valid_ratio': 1.0}]
|
||||
|
||||
feat = encode_decode_recognizer.extract_feat(input)
|
||||
out_enc = None
|
||||
if encode_decode_recognizer.encoder is not None:
|
||||
out_enc = encode_decode_recognizer.encoder(feat, img_metas)
|
||||
model_outputs = encode_decode_recognizer.decoder(
|
||||
feat, out_enc, None, img_metas, train_mode=False)
|
||||
|
||||
wrapped_model = WrapModel(
|
||||
encode_decode_recognizer, 'simple_test', img_metas=img_metas)
|
||||
rewrite_inputs = {'img': input}
|
||||
rewrite_outputs = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
@pytest.mark.parametrize('backend_type', ['tensorrt'])
|
||||
def test_forward_of_fpnc(backend_type):
|
||||
"""Test forward rewrite of fpnc."""
|
||||
fpnc = get_fpnc_neck_model()
|
||||
fpnc.eval()
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(
|
||||
type=backend_type,
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
input=dict(
|
||||
min_shape=[1, 3, 64, 64],
|
||||
opt_shape=[1, 3, 64, 64],
|
||||
max_shape=[1, 3, 64, 64])))
|
||||
]),
|
||||
onnx_config=dict(input_shape=[64, 64], output_names=['output']),
|
||||
codebase_config=dict(type='mmocr', task='TextDetection')))
|
||||
|
||||
input = torch.rand(1, 3, 64, 64).cuda()
|
||||
model_inputs = {
|
||||
'inputs': input,
|
||||
}
|
||||
model_outputs = get_model_outputs(fpnc, 'forward', model_inputs)
|
||||
wrapped_model = WrapModel(fpnc, 'forward')
|
||||
rewrite_inputs = {
|
||||
'inputs': input,
|
||||
}
|
||||
rewrite_outputs, is_need_name = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_need_name:
|
||||
model_output = model_outputs[0].squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_outputs['output'].squeeze().cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
for model_output, rewrite_output in zip(model_outputs,
|
||||
rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze().cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
Loading…
x
Reference in New Issue
Block a user