mmocr/tests/test_models/test_recognizer.py

193 lines
5.7 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2021-04-02 23:54:57 +08:00
import os.path as osp
import tempfile
from functools import partial
2021-04-02 23:54:57 +08:00
import numpy as np
import pytest
import torch
from mmdet.core import BitmapMasks
2021-04-02 23:54:57 +08:00
from mmocr.models.textrecog.recognizer import (EncodeDecodeRecognizer,
SegRecognizer)
def _create_dummy_dict_file(dict_file):
chars = list('helowrd')
with open(dict_file, 'w') as fw:
for char in chars:
fw.write(char + '\n')
def test_base_recognizer():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
_create_dummy_dict_file(dict_file)
label_convertor = dict(
type='CTCConvertor', dict_file=dict_file, with_unknown=False)
preprocessor = None
backbone = dict(type='VeryDeepVgg', leaky_relu=False)
2021-04-02 23:54:57 +08:00
encoder = None
decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True)
loss = dict(type='CTCLoss')
with pytest.raises(AssertionError):
EncodeDecodeRecognizer(backbone=None)
with pytest.raises(AssertionError):
EncodeDecodeRecognizer(decoder=None)
with pytest.raises(AssertionError):
EncodeDecodeRecognizer(loss=None)
with pytest.raises(AssertionError):
EncodeDecodeRecognizer(label_convertor=None)
recognizer = EncodeDecodeRecognizer(
preprocessor=preprocessor,
backbone=backbone,
encoder=encoder,
decoder=decoder,
loss=loss,
label_convertor=label_convertor)
recognizer.init_weights()
recognizer.train()
imgs = torch.rand(1, 3, 32, 160)
# test extract feat
feat = recognizer.extract_feat(imgs)
assert feat.shape == torch.Size([1, 512, 1, 41])
# test forward train
img_metas = [{
'text': 'hello',
'resize_shape': (32, 120, 3),
'valid_ratio': 1.0
}]
2021-04-02 23:54:57 +08:00
losses = recognizer.forward_train(imgs, img_metas)
assert isinstance(losses, dict)
assert 'loss_ctc' in losses
# test simple test
results = recognizer.simple_test(imgs, img_metas)
assert isinstance(results, list)
assert isinstance(results[0], dict)
assert 'text' in results[0]
assert 'score' in results[0]
# test onnx export
recognizer.forward = partial(
recognizer.simple_test,
img_metas=img_metas,
return_loss=False,
rescale=True)
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_path = f'{tmpdirname}/tmp.onnx'
torch.onnx.export(
recognizer, (imgs, ),
onnx_path,
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=False)
2021-04-02 23:54:57 +08:00
# test aug_test
aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
assert isinstance(aug_results, list)
assert isinstance(aug_results[0], dict)
assert 'text' in aug_results[0]
assert 'score' in aug_results[0]
tmp_dir.cleanup()
def test_seg_recognizer():
tmp_dir = tempfile.TemporaryDirectory()
# create dummy data
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
_create_dummy_dict_file(dict_file)
label_convertor = dict(
type='SegConvertor', dict_file=dict_file, with_unknown=False)
preprocessor = None
2021-04-05 21:11:50 +08:00
backbone = dict(
type='ResNet31OCR',
layers=[1, 2, 5, 3],
channels=[32, 64, 128, 256, 512, 512],
out_indices=[0, 1, 2, 3],
stage4_pool_cfg=dict(kernel_size=2, stride=2),
last_stage_pool=True)
neck = dict(
type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256)
head = dict(
type='SegHead',
in_channels=256,
upsample_param=dict(scale_factor=2.0, mode='nearest'))
loss = dict(type='SegLoss', seg_downsample_ratio=1.0)
2021-04-02 23:54:57 +08:00
with pytest.raises(AssertionError):
SegRecognizer(backbone=None)
with pytest.raises(AssertionError):
SegRecognizer(neck=None)
with pytest.raises(AssertionError):
SegRecognizer(head=None)
with pytest.raises(AssertionError):
SegRecognizer(loss=None)
with pytest.raises(AssertionError):
SegRecognizer(label_convertor=None)
recognizer = SegRecognizer(
preprocessor=preprocessor,
backbone=backbone,
neck=neck,
head=head,
loss=loss,
label_convertor=label_convertor)
recognizer.init_weights()
recognizer.train()
imgs = torch.rand(1, 3, 64, 256)
# test extract feat
feats = recognizer.extract_feat(imgs)
2021-04-05 21:11:50 +08:00
assert len(feats) == 4
assert feats[0].shape == torch.Size([1, 128, 32, 128])
assert feats[1].shape == torch.Size([1, 256, 16, 64])
assert feats[2].shape == torch.Size([1, 512, 8, 32])
assert feats[3].shape == torch.Size([1, 512, 4, 16])
2021-04-02 23:54:57 +08:00
attn_tgt = np.zeros((64, 256), dtype=np.float32)
segm_tgt = np.zeros((64, 256), dtype=np.float32)
2021-04-05 21:11:50 +08:00
mask = np.zeros((64, 256), dtype=np.float32)
gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], 64, 256)
2021-04-02 23:54:57 +08:00
# test forward train
img_metas = [{
'text': 'hello',
'resize_shape': (64, 256, 3),
'valid_ratio': 1.0
}]
2021-04-02 23:54:57 +08:00
losses = recognizer.forward_train(imgs, img_metas, gt_kernels=[gt_kernels])
assert isinstance(losses, dict)
# test simple test
results = recognizer.simple_test(imgs, img_metas)
assert isinstance(results, list)
assert isinstance(results[0], dict)
assert 'text' in results[0]
assert 'score' in results[0]
# test aug_test
aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
assert isinstance(aug_results, list)
assert isinstance(aug_results[0], dict)
assert 'text' in aug_results[0]
assert 'score' in aug_results[0]
tmp_dir.cleanup()