[UnitTest] mmocr unittest (#130)

* WIP test_mmocr 8 out of 20

* test_mmocr_export

* test mmocr apis

* add test data

* add mmocr model unittest 5 passed 1 failed

* finish mmocr unittest

* fix lint

* fix yapf

* fix isort

* fix flake8

* fix docformatter

* fix docformatter

* try to fix unittest after merge master

* Change test.py for backend.DEFAULT

* fix flake8

* fix ut

* fix yapf

* fix ut build

* fix yapf

* fix mmocr_export ut

* fix mmocr_apis ort not cuda

* remove explicit .forward

* remove backendwrapper

* simplify the crnn and dbnet config

* simplify instance_test.json

* add another case of decoder

* increase coverage of test_mmocr_models base_recognizer

* improve coverage

* improve encode_decoder coverage

* reply for grimoire codereview

* what if not check cuda?

* remove image data

* reply to runningleon code review

* fix fpnc

* fix lint

* try to fix CI UT error

* fix fpnc with and wo custom ops

* fix yapf

* skip fpnc when cuda is not ready in ci

* reply for code review

* reply for code review

* fix yapf

* reply for code review

* fix yapf

* fix conflict

* remove unmatched data path

* remove unnecessary comments
This commit is contained in:
hanrui1sensetime 2021-10-25 10:15:57 +08:00 committed by GitHub
parent c5a87fb1bc
commit 9e227b228b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 909 additions and 5 deletions

View File

@ -1,6 +1,7 @@
from typing import Iterable, Sequence, Union from typing import Iterable, Sequence, Union
import mmcv import mmcv
import numpy as np
import torch import torch
from mmdet.models.builder import DETECTORS from mmdet.models.builder import DETECTORS
from mmocr.datasets import DATASETS from mmocr.datasets import DATASETS
@ -60,6 +61,8 @@ class DeployBaseTextDetector(TextDetectorMixin, SingleStageTextDetector):
list: A list of predictions. list: A list of predictions.
""" """
pred = self.forward_of_backend(img, img_metas, *args, **kwargs) pred = self.forward_of_backend(img, img_metas, *args, **kwargs)
if isinstance(pred, np.ndarray):
pred = torch.from_numpy(pred[0])
if len(img_metas) > 1: if len(img_metas) > 1:
boundaries = [ boundaries = [
self.bbox_head.get_boundary( self.bbox_head.get_boundary(
@ -187,7 +190,6 @@ class ONNXRuntimeDetector(DeployBaseTextDetector):
np.ndarray: Prediction of input model. np.ndarray: Prediction of input model.
""" """
onnx_pred = self.model({'input': img}) onnx_pred = self.model({'input': img})
onnx_pred = torch.from_numpy(onnx_pred[0])
return onnx_pred return onnx_pred
@ -223,7 +225,6 @@ class ONNXRuntimeRecognizer(DeployBaseRecognizer):
np.ndarray: Prediction of input model. np.ndarray: Prediction of input model.
""" """
onnx_pred = self.model({'input': img}) onnx_pred = self.model({'input': img})
onnx_pred = torch.from_numpy(onnx_pred[0])
return onnx_pred return onnx_pred
@ -403,8 +404,8 @@ class PPLDetector(DeployBaseTextDetector):
""" """
with torch.cuda.device(self.device_id), torch.no_grad(): with torch.cuda.device(self.device_id), torch.no_grad():
ppl_pred = self.model({'input': img}) ppl_pred = self.model({'input': img})
if isinstance(ppl_pred[0], np.ndarray):
ppl_pred = torch.from_numpy(ppl_pred[0]) ppl_pred = torch.from_numpy(ppl_pred[0])
return ppl_pred return ppl_pred
@ -442,7 +443,8 @@ class PPLRecognizer(DeployBaseRecognizer):
""" """
with torch.cuda.device(self.device_id), torch.no_grad(): with torch.cuda.device(self.device_id), torch.no_grad():
ppl_pred = self.model({'input': img})[0] ppl_pred = self.model({'input': img})[0]
ppl_pred = torch.from_numpy(ppl_pred[0]) if isinstance(ppl_pred[0], np.ndarray):
ppl_pred = torch.from_numpy(ppl_pred[0])
return ppl_pred return ppl_pred

View File

@ -295,6 +295,8 @@ def get_rewrite_outputs(wrapped_model: nn.Module,
backend_model = openvino_apis.OpenVINOWrapper(openvino_file_path) backend_model = openvino_apis.OpenVINOWrapper(openvino_file_path)
backend_feats = flatten_model_inputs backend_feats = flatten_model_inputs
elif backend == Backend.DEFAULT:
return ctx_outputs, False
else: else:
raise NotImplementedError( raise NotImplementedError(
f'Unimplemented backend type: {backend.value}') f'Unimplemented backend type: {backend.value}')

View File

@ -0,0 +1,15 @@
_base_ = []
# model
label_convertor = dict(
type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True)
model = dict(
type='CRNNNet',
preprocessor=None,
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
encoder=None,
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss'),
label_convertor=label_convertor,
pretrained=None)

View File

@ -0,0 +1,49 @@
model = dict(
type='DBNet',
backbone=dict(
type='mmdet.ResNet',
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
norm_eval=False,
style='caffe'),
neck=dict(type='FPNC', in_channels=[2, 4, 8, 16], lateral_channels=8),
bbox_head=dict(
type='DBHead',
text_repr_type='quad',
in_channels=8,
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)),
train_cfg=None,
test_cfg=None)
dataset_type = 'IcdarDataset'
data_root = 'data/icdar2015'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
test_pipeline = [
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
dict(
type='MultiScaleFlipAug',
img_scale=(128, 64),
flip=False,
transforms=[
dict(type='Resize', img_scale=(256, 128), keep_ratio=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=16,
test_dataloader=dict(samples_per_gpu=1),
test=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline))
evaluation = dict(interval=100, metric='hmean-iou')

View File

@ -0,0 +1 @@
{"images": [], "categories": [], "annotations": []}

View File

@ -0,0 +1,291 @@
import importlib
import mmcv
import numpy as np
import pytest
import torch
import mmdeploy.apis.ncnn as ncnn_apis
import mmdeploy.apis.onnxruntime as ort_apis
import mmdeploy.apis.ppl as ppl_apis
import mmdeploy.apis.tensorrt as trt_apis
from mmdeploy.mmocr.apis.inference import get_classes_from_config
from mmdeploy.mmocr.apis.visualize import show_result
from mmdeploy.utils.test import SwitchBackendWrapper
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_TensorRTDetector():
# force add backend wrapper regardless of plugins
# make sure TensorRTDetector can use TRTWrapper inside itself
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
outputs = {
'output': torch.rand(1, 3, 64, 64).cuda(),
}
with SwitchBackendWrapper(TRTWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmocr.apis.inference import TensorRTDetector
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/dbnet.py')
trt_detector = TensorRTDetector('', model_config, 0, False)
# watch from config
imgs = [torch.rand(1, 3, 64, 64).cuda()]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'pad_shape': [64, 64, 3],
'scale_factor': [1., 1., 1., 1.],
}]]
results = trt_detector.forward_of_backend(imgs, img_metas)
assert results is not None, 'failed to get output \
using TensorRTDetector'
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_ONNXRuntimeDetector():
# force add backend wrapper regardless of plugins
# make sure ONNXRuntimeDetector can use ORTWrapper inside itself
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
outputs = (torch.rand(1, 3, 64, 64))
with SwitchBackendWrapper(ORTWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmocr.apis.inference import ONNXRuntimeDetector
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/dbnet.py')
ort_detector = ONNXRuntimeDetector('', model_config, 0, False)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'pad_shape': [64, 64, 3],
'scale_factor': [1., 1., 1., 1.],
}]]
results = ort_detector.forward_of_backend(imgs, img_metas)
assert results is not None, 'failed to get output using '\
'ONNXRuntimeDetector'
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
def test_PPLDetector():
# force add backend wrapper regardless of plugins
# make sure PPLDetector can use PPLWrapper inside itself
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
outputs = (torch.rand(1, 3, 64, 64))
with SwitchBackendWrapper(PPLWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmocr.apis.inference import PPLDetector
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/dbnet.py')
ppl_detector = PPLDetector('', model_config, 0, False)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'pad_shape': [64, 64, 3],
'scale_factor': [1., 1., 1., 1.],
}]]
results = ppl_detector.forward_of_backend(imgs, img_metas)
assert results is not None, 'failed to get output using PPLDetector'
@pytest.mark.skipif(
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
def test_NCNNDetector():
# force add backend wrapper regardless of plugins
# make sure NCNNDetector can use NCNNWrapper inside itself
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
outputs = {'output': torch.rand(1, 3, 64, 64)}
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmocr.apis.inference import NCNNDetector
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/dbnet.py')
ncnn_detector = NCNNDetector(['', ''], model_config, 0, False)
imgs = [torch.rand(1, 3, 64, 64)]
img_metas = [[{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'pad_shape': [64, 64, 3],
'scale_factor': [1., 1., 1., 1.],
}]]
results = ncnn_detector.forward_of_backend(imgs, img_metas)
assert results is not None, 'failed to get output using NCNNDetector'
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
def test_TensorRTRecognizer():
# force add backend wrapper regardless of plugins
# make sure TensorRTRecognizer can use TRTWrapper inside itself
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
outputs = {
'output': torch.rand(1, 9, 37).cuda(),
}
with SwitchBackendWrapper(TRTWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmocr.apis.inference import TensorRTRecognizer
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/crnn.py')
trt_recognizer = TensorRTRecognizer('', model_config, 0, False)
# watch from config
imgs = [torch.rand(1, 1, 32, 32).cuda()]
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
results = trt_recognizer.forward_of_backend(imgs, img_metas)
assert results is not None, 'failed to get output using \
TensorRTRecognizer'
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_ONNXRuntimeRecognizer():
# force add backend wrapper regardless of plugins
# make sure ONNXRuntimeRecognizer can use ORTWrapper inside itself
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
outputs = (torch.rand(1, 9, 37))
with SwitchBackendWrapper(ORTWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmocr.apis.inference import ONNXRuntimeRecognizer
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/crnn.py')
ort_recognizer = ONNXRuntimeRecognizer('', model_config, 0, False)
imgs = [torch.rand(1, 1, 32, 32).numpy()]
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
results = ort_recognizer.forward_of_backend(imgs, img_metas)
assert results is not None, 'failed to get output using '\
'ONNXRuntimeRecognizer'
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
def test_PPLRecognizer():
# force add backend wrapper regardless of plugins
# make sure PPLRecognizer can use PPLWrapper inside itself
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
outputs = (torch.rand(1, 9, 37))
with SwitchBackendWrapper(PPLWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmocr.apis.inference import PPLRecognizer
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/crnn.py')
ppl_recognizer = PPLRecognizer('', model_config, 0, False)
imgs = [torch.rand(1, 1, 32, 32)]
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
results = ppl_recognizer.forward_of_backend(imgs, img_metas)
assert results is not None, 'failed to get output using PPLRecognizer'
@pytest.mark.skipif(
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
def test_NCNNRecognizer():
# force add backend wrapper regardless of plugins
# make sure NCNNPSSDetector can use NCNNWrapper inside itself
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
outputs = {'output': torch.rand(1, 9, 37)}
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
wrapper.set(outputs=outputs)
from mmdeploy.mmocr.apis.inference import NCNNRecognizer
model_config = mmcv.Config.fromfile(
'tests/test_mmocr/data/config/crnn.py')
ncnn_recognizer = NCNNRecognizer(['', ''], model_config, 0, False)
imgs = [torch.rand(1, 1, 32, 32)]
img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]]
results = ncnn_recognizer.forward_of_backend(imgs, img_metas)
assert results is not None, 'failed to get output using NCNNRecognizer'
@pytest.mark.parametrize(
'task, model',
[('TextDetection', 'tests/test_mmocr/data/config/dbnet.py'),
('TextRecognition', 'tests/test_mmocr/data/config/crnn.py')])
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_build_ocr_processor(task, model):
model_cfg = mmcv.Config.fromfile(model)
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
codebase_config=dict(type='mmocr', task=task)))
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
with SwitchBackendWrapper(ORTWrapper) as wrapper:
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
from mmdeploy.apis.utils import init_backend_model
result_model = init_backend_model([''], model_cfg, deploy_cfg, -1)
assert result_model is not None
@pytest.mark.parametrize('model', ['tests/test_mmocr/data/config/dbnet.py'])
def test_get_classes_from_config(model):
get_classes_from_config(model)
@pytest.mark.parametrize(
'task, model',
[('TextDetection', 'tests/test_mmocr/data/config/dbnet.py')])
@pytest.mark.skipif(
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
def test_show_result(task, model):
model_cfg = mmcv.Config.fromfile(model)
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
codebase_config=dict(type='mmocr', task=task)))
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
# simplify backend inference
with SwitchBackendWrapper(ORTWrapper) as wrapper:
wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg)
from mmdeploy.apis.utils import init_backend_model
detector = init_backend_model([''], model_cfg, deploy_cfg, -1)
img = np.random.random((64, 64, 3))
result = {'boundary_result': [[1, 2, 3, 4, 5], [2, 2, 0, 4, 5]]}
import os.path
import tempfile
from mmdeploy.utils.constants import Backend
with tempfile.TemporaryDirectory() as dir:
filename = dir + 'tmp.jpg'
show_result(
detector,
img,
result,
filename,
Backend.ONNXRUNTIME,
show=False)
assert os.path.exists(filename)

View File

@ -0,0 +1,135 @@
import mmcv
import numpy as np
import pytest
from mmdeploy.apis.utils import build_dataloader, build_dataset, create_input
from mmdeploy.utils.constants import Codebase, Task
@pytest.mark.parametrize('task', [Task.TEXT_DETECTION, Task.TEXT_RECOGNITION])
def test_create_input(task):
if task == Task.TEXT_DETECTION:
test = dict(
type='IcdarDataset',
pipeline=[{
'type': 'LoadImageFromFile',
'color_type': 'color_ignore_orientation'
}, {
'type':
'MultiScaleFlipAug',
'img_scale': (128, 64),
'flip':
False,
'transforms': [
{
'type': 'Resize',
'img_scale': (256, 128),
'keep_ratio': True
},
{
'type': 'Normalize',
'mean': [123.675, 116.28, 103.53],
'std': [58.395, 57.12, 57.375],
'to_rgb': True
},
{
'type': 'Pad',
'size_divisor': 32
},
{
'type': 'ImageToTensor',
'keys': ['img']
},
{
'type': 'Collect',
'keys': ['img']
},
]
}])
imgs = [np.random.rand(128, 64, 3).astype(np.uint8)]
elif task == Task.TEXT_RECOGNITION:
test = dict(
type='UniformConcatDataset',
pipeline=[
{
'type': 'LoadImageFromFile',
'color_type': 'grayscale'
},
{
'type': 'ResizeOCR',
'height': 32,
'min_width': 32,
'max_width': None,
'keep_aspect_ratio': True
},
{
'type': 'Normalize',
'mean': [127],
'std': [127]
},
{
'type': 'DefaultFormatBundle'
},
{
'type': 'Collect',
'keys': ['img'],
'meta_keys': ['filename', 'resize_shape', 'valid_ratio']
},
])
imgs = [np.random.random((32, 32, 3)).astype(np.uint8)]
data = dict(test=test)
model_cfg = mmcv.Config(dict(data=data))
inputs = create_input(
Codebase.MMOCR,
task,
model_cfg,
imgs,
input_shape=imgs[0].shape[0:2],
device='cpu')
assert inputs is not None, 'Failed to create input'
@pytest.mark.parametrize('task', [Task.TEXT_DETECTION, Task.TEXT_RECOGNITION])
def test_build_dataset(task):
import tempfile
import os
ann_file, ann_path = tempfile.mkstemp()
if task == Task.TEXT_DETECTION:
data = dict(
test={
'type': 'IcdarDataset',
'ann_file':
'tests/test_mmocr/data/icdar2015/instances_test.json',
'img_prefix': 'tests/test_mmocr/data/icdar2015/imgs',
'pipeline': [
{
'type': 'LoadImageFromFile'
},
]
})
elif task == Task.TEXT_RECOGNITION:
data = dict(
test={
'type': 'OCRDataset',
'ann_file': ann_path,
'img_prefix': '',
'loader': {
'type': 'HardDiskLoader',
'repeat': 1,
'parser': {
'type': 'LineStrParser',
'keys': ['filename', 'text'],
'keys_idx': [0, 1],
'separator': ' '
}
},
'pipeline': [],
'test_mode': True
})
dataset_cfg = mmcv.Config(dict(data=data))
dataset = build_dataset(
Codebase.MMOCR, dataset_cfg=dataset_cfg, dataset_type='test')
assert dataset is not None, 'Failed to build dataset'
dataloader = build_dataloader(Codebase.MMOCR, dataset, 1, 1)
os.close(ann_file)
assert dataloader is not None, 'Failed to build dataloader'

View File

@ -0,0 +1,409 @@
import mmcv
import numpy as np
import pytest
import torch
from mmocr.models.textdet.necks import FPNC
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
get_rewrite_outputs)
class FPNCNeckModel(FPNC):
def __init__(self, in_channels, init_cfg=None):
super().__init__(in_channels, init_cfg=init_cfg)
self.in_channels = in_channels
self.neck = FPNC(in_channels, init_cfg=init_cfg)
def forward(self, inputs):
neck_inputs = [
torch.ones(1, channel, inputs.shape[-2], inputs.shape[-1])
for channel in self.in_channels
]
output = self.neck.forward(neck_inputs)
return output
def get_bidirectionallstm_model():
from mmocr.models.textrecog.layers.lstm_layer import BidirectionalLSTM
model = BidirectionalLSTM(32, 16, 16)
model.requires_grad_(False)
return model
def get_single_stage_text_detector_model():
from mmocr.models.textdet import SingleStageTextDetector
backbone = dict(
type='mmdet.ResNet',
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
norm_eval=False,
style='caffe')
neck = dict(
type='FPNC',
in_channels=[64, 128, 256, 512],
lateral_channels=4,
out_channels=4)
bbox_head = dict(
type='DBHead',
text_repr_type='quad',
in_channels=16,
loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True))
model = SingleStageTextDetector(backbone, neck, bbox_head)
model.requires_grad_(False)
return model
def get_encode_decode_recognizer_model():
from mmocr.models.textrecog import EncodeDecodeRecognizer
cfg = dict(
preprocessor=None,
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
encoder=dict(type='TFEncoder'),
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss'),
label_convertor=dict(
type='CTCConvertor',
dict_type='DICT36',
with_unknown=False,
lower=True),
pretrained=None)
model = EncodeDecodeRecognizer(
backbone=cfg['backbone'],
encoder=cfg['encoder'],
decoder=cfg['decoder'],
loss=cfg['loss'],
label_convertor=cfg['label_convertor'])
model.requires_grad_(False)
return model
def get_crnn_decoder_model(rnn_flag):
from mmocr.models.textrecog.decoders import CRNNDecoder
model = CRNNDecoder(32, 4, rnn_flag=rnn_flag)
model.requires_grad_(False)
return model
def get_fpnc_neck_model():
model = FPNCNeckModel([2, 4, 8, 16])
model.requires_grad_(False)
return model
def get_base_recognizer_model():
from mmocr.models.textrecog import CRNNNet
cfg = dict(
preprocessor=None,
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
encoder=None,
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss'),
label_convertor=dict(
type='CTCConvertor',
dict_type='DICT36',
with_unknown=False,
lower=True),
pretrained=None)
model = CRNNNet(
backbone=cfg['backbone'],
decoder=cfg['decoder'],
loss=cfg['loss'],
label_convertor=cfg['label_convertor'])
model.requires_grad_(False)
return model
@pytest.mark.parametrize('backend_type', ['ncnn'])
def test_bidirectionallstm(backend_type):
"""Test forward rewrite of bidirectionallstm."""
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
bilstm = get_bidirectionallstm_model()
bilstm.cpu().eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
task='TextRecognition',
)))
input = torch.rand(1, 1, 32)
# to get outputs of pytorch model
model_inputs = {
'input': input,
}
model_outputs = get_model_outputs(bilstm, 'forward', model_inputs)
# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(bilstm, 'forward')
rewrite_inputs = {'input': input}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
def test_simple_test_of_single_stage_text_detector():
"""Test simple_test single_stage_text_detector."""
single_stage_text_detector = get_single_stage_text_detector_model()
single_stage_text_detector.eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='default'),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
task='TextDetection',
)))
input = torch.rand(1, 3, 64, 64)
img_metas = [{
'ori_shape': [64, 64, 3],
'img_shape': [64, 64, 3],
'pad_shape': [64, 64, 3],
'scale_factor': [1., 1., 1., 1],
}]
x = single_stage_text_detector.extract_feat(input)
model_outputs = single_stage_text_detector.bbox_head(x)
wrapped_model = WrapModel(single_stage_text_detector, 'simple_test')
rewrite_inputs = {'img': input, 'img_metas': img_metas[0]}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend_type', ['ncnn'])
@pytest.mark.parametrize('rnn_flag', [True, False])
def test_crnndecoder(backend_type, rnn_flag):
"""Test forward rewrite of crnndecoder."""
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
crnn_decoder = get_crnn_decoder_model(rnn_flag)
crnn_decoder.cpu().eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
task='TextRecognition',
)))
input = torch.rand(1, 32, 1, 64)
out_enc = None
targets_dict = None
img_metas = None
# to get outputs of pytorch model
model_inputs = {
'feat': input,
'out_enc': out_enc,
'targets_dict': targets_dict,
'img_metas': img_metas
}
model_outputs = get_model_outputs(crnn_decoder, 'forward_train',
model_inputs)
# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(
crnn_decoder,
'forward_train',
out_enc=out_enc,
targets_dict=targets_dict,
img_metas=img_metas)
rewrite_inputs = {'feat': input}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize(
'img_metas', [[None], [{
'resize_shape': [32, 32],
'valid_ratio': 1.0
}]])
@pytest.mark.parametrize('is_dynamic', [True, False])
def test_forward_of_base_recognizer(img_metas, is_dynamic):
"""Test forward base_recognizer."""
base_recognizer = get_base_recognizer_model()
base_recognizer.eval()
if not is_dynamic:
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='ncnn'),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
task='TextRecognition',
)))
else:
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='ncnn'),
onnx_config=dict(
input_shape=None,
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
0: 'batch',
2: 'height',
3: 'width'
}
}),
codebase_config=dict(
type='mmocr',
task='TextRecognition',
)))
input = torch.rand(1, 1, 32, 32)
feat = base_recognizer.extract_feat(input)
out_enc = None
if base_recognizer.encoder is not None:
out_enc = base_recognizer.encoder(feat, img_metas)
model_outputs = base_recognizer.decoder(
feat, out_enc, None, img_metas, train_mode=False)
wrapped_model = WrapModel(
base_recognizer, 'forward', img_metas=img_metas[0])
rewrite_inputs = {
'img': input,
}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
def test_simple_test_of_encode_decode_recognizer():
"""Test simple_test encode_decode_recognizer."""
encode_decode_recognizer = get_encode_decode_recognizer_model()
encode_decode_recognizer.eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='default'),
onnx_config=dict(input_shape=None),
codebase_config=dict(
type='mmocr',
task='TextRecognition',
)))
input = torch.rand(1, 1, 32, 32)
img_metas = [{'resize_shape': [32, 32], 'valid_ratio': 1.0}]
feat = encode_decode_recognizer.extract_feat(input)
out_enc = None
if encode_decode_recognizer.encoder is not None:
out_enc = encode_decode_recognizer.encoder(feat, img_metas)
model_outputs = encode_decode_recognizer.decoder(
feat, out_enc, None, img_metas, train_mode=False)
wrapped_model = WrapModel(
encode_decode_recognizer, 'simple_test', img_metas=img_metas)
rewrite_inputs = {'img': input}
rewrite_outputs = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.parametrize('backend_type', ['tensorrt'])
def test_forward_of_fpnc(backend_type):
"""Test forward rewrite of fpnc."""
fpnc = get_fpnc_neck_model()
fpnc.eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(
type=backend_type,
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 64, 64],
opt_shape=[1, 3, 64, 64],
max_shape=[1, 3, 64, 64])))
]),
onnx_config=dict(input_shape=[64, 64], output_names=['output']),
codebase_config=dict(type='mmocr', task='TextDetection')))
input = torch.rand(1, 3, 64, 64).cuda()
model_inputs = {
'inputs': input,
}
model_outputs = get_model_outputs(fpnc, 'forward', model_inputs)
wrapped_model = WrapModel(fpnc, 'forward')
rewrite_inputs = {
'inputs': input,
}
rewrite_outputs, is_need_name = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_need_name:
model_output = model_outputs[0].squeeze().cpu().numpy()
rewrite_output = rewrite_outputs['output'].squeeze().cpu().numpy()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
else:
for model_output, rewrite_output in zip(model_outputs,
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze().cpu().numpy()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)