import os.path as osp import tempfile from functools import partial import numpy as np import pytest import torch from mmdet.core import BitmapMasks from mmocr.models.textrecog.recognizer import (EncodeDecodeRecognizer, SegRecognizer) def _create_dummy_dict_file(dict_file): chars = list('helowrd') with open(dict_file, 'w') as fw: for char in chars: fw.write(char + '\n') def test_base_recognizer(): tmp_dir = tempfile.TemporaryDirectory() # create dummy data dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') _create_dummy_dict_file(dict_file) label_convertor = dict( type='CTCConvertor', dict_file=dict_file, with_unknown=False) preprocessor = None backbone = dict(type='VeryDeepVgg', leaky_relu=False) encoder = None decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True) loss = dict(type='CTCLoss') with pytest.raises(AssertionError): EncodeDecodeRecognizer(backbone=None) with pytest.raises(AssertionError): EncodeDecodeRecognizer(decoder=None) with pytest.raises(AssertionError): EncodeDecodeRecognizer(loss=None) with pytest.raises(AssertionError): EncodeDecodeRecognizer(label_convertor=None) recognizer = EncodeDecodeRecognizer( preprocessor=preprocessor, backbone=backbone, encoder=encoder, decoder=decoder, loss=loss, label_convertor=label_convertor) recognizer.init_weights() recognizer.train() imgs = torch.rand(1, 3, 32, 160) # test extract feat feat = recognizer.extract_feat(imgs) assert feat.shape == torch.Size([1, 512, 1, 41]) # test forward train img_metas = [{'text': 'hello', 'valid_ratio': 1.0}] losses = recognizer.forward_train(imgs, img_metas) assert isinstance(losses, dict) assert 'loss_ctc' in losses # test simple test results = recognizer.simple_test(imgs, img_metas) assert isinstance(results, list) assert isinstance(results[0], dict) assert 'text' in results[0] assert 'score' in results[0] # test onnx export recognizer.forward = partial( recognizer.simple_test, img_metas=img_metas, return_loss=False, rescale=True) with tempfile.TemporaryDirectory() as tmpdirname: onnx_path = f'{tmpdirname}/tmp.onnx' torch.onnx.export( recognizer, (imgs, ), onnx_path, input_names=['input'], output_names=['output'], export_params=True, keep_initializers_as_inputs=False) # test aug_test aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas]) assert isinstance(aug_results, list) assert isinstance(aug_results[0], dict) assert 'text' in aug_results[0] assert 'score' in aug_results[0] tmp_dir.cleanup() def test_seg_recognizer(): tmp_dir = tempfile.TemporaryDirectory() # create dummy data dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') _create_dummy_dict_file(dict_file) label_convertor = dict( type='SegConvertor', dict_file=dict_file, with_unknown=False) preprocessor = None backbone = dict( type='ResNet31OCR', layers=[1, 2, 5, 3], channels=[32, 64, 128, 256, 512, 512], out_indices=[0, 1, 2, 3], stage4_pool_cfg=dict(kernel_size=2, stride=2), last_stage_pool=True) neck = dict( type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256) head = dict( type='SegHead', in_channels=256, upsample_param=dict(scale_factor=2.0, mode='nearest')) loss = dict(type='SegLoss', seg_downsample_ratio=1.0) with pytest.raises(AssertionError): SegRecognizer(backbone=None) with pytest.raises(AssertionError): SegRecognizer(neck=None) with pytest.raises(AssertionError): SegRecognizer(head=None) with pytest.raises(AssertionError): SegRecognizer(loss=None) with pytest.raises(AssertionError): SegRecognizer(label_convertor=None) recognizer = SegRecognizer( preprocessor=preprocessor, backbone=backbone, neck=neck, head=head, loss=loss, label_convertor=label_convertor) recognizer.init_weights() recognizer.train() imgs = torch.rand(1, 3, 64, 256) # test extract feat feats = recognizer.extract_feat(imgs) assert len(feats) == 4 assert feats[0].shape == torch.Size([1, 128, 32, 128]) assert feats[1].shape == torch.Size([1, 256, 16, 64]) assert feats[2].shape == torch.Size([1, 512, 8, 32]) assert feats[3].shape == torch.Size([1, 512, 4, 16]) attn_tgt = np.zeros((64, 256), dtype=np.float32) segm_tgt = np.zeros((64, 256), dtype=np.float32) mask = np.zeros((64, 256), dtype=np.float32) gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], 64, 256) # test forward train img_metas = [{'text': 'hello', 'valid_ratio': 1.0}] losses = recognizer.forward_train(imgs, img_metas, gt_kernels=[gt_kernels]) assert isinstance(losses, dict) # test simple test results = recognizer.simple_test(imgs, img_metas) assert isinstance(results, list) assert isinstance(results[0], dict) assert 'text' in results[0] assert 'score' in results[0] # test aug_test aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas]) assert isinstance(aug_results, list) assert isinstance(aug_results[0], dict) assert 'text' in aug_results[0] assert 'score' in aug_results[0] tmp_dir.cleanup()