diff --git a/mmdeploy/mmocr/apis/inference.py b/mmdeploy/mmocr/apis/inference.py index 5cbaeea7c..0ba95b3bb 100644 --- a/mmdeploy/mmocr/apis/inference.py +++ b/mmdeploy/mmocr/apis/inference.py @@ -1,6 +1,7 @@ from typing import Iterable, Sequence, Union import mmcv +import numpy as np import torch from mmdet.models.builder import DETECTORS from mmocr.datasets import DATASETS @@ -60,6 +61,8 @@ class DeployBaseTextDetector(TextDetectorMixin, SingleStageTextDetector): list: A list of predictions. """ pred = self.forward_of_backend(img, img_metas, *args, **kwargs) + if isinstance(pred, np.ndarray): + pred = torch.from_numpy(pred[0]) if len(img_metas) > 1: boundaries = [ self.bbox_head.get_boundary( @@ -187,7 +190,6 @@ class ONNXRuntimeDetector(DeployBaseTextDetector): np.ndarray: Prediction of input model. """ onnx_pred = self.model({'input': img}) - onnx_pred = torch.from_numpy(onnx_pred[0]) return onnx_pred @@ -223,7 +225,6 @@ class ONNXRuntimeRecognizer(DeployBaseRecognizer): np.ndarray: Prediction of input model. """ onnx_pred = self.model({'input': img}) - onnx_pred = torch.from_numpy(onnx_pred[0]) return onnx_pred @@ -403,8 +404,8 @@ class PPLDetector(DeployBaseTextDetector): """ with torch.cuda.device(self.device_id), torch.no_grad(): ppl_pred = self.model({'input': img}) - - ppl_pred = torch.from_numpy(ppl_pred[0]) + if isinstance(ppl_pred[0], np.ndarray): + ppl_pred = torch.from_numpy(ppl_pred[0]) return ppl_pred @@ -442,7 +443,8 @@ class PPLRecognizer(DeployBaseRecognizer): """ with torch.cuda.device(self.device_id), torch.no_grad(): ppl_pred = self.model({'input': img})[0] - ppl_pred = torch.from_numpy(ppl_pred[0]) + if isinstance(ppl_pred[0], np.ndarray): + ppl_pred = torch.from_numpy(ppl_pred[0]) return ppl_pred diff --git a/mmdeploy/utils/test.py b/mmdeploy/utils/test.py index b383784e8..453227eaa 100644 --- a/mmdeploy/utils/test.py +++ b/mmdeploy/utils/test.py @@ -295,6 +295,8 @@ def get_rewrite_outputs(wrapped_model: nn.Module, backend_model = openvino_apis.OpenVINOWrapper(openvino_file_path) backend_feats = flatten_model_inputs + elif backend == Backend.DEFAULT: + return ctx_outputs, False else: raise NotImplementedError( f'Unimplemented backend type: {backend.value}') diff --git a/tests/test_mmocr/data/config/crnn.py b/tests/test_mmocr/data/config/crnn.py new file mode 100755 index 000000000..b92c1f4fe --- /dev/null +++ b/tests/test_mmocr/data/config/crnn.py @@ -0,0 +1,15 @@ +_base_ = [] + +# model +label_convertor = dict( + type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True) + +model = dict( + type='CRNNNet', + preprocessor=None, + backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1), + encoder=None, + decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), + loss=dict(type='CTCLoss'), + label_convertor=label_convertor, + pretrained=None) diff --git a/tests/test_mmocr/data/config/dbnet.py b/tests/test_mmocr/data/config/dbnet.py new file mode 100755 index 000000000..ce36ced57 --- /dev/null +++ b/tests/test_mmocr/data/config/dbnet.py @@ -0,0 +1,49 @@ +model = dict( + type='DBNet', + backbone=dict( + type='mmdet.ResNet', + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'), + norm_eval=False, + style='caffe'), + neck=dict(type='FPNC', in_channels=[2, 4, 8, 16], lateral_channels=8), + bbox_head=dict( + type='DBHead', + text_repr_type='quad', + in_channels=8, + loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2015' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +test_pipeline = [ + dict(type='LoadImageFromFile', color_type='color_ignore_orientation'), + dict( + type='MultiScaleFlipAug', + img_scale=(128, 64), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(256, 128), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=16, + test_dataloader=dict(samples_per_gpu=1), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) +evaluation = dict(interval=100, metric='hmean-iou') diff --git a/tests/test_mmocr/data/icdar2015/instances_test.json b/tests/test_mmocr/data/icdar2015/instances_test.json new file mode 100755 index 000000000..f24ae2a15 --- /dev/null +++ b/tests/test_mmocr/data/icdar2015/instances_test.json @@ -0,0 +1 @@ +{"images": [], "categories": [], "annotations": []} diff --git a/tests/test_mmocr/test_mmocr_apis.py b/tests/test_mmocr/test_mmocr_apis.py new file mode 100755 index 000000000..6b1aa6808 --- /dev/null +++ b/tests/test_mmocr/test_mmocr_apis.py @@ -0,0 +1,291 @@ +import importlib + +import mmcv +import numpy as np +import pytest +import torch + +import mmdeploy.apis.ncnn as ncnn_apis +import mmdeploy.apis.onnxruntime as ort_apis +import mmdeploy.apis.ppl as ppl_apis +import mmdeploy.apis.tensorrt as trt_apis +from mmdeploy.mmocr.apis.inference import get_classes_from_config +from mmdeploy.mmocr.apis.visualize import show_result +from mmdeploy.utils.test import SwitchBackendWrapper + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.skipif( + not importlib.util.find_spec('tensorrt'), reason='requires tensorrt') +def test_TensorRTDetector(): + # force add backend wrapper regardless of plugins + # make sure TensorRTDetector can use TRTWrapper inside itself + from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper + trt_apis.__dict__.update({'TRTWrapper': TRTWrapper}) + + outputs = { + 'output': torch.rand(1, 3, 64, 64).cuda(), + } + with SwitchBackendWrapper(TRTWrapper) as wrapper: + wrapper.set(outputs=outputs) + from mmdeploy.mmocr.apis.inference import TensorRTDetector + model_config = mmcv.Config.fromfile( + 'tests/test_mmocr/data/config/dbnet.py') + trt_detector = TensorRTDetector('', model_config, 0, False) + # watch from config + imgs = [torch.rand(1, 3, 64, 64).cuda()] + img_metas = [[{ + 'ori_shape': [64, 64, 3], + 'img_shape': [64, 64, 3], + 'pad_shape': [64, 64, 3], + 'scale_factor': [1., 1., 1., 1.], + }]] + + results = trt_detector.forward_of_backend(imgs, img_metas) + assert results is not None, 'failed to get output \ + using TensorRTDetector' + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_ONNXRuntimeDetector(): + # force add backend wrapper regardless of plugins + # make sure ONNXRuntimeDetector can use ORTWrapper inside itself + from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + outputs = (torch.rand(1, 3, 64, 64)) + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(outputs=outputs) + from mmdeploy.mmocr.apis.inference import ONNXRuntimeDetector + model_config = mmcv.Config.fromfile( + 'tests/test_mmocr/data/config/dbnet.py') + ort_detector = ONNXRuntimeDetector('', model_config, 0, False) + imgs = [torch.rand(1, 3, 64, 64)] + img_metas = [[{ + 'ori_shape': [64, 64, 3], + 'img_shape': [64, 64, 3], + 'pad_shape': [64, 64, 3], + 'scale_factor': [1., 1., 1., 1.], + }]] + + results = ort_detector.forward_of_backend(imgs, img_metas) + assert results is not None, 'failed to get output using '\ + 'ONNXRuntimeDetector' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.skipif( + not importlib.util.find_spec('pyppl'), reason='requires pyppl') +def test_PPLDetector(): + # force add backend wrapper regardless of plugins + # make sure PPLDetector can use PPLWrapper inside itself + from mmdeploy.apis.ppl.ppl_utils import PPLWrapper + ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper}) + + outputs = (torch.rand(1, 3, 64, 64)) + with SwitchBackendWrapper(PPLWrapper) as wrapper: + wrapper.set(outputs=outputs) + from mmdeploy.mmocr.apis.inference import PPLDetector + model_config = mmcv.Config.fromfile( + 'tests/test_mmocr/data/config/dbnet.py') + ppl_detector = PPLDetector('', model_config, 0, False) + imgs = [torch.rand(1, 3, 64, 64)] + img_metas = [[{ + 'ori_shape': [64, 64, 3], + 'img_shape': [64, 64, 3], + 'pad_shape': [64, 64, 3], + 'scale_factor': [1., 1., 1., 1.], + }]] + + results = ppl_detector.forward_of_backend(imgs, img_metas) + assert results is not None, 'failed to get output using PPLDetector' + + +@pytest.mark.skipif( + not importlib.util.find_spec('ncnn'), reason='requires ncnn') +def test_NCNNDetector(): + # force add backend wrapper regardless of plugins + # make sure NCNNDetector can use NCNNWrapper inside itself + from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper + ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper}) + + outputs = {'output': torch.rand(1, 3, 64, 64)} + with SwitchBackendWrapper(NCNNWrapper) as wrapper: + wrapper.set(outputs=outputs) + from mmdeploy.mmocr.apis.inference import NCNNDetector + model_config = mmcv.Config.fromfile( + 'tests/test_mmocr/data/config/dbnet.py') + ncnn_detector = NCNNDetector(['', ''], model_config, 0, False) + imgs = [torch.rand(1, 3, 64, 64)] + img_metas = [[{ + 'ori_shape': [64, 64, 3], + 'img_shape': [64, 64, 3], + 'pad_shape': [64, 64, 3], + 'scale_factor': [1., 1., 1., 1.], + }]] + + results = ncnn_detector.forward_of_backend(imgs, img_metas) + assert results is not None, 'failed to get output using NCNNDetector' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.skipif( + not importlib.util.find_spec('tensorrt'), reason='requires tensorrt') +def test_TensorRTRecognizer(): + # force add backend wrapper regardless of plugins + # make sure TensorRTRecognizer can use TRTWrapper inside itself + from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper + trt_apis.__dict__.update({'TRTWrapper': TRTWrapper}) + + outputs = { + 'output': torch.rand(1, 9, 37).cuda(), + } + with SwitchBackendWrapper(TRTWrapper) as wrapper: + wrapper.set(outputs=outputs) + from mmdeploy.mmocr.apis.inference import TensorRTRecognizer + model_config = mmcv.Config.fromfile( + 'tests/test_mmocr/data/config/crnn.py') + trt_recognizer = TensorRTRecognizer('', model_config, 0, False) + # watch from config + imgs = [torch.rand(1, 1, 32, 32).cuda()] + img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]] + + results = trt_recognizer.forward_of_backend(imgs, img_metas) + assert results is not None, 'failed to get output using \ + TensorRTRecognizer' + + +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_ONNXRuntimeRecognizer(): + # force add backend wrapper regardless of plugins + # make sure ONNXRuntimeRecognizer can use ORTWrapper inside itself + from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + outputs = (torch.rand(1, 9, 37)) + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(outputs=outputs) + from mmdeploy.mmocr.apis.inference import ONNXRuntimeRecognizer + model_config = mmcv.Config.fromfile( + 'tests/test_mmocr/data/config/crnn.py') + ort_recognizer = ONNXRuntimeRecognizer('', model_config, 0, False) + imgs = [torch.rand(1, 1, 32, 32).numpy()] + img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]] + + results = ort_recognizer.forward_of_backend(imgs, img_metas) + assert results is not None, 'failed to get output using '\ + 'ONNXRuntimeRecognizer' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.skipif( + not importlib.util.find_spec('pyppl'), reason='requires pyppl') +def test_PPLRecognizer(): + # force add backend wrapper regardless of plugins + # make sure PPLRecognizer can use PPLWrapper inside itself + from mmdeploy.apis.ppl.ppl_utils import PPLWrapper + ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper}) + + outputs = (torch.rand(1, 9, 37)) + with SwitchBackendWrapper(PPLWrapper) as wrapper: + wrapper.set(outputs=outputs) + from mmdeploy.mmocr.apis.inference import PPLRecognizer + model_config = mmcv.Config.fromfile( + 'tests/test_mmocr/data/config/crnn.py') + ppl_recognizer = PPLRecognizer('', model_config, 0, False) + imgs = [torch.rand(1, 1, 32, 32)] + img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]] + + results = ppl_recognizer.forward_of_backend(imgs, img_metas) + assert results is not None, 'failed to get output using PPLRecognizer' + + +@pytest.mark.skipif( + not importlib.util.find_spec('ncnn'), reason='requires ncnn') +def test_NCNNRecognizer(): + # force add backend wrapper regardless of plugins + # make sure NCNNPSSDetector can use NCNNWrapper inside itself + from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper + ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper}) + + outputs = {'output': torch.rand(1, 9, 37)} + with SwitchBackendWrapper(NCNNWrapper) as wrapper: + wrapper.set(outputs=outputs) + from mmdeploy.mmocr.apis.inference import NCNNRecognizer + model_config = mmcv.Config.fromfile( + 'tests/test_mmocr/data/config/crnn.py') + ncnn_recognizer = NCNNRecognizer(['', ''], model_config, 0, False) + imgs = [torch.rand(1, 1, 32, 32)] + img_metas = [[{'resize_shape': [32, 32], 'valid_ratio': 1.0}]] + + results = ncnn_recognizer.forward_of_backend(imgs, img_metas) + assert results is not None, 'failed to get output using NCNNRecognizer' + + +@pytest.mark.parametrize( + 'task, model', + [('TextDetection', 'tests/test_mmocr/data/config/dbnet.py'), + ('TextRecognition', 'tests/test_mmocr/data/config/crnn.py')]) +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_build_ocr_processor(task, model): + model_cfg = mmcv.Config.fromfile(model) + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmocr', task=task))) + + from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg) + from mmdeploy.apis.utils import init_backend_model + result_model = init_backend_model([''], model_cfg, deploy_cfg, -1) + assert result_model is not None + + +@pytest.mark.parametrize('model', ['tests/test_mmocr/data/config/dbnet.py']) +def test_get_classes_from_config(model): + get_classes_from_config(model) + + +@pytest.mark.parametrize( + 'task, model', + [('TextDetection', 'tests/test_mmocr/data/config/dbnet.py')]) +@pytest.mark.skipif( + not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime') +def test_show_result(task, model): + model_cfg = mmcv.Config.fromfile(model) + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmocr', task=task))) + + from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg) + from mmdeploy.apis.utils import init_backend_model + detector = init_backend_model([''], model_cfg, deploy_cfg, -1) + img = np.random.random((64, 64, 3)) + result = {'boundary_result': [[1, 2, 3, 4, 5], [2, 2, 0, 4, 5]]} + import os.path + import tempfile + + from mmdeploy.utils.constants import Backend + with tempfile.TemporaryDirectory() as dir: + filename = dir + 'tmp.jpg' + show_result( + detector, + img, + result, + filename, + Backend.ONNXRUNTIME, + show=False) + assert os.path.exists(filename) diff --git a/tests/test_mmocr/test_mmocr_export.py b/tests/test_mmocr/test_mmocr_export.py new file mode 100644 index 000000000..6160c4d65 --- /dev/null +++ b/tests/test_mmocr/test_mmocr_export.py @@ -0,0 +1,135 @@ +import mmcv +import numpy as np +import pytest + +from mmdeploy.apis.utils import build_dataloader, build_dataset, create_input +from mmdeploy.utils.constants import Codebase, Task + + +@pytest.mark.parametrize('task', [Task.TEXT_DETECTION, Task.TEXT_RECOGNITION]) +def test_create_input(task): + if task == Task.TEXT_DETECTION: + test = dict( + type='IcdarDataset', + pipeline=[{ + 'type': 'LoadImageFromFile', + 'color_type': 'color_ignore_orientation' + }, { + 'type': + 'MultiScaleFlipAug', + 'img_scale': (128, 64), + 'flip': + False, + 'transforms': [ + { + 'type': 'Resize', + 'img_scale': (256, 128), + 'keep_ratio': True + }, + { + 'type': 'Normalize', + 'mean': [123.675, 116.28, 103.53], + 'std': [58.395, 57.12, 57.375], + 'to_rgb': True + }, + { + 'type': 'Pad', + 'size_divisor': 32 + }, + { + 'type': 'ImageToTensor', + 'keys': ['img'] + }, + { + 'type': 'Collect', + 'keys': ['img'] + }, + ] + }]) + imgs = [np.random.rand(128, 64, 3).astype(np.uint8)] + elif task == Task.TEXT_RECOGNITION: + test = dict( + type='UniformConcatDataset', + pipeline=[ + { + 'type': 'LoadImageFromFile', + 'color_type': 'grayscale' + }, + { + 'type': 'ResizeOCR', + 'height': 32, + 'min_width': 32, + 'max_width': None, + 'keep_aspect_ratio': True + }, + { + 'type': 'Normalize', + 'mean': [127], + 'std': [127] + }, + { + 'type': 'DefaultFormatBundle' + }, + { + 'type': 'Collect', + 'keys': ['img'], + 'meta_keys': ['filename', 'resize_shape', 'valid_ratio'] + }, + ]) + imgs = [np.random.random((32, 32, 3)).astype(np.uint8)] + data = dict(test=test) + model_cfg = mmcv.Config(dict(data=data)) + inputs = create_input( + Codebase.MMOCR, + task, + model_cfg, + imgs, + input_shape=imgs[0].shape[0:2], + device='cpu') + assert inputs is not None, 'Failed to create input' + + +@pytest.mark.parametrize('task', [Task.TEXT_DETECTION, Task.TEXT_RECOGNITION]) +def test_build_dataset(task): + import tempfile + import os + ann_file, ann_path = tempfile.mkstemp() + if task == Task.TEXT_DETECTION: + data = dict( + test={ + 'type': 'IcdarDataset', + 'ann_file': + 'tests/test_mmocr/data/icdar2015/instances_test.json', + 'img_prefix': 'tests/test_mmocr/data/icdar2015/imgs', + 'pipeline': [ + { + 'type': 'LoadImageFromFile' + }, + ] + }) + elif task == Task.TEXT_RECOGNITION: + data = dict( + test={ + 'type': 'OCRDataset', + 'ann_file': ann_path, + 'img_prefix': '', + 'loader': { + 'type': 'HardDiskLoader', + 'repeat': 1, + 'parser': { + 'type': 'LineStrParser', + 'keys': ['filename', 'text'], + 'keys_idx': [0, 1], + 'separator': ' ' + } + }, + 'pipeline': [], + 'test_mode': True + }) + dataset_cfg = mmcv.Config(dict(data=data)) + dataset = build_dataset( + Codebase.MMOCR, dataset_cfg=dataset_cfg, dataset_type='test') + assert dataset is not None, 'Failed to build dataset' + dataloader = build_dataloader(Codebase.MMOCR, dataset, 1, 1) + os.close(ann_file) + assert dataloader is not None, 'Failed to build dataloader' diff --git a/tests/test_mmocr/test_mmocr_models.py b/tests/test_mmocr/test_mmocr_models.py new file mode 100644 index 000000000..cd8718d7c --- /dev/null +++ b/tests/test_mmocr/test_mmocr_models.py @@ -0,0 +1,409 @@ +import mmcv +import numpy as np +import pytest +import torch +from mmocr.models.textdet.necks import FPNC + +from mmdeploy.utils.test import (WrapModel, get_model_outputs, + get_rewrite_outputs) + + +class FPNCNeckModel(FPNC): + + def __init__(self, in_channels, init_cfg=None): + super().__init__(in_channels, init_cfg=init_cfg) + self.in_channels = in_channels + self.neck = FPNC(in_channels, init_cfg=init_cfg) + + def forward(self, inputs): + neck_inputs = [ + torch.ones(1, channel, inputs.shape[-2], inputs.shape[-1]) + for channel in self.in_channels + ] + output = self.neck.forward(neck_inputs) + return output + + +def get_bidirectionallstm_model(): + from mmocr.models.textrecog.layers.lstm_layer import BidirectionalLSTM + model = BidirectionalLSTM(32, 16, 16) + + model.requires_grad_(False) + return model + + +def get_single_stage_text_detector_model(): + from mmocr.models.textdet import SingleStageTextDetector + backbone = dict( + type='mmdet.ResNet', + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'), + norm_eval=False, + style='caffe') + neck = dict( + type='FPNC', + in_channels=[64, 128, 256, 512], + lateral_channels=4, + out_channels=4) + bbox_head = dict( + type='DBHead', + text_repr_type='quad', + in_channels=16, + loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)) + model = SingleStageTextDetector(backbone, neck, bbox_head) + + model.requires_grad_(False) + return model + + +def get_encode_decode_recognizer_model(): + from mmocr.models.textrecog import EncodeDecodeRecognizer + + cfg = dict( + preprocessor=None, + backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1), + encoder=dict(type='TFEncoder'), + decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), + loss=dict(type='CTCLoss'), + label_convertor=dict( + type='CTCConvertor', + dict_type='DICT36', + with_unknown=False, + lower=True), + pretrained=None) + + model = EncodeDecodeRecognizer( + backbone=cfg['backbone'], + encoder=cfg['encoder'], + decoder=cfg['decoder'], + loss=cfg['loss'], + label_convertor=cfg['label_convertor']) + model.requires_grad_(False) + return model + + +def get_crnn_decoder_model(rnn_flag): + from mmocr.models.textrecog.decoders import CRNNDecoder + model = CRNNDecoder(32, 4, rnn_flag=rnn_flag) + + model.requires_grad_(False) + return model + + +def get_fpnc_neck_model(): + model = FPNCNeckModel([2, 4, 8, 16]) + + model.requires_grad_(False) + return model + + +def get_base_recognizer_model(): + from mmocr.models.textrecog import CRNNNet + + cfg = dict( + preprocessor=None, + backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1), + encoder=None, + decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), + loss=dict(type='CTCLoss'), + label_convertor=dict( + type='CTCConvertor', + dict_type='DICT36', + with_unknown=False, + lower=True), + pretrained=None) + + model = CRNNNet( + backbone=cfg['backbone'], + decoder=cfg['decoder'], + loss=cfg['loss'], + label_convertor=cfg['label_convertor']) + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', ['ncnn']) +def test_bidirectionallstm(backend_type): + """Test forward rewrite of bidirectionallstm.""" + pytest.importorskip(backend_type, reason=f'requires {backend_type}') + bilstm = get_bidirectionallstm_model() + bilstm.cpu().eval() + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type), + onnx_config=dict(input_shape=None), + codebase_config=dict( + type='mmocr', + task='TextRecognition', + ))) + + input = torch.rand(1, 1, 32) + + # to get outputs of pytorch model + model_inputs = { + 'input': input, + } + model_outputs = get_model_outputs(bilstm, 'forward', model_inputs) + + # to get outputs of onnx model after rewrite + wrapped_model = WrapModel(bilstm, 'forward') + rewrite_inputs = {'input': input} + rewrite_outputs = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05) + + +def test_simple_test_of_single_stage_text_detector(): + """Test simple_test single_stage_text_detector.""" + single_stage_text_detector = get_single_stage_text_detector_model() + single_stage_text_detector.eval() + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='default'), + onnx_config=dict(input_shape=None), + codebase_config=dict( + type='mmocr', + task='TextDetection', + ))) + + input = torch.rand(1, 3, 64, 64) + img_metas = [{ + 'ori_shape': [64, 64, 3], + 'img_shape': [64, 64, 3], + 'pad_shape': [64, 64, 3], + 'scale_factor': [1., 1., 1., 1], + }] + + x = single_stage_text_detector.extract_feat(input) + model_outputs = single_stage_text_detector.bbox_head(x) + + wrapped_model = WrapModel(single_stage_text_detector, 'simple_test') + rewrite_inputs = {'img': input, 'img_metas': img_metas[0]} + rewrite_outputs = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05) + + +@pytest.mark.parametrize('backend_type', ['ncnn']) +@pytest.mark.parametrize('rnn_flag', [True, False]) +def test_crnndecoder(backend_type, rnn_flag): + """Test forward rewrite of crnndecoder.""" + pytest.importorskip(backend_type, reason=f'requires {backend_type}') + crnn_decoder = get_crnn_decoder_model(rnn_flag) + crnn_decoder.cpu().eval() + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type), + onnx_config=dict(input_shape=None), + codebase_config=dict( + type='mmocr', + task='TextRecognition', + ))) + + input = torch.rand(1, 32, 1, 64) + out_enc = None + targets_dict = None + img_metas = None + + # to get outputs of pytorch model + model_inputs = { + 'feat': input, + 'out_enc': out_enc, + 'targets_dict': targets_dict, + 'img_metas': img_metas + } + model_outputs = get_model_outputs(crnn_decoder, 'forward_train', + model_inputs) + + # to get outputs of onnx model after rewrite + wrapped_model = WrapModel( + crnn_decoder, + 'forward_train', + out_enc=out_enc, + targets_dict=targets_dict, + img_metas=img_metas) + rewrite_inputs = {'feat': input} + rewrite_outputs = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05) + + +@pytest.mark.parametrize( + 'img_metas', [[None], [{ + 'resize_shape': [32, 32], + 'valid_ratio': 1.0 + }]]) +@pytest.mark.parametrize('is_dynamic', [True, False]) +def test_forward_of_base_recognizer(img_metas, is_dynamic): + """Test forward base_recognizer.""" + base_recognizer = get_base_recognizer_model() + base_recognizer.eval() + + if not is_dynamic: + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='ncnn'), + onnx_config=dict(input_shape=None), + codebase_config=dict( + type='mmocr', + task='TextRecognition', + ))) + else: + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='ncnn'), + onnx_config=dict( + input_shape=None, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 0: 'batch', + 2: 'height', + 3: 'width' + } + }), + codebase_config=dict( + type='mmocr', + task='TextRecognition', + ))) + + input = torch.rand(1, 1, 32, 32) + + feat = base_recognizer.extract_feat(input) + out_enc = None + if base_recognizer.encoder is not None: + out_enc = base_recognizer.encoder(feat, img_metas) + model_outputs = base_recognizer.decoder( + feat, out_enc, None, img_metas, train_mode=False) + wrapped_model = WrapModel( + base_recognizer, 'forward', img_metas=img_metas[0]) + rewrite_inputs = { + 'img': input, + } + rewrite_outputs = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05) + + +def test_simple_test_of_encode_decode_recognizer(): + """Test simple_test encode_decode_recognizer.""" + encode_decode_recognizer = get_encode_decode_recognizer_model() + encode_decode_recognizer.eval() + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='default'), + onnx_config=dict(input_shape=None), + codebase_config=dict( + type='mmocr', + task='TextRecognition', + ))) + + input = torch.rand(1, 1, 32, 32) + img_metas = [{'resize_shape': [32, 32], 'valid_ratio': 1.0}] + + feat = encode_decode_recognizer.extract_feat(input) + out_enc = None + if encode_decode_recognizer.encoder is not None: + out_enc = encode_decode_recognizer.encoder(feat, img_metas) + model_outputs = encode_decode_recognizer.decoder( + feat, out_enc, None, img_metas, train_mode=False) + + wrapped_model = WrapModel( + encode_decode_recognizer, 'simple_test', img_metas=img_metas) + rewrite_inputs = {'img': input} + rewrite_outputs = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.parametrize('backend_type', ['tensorrt']) +def test_forward_of_fpnc(backend_type): + """Test forward rewrite of fpnc.""" + fpnc = get_fpnc_neck_model() + fpnc.eval() + deploy_cfg = mmcv.Config( + dict( + backend_config=dict( + type=backend_type, + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 64, 64], + opt_shape=[1, 3, 64, 64], + max_shape=[1, 3, 64, 64]))) + ]), + onnx_config=dict(input_shape=[64, 64], output_names=['output']), + codebase_config=dict(type='mmocr', task='TextDetection'))) + + input = torch.rand(1, 3, 64, 64).cuda() + model_inputs = { + 'inputs': input, + } + model_outputs = get_model_outputs(fpnc, 'forward', model_inputs) + wrapped_model = WrapModel(fpnc, 'forward') + rewrite_inputs = { + 'inputs': input, + } + rewrite_outputs, is_need_name = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + if is_need_name: + model_output = model_outputs[0].squeeze().cpu().numpy() + rewrite_output = rewrite_outputs['output'].squeeze().cpu().numpy() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05) + else: + for model_output, rewrite_output in zip(model_outputs, + rewrite_outputs): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze().cpu().numpy() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05)