mirror of https://github.com/open-mmlab/mmocr.git
421 lines
16 KiB
Python
421 lines
16 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import io
|
|
import json
|
|
import os
|
|
import platform
|
|
import random
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmocr.apis import init_detector
|
|
from mmocr.datasets.kie_dataset import KIEDataset
|
|
from mmocr.utils.ocr import MMOCR
|
|
|
|
|
|
def test_ocr_init_errors():
|
|
# Test assertions
|
|
with pytest.raises(ValueError):
|
|
_ = MMOCR(det='test')
|
|
with pytest.raises(ValueError):
|
|
_ = MMOCR(recog='test')
|
|
with pytest.raises(ValueError):
|
|
_ = MMOCR(kie='test')
|
|
with pytest.raises(NotImplementedError):
|
|
_ = MMOCR(det=None, recog=None, kie='SDMGR')
|
|
with pytest.raises(NotImplementedError):
|
|
_ = MMOCR(det='DB_r18', recog=None, kie='SDMGR')
|
|
|
|
|
|
cfg_default_prefix = os.path.join(str(Path.cwd()), 'configs/')
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'det, recog, kie, config_dir, gt_cfg, gt_ckpt',
|
|
[('DB_r18', None, '', '',
|
|
cfg_default_prefix + 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
|
'https://download.openmmlab.com/mmocr/textdet/'
|
|
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'),
|
|
(None, 'CRNN', '', '',
|
|
cfg_default_prefix + 'textrecog/crnn/crnn_academic_dataset.py',
|
|
'https://download.openmmlab.com/mmocr/textrecog/'
|
|
'crnn/crnn_academic-a723a1c5.pth'),
|
|
('DB_r18', 'CRNN', 'SDMGR', '', [
|
|
cfg_default_prefix +
|
|
'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
|
cfg_default_prefix + 'textrecog/crnn/crnn_academic_dataset.py',
|
|
cfg_default_prefix + 'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py'
|
|
], [
|
|
'https://download.openmmlab.com/mmocr/textdet/'
|
|
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth',
|
|
'https://download.openmmlab.com/mmocr/textrecog/'
|
|
'crnn/crnn_academic-a723a1c5.pth',
|
|
'https://download.openmmlab.com/mmocr/kie/'
|
|
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
|
]),
|
|
('DB_r18', 'CRNN', 'SDMGR', 'test/', [
|
|
'test/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
|
'test/textrecog/crnn/crnn_academic_dataset.py',
|
|
'test/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py'
|
|
], [
|
|
'https://download.openmmlab.com/mmocr/textdet/'
|
|
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth',
|
|
'https://download.openmmlab.com/mmocr/textrecog/'
|
|
'crnn/crnn_academic-a723a1c5.pth',
|
|
'https://download.openmmlab.com/mmocr/kie/'
|
|
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
|
])],
|
|
)
|
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
|
@mock.patch('mmocr.utils.ocr.build_detector')
|
|
@mock.patch('mmocr.utils.ocr.Config.fromfile')
|
|
@mock.patch('mmocr.utils.ocr.load_checkpoint')
|
|
def test_ocr_init(mock_loading, mock_config, mock_build_detector,
|
|
mock_init_detector, det, recog, kie, config_dir, gt_cfg,
|
|
gt_ckpt):
|
|
|
|
def loadcheckpoint_assert(*args, **kwargs):
|
|
assert args[1] == gt_ckpt[-1]
|
|
assert kwargs['map_location'] == torch.device(
|
|
'cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
mock_loading.side_effect = loadcheckpoint_assert
|
|
with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'):
|
|
if kie == '':
|
|
if config_dir == '':
|
|
_ = MMOCR(det=det, recog=recog)
|
|
else:
|
|
_ = MMOCR(det=det, recog=recog, config_dir=config_dir)
|
|
else:
|
|
if config_dir == '':
|
|
_ = MMOCR(det=det, recog=recog, kie=kie)
|
|
else:
|
|
_ = MMOCR(det=det, recog=recog, kie=kie, config_dir=config_dir)
|
|
if isinstance(gt_cfg, str):
|
|
gt_cfg = [gt_cfg]
|
|
if isinstance(gt_ckpt, str):
|
|
gt_ckpt = [gt_ckpt]
|
|
|
|
i_range = range(len(gt_cfg))
|
|
if kie:
|
|
i_range = i_range[:-1]
|
|
mock_config.assert_called_with(gt_cfg[-1])
|
|
mock_build_detector.assert_called_once()
|
|
mock_loading.assert_called_once()
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
calls = [
|
|
mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range
|
|
]
|
|
mock_init_detector.assert_has_calls(calls)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'det, det_config, det_ckpt, recog, recog_config, recog_ckpt,'
|
|
'kie, kie_config, kie_ckpt, config_dir, gt_cfg, gt_ckpt',
|
|
[('DB_r18', 'test.py', '', 'CRNN', 'test.py', '', 'SDMGR', 'test.py', '',
|
|
'configs/', ['test.py', 'test.py', 'test.py'], [
|
|
'https://download.openmmlab.com/mmocr/textdet/'
|
|
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth',
|
|
'https://download.openmmlab.com/mmocr/textrecog/'
|
|
'crnn/crnn_academic-a723a1c5.pth',
|
|
'https://download.openmmlab.com/mmocr/kie/'
|
|
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
|
]),
|
|
('DB_r18', '', 'test.ckpt', 'CRNN', '', 'test.ckpt', 'SDMGR', '',
|
|
'test.ckpt', '', [
|
|
'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
|
'textrecog/crnn/crnn_academic_dataset.py',
|
|
'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py'
|
|
], ['test.ckpt', 'test.ckpt', 'test.ckpt']),
|
|
('DB_r18', 'test.py', 'test.ckpt', 'CRNN', 'test.py', 'test.ckpt',
|
|
'SDMGR', 'test.py', 'test.ckpt', '', ['test.py', 'test.py', 'test.py'],
|
|
['test.ckpt', 'test.ckpt', 'test.ckpt'])])
|
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
|
@mock.patch('mmocr.utils.ocr.build_detector')
|
|
@mock.patch('mmocr.utils.ocr.Config.fromfile')
|
|
@mock.patch('mmocr.utils.ocr.load_checkpoint')
|
|
def test_ocr_init_customize_config(mock_loading, mock_config,
|
|
mock_build_detector, mock_init_detector,
|
|
det, det_config, det_ckpt, recog,
|
|
recog_config, recog_ckpt, kie, kie_config,
|
|
kie_ckpt, config_dir, gt_cfg, gt_ckpt):
|
|
|
|
def loadcheckpoint_assert(*args, **kwargs):
|
|
assert args[1] == gt_ckpt[-1]
|
|
|
|
mock_loading.side_effect = loadcheckpoint_assert
|
|
with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'):
|
|
_ = MMOCR(
|
|
det=det,
|
|
det_config=det_config,
|
|
det_ckpt=det_ckpt,
|
|
recog=recog,
|
|
recog_config=recog_config,
|
|
recog_ckpt=recog_ckpt,
|
|
kie=kie,
|
|
kie_config=kie_config,
|
|
kie_ckpt=kie_ckpt,
|
|
config_dir=config_dir)
|
|
|
|
i_range = range(len(gt_cfg))
|
|
if kie:
|
|
i_range = i_range[:-1]
|
|
mock_config.assert_called_with(gt_cfg[-1])
|
|
mock_build_detector.assert_called_once()
|
|
mock_loading.assert_called_once()
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
calls = [
|
|
mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range
|
|
]
|
|
mock_init_detector.assert_has_calls(calls)
|
|
|
|
|
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
|
@mock.patch('mmocr.utils.ocr.build_detector')
|
|
@mock.patch('mmocr.utils.ocr.Config.fromfile')
|
|
@mock.patch('mmocr.utils.ocr.load_checkpoint')
|
|
@mock.patch('mmocr.utils.ocr.model_inference')
|
|
def test_single_inference(mock_model_inference, mock_loading, mock_config,
|
|
mock_build_detector, mock_init_detector):
|
|
|
|
def dummy_inference(model, arr, batch_mode):
|
|
return arr
|
|
|
|
mock_model_inference.side_effect = dummy_inference
|
|
mmocr = MMOCR()
|
|
|
|
data = list(range(20))
|
|
model = 'dummy'
|
|
res = mmocr.single_inference(model, data, batch_mode=False)
|
|
assert (data == res)
|
|
mock_model_inference.reset_mock()
|
|
|
|
res = mmocr.single_inference(model, data, batch_mode=True)
|
|
assert (data == res)
|
|
mock_model_inference.assert_called_once()
|
|
mock_model_inference.reset_mock()
|
|
|
|
res = mmocr.single_inference(model, data, batch_mode=True, batch_size=100)
|
|
assert (data == res)
|
|
mock_model_inference.assert_called_once()
|
|
mock_model_inference.reset_mock()
|
|
|
|
res = mmocr.single_inference(model, data, batch_mode=True, batch_size=3)
|
|
assert (data == res)
|
|
|
|
|
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
|
@mock.patch('mmocr.utils.ocr.load_checkpoint')
|
|
def MMOCR_testobj(mock_loading, mock_init_detector, **kwargs):
|
|
# returns an MMOCR object bypassing the
|
|
# checkpoint initialization step
|
|
def init_detector_skip_ckpt(config, ckpt, device):
|
|
return init_detector(config, device=device)
|
|
|
|
def modify_kie_class(model, ckpt, map_location):
|
|
model.class_list = 'tests/data/kie_toy_dataset/class_list.txt'
|
|
|
|
mock_init_detector.side_effect = init_detector_skip_ckpt
|
|
mock_loading.side_effect = modify_kie_class
|
|
kwargs['det'] = kwargs.get('det', 'DB_r18')
|
|
kwargs['recog'] = kwargs.get('recog', 'CRNN')
|
|
kwargs['kie'] = kwargs.get('kie', 'SDMGR')
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
return MMOCR(**kwargs, device=device)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
platform.system() == 'Windows',
|
|
reason='Win container on Github Action does not have enough RAM to run')
|
|
@mock.patch('mmocr.utils.ocr.KIEDataset')
|
|
def test_readtext(mock_kiedataset):
|
|
# Fixing the weights of models to prevent them from
|
|
# generating invalid results and triggering other assertion errors
|
|
torch.manual_seed(4)
|
|
random.seed(4)
|
|
mmocr = MMOCR_testobj()
|
|
mmocr_det = MMOCR_testobj(kie='', recog='')
|
|
mmocr_recog = MMOCR_testobj(kie='', det='', recog='CRNN_TPS')
|
|
mmocr_det_recog = MMOCR_testobj(kie='')
|
|
|
|
def readtext(imgs, ocr_obj=mmocr, **kwargs):
|
|
# filename can be different depends on how
|
|
# the the image was loaded
|
|
e2e_res = ocr_obj.readtext(imgs, **kwargs)
|
|
for res in e2e_res:
|
|
res.pop('filename')
|
|
return e2e_res
|
|
|
|
def kiedataset_with_test_dict(**kwargs):
|
|
kwargs['dict_file'] = 'tests/data/kie_toy_dataset/dict.txt'
|
|
return KIEDataset(**kwargs)
|
|
|
|
mock_kiedataset.side_effect = kiedataset_with_test_dict
|
|
|
|
# Single image
|
|
toy_dir = 'tests/data/toy_dataset/imgs/test/'
|
|
toy_img1_path = toy_dir + 'img_1.jpg'
|
|
str_e2e_res = readtext(toy_img1_path)
|
|
toy_img1 = mmcv.imread(toy_img1_path)
|
|
np_e2e_res = readtext(toy_img1)
|
|
assert str_e2e_res == np_e2e_res
|
|
|
|
# Multiple images
|
|
toy_img2_path = toy_dir + 'img_2.jpg'
|
|
toy_img2 = mmcv.imread(toy_img2_path)
|
|
toy_imgs = [toy_img1, toy_img2]
|
|
toy_img_paths = [toy_img1_path, toy_img2_path]
|
|
np_e2e_results = readtext(toy_imgs)
|
|
str_e2e_results = readtext(toy_img_paths)
|
|
str_tuple_e2e_results = readtext(tuple(toy_img_paths))
|
|
assert np_e2e_results == str_e2e_results
|
|
assert str_e2e_results == str_tuple_e2e_results
|
|
|
|
# Batch mode test
|
|
toy_imgs.append(toy_dir + 'img_3.jpg')
|
|
e2e_res = readtext(toy_imgs)
|
|
full_batch_e2e_res = readtext(toy_imgs, batch_mode=True)
|
|
assert full_batch_e2e_res == e2e_res
|
|
batch_e2e_res = readtext(
|
|
toy_imgs, batch_mode=True, recog_batch_size=2, det_batch_size=2)
|
|
assert batch_e2e_res == full_batch_e2e_res
|
|
|
|
# Batch mode test with DBNet only
|
|
full_batch_det_res = mmocr_det.readtext(toy_imgs, batch_mode=True)
|
|
det_res = mmocr_det.readtext(toy_imgs)
|
|
batch_det_res = mmocr_det.readtext(
|
|
toy_imgs, batch_mode=True, single_batch_size=2)
|
|
assert len(full_batch_det_res) == len(det_res)
|
|
assert len(batch_det_res) == len(det_res)
|
|
assert all([
|
|
np.allclose(full_batch_det_res[i]['boundary_result'],
|
|
det_res[i]['boundary_result'])
|
|
for i in range(len(full_batch_det_res))
|
|
])
|
|
assert all([
|
|
np.allclose(batch_det_res[i]['boundary_result'],
|
|
det_res[i]['boundary_result'])
|
|
for i in range(len(batch_det_res))
|
|
])
|
|
|
|
# Batch mode test with CRNN_TPS only (CRNN doesn't support batch inference)
|
|
full_batch_recog_res = mmocr_recog.readtext(toy_imgs, batch_mode=True)
|
|
recog_res = mmocr_recog.readtext(toy_imgs)
|
|
batch_recog_res = mmocr_recog.readtext(
|
|
toy_imgs, batch_mode=True, single_batch_size=2)
|
|
full_batch_recog_res.sort(key=lambda x: x['text'])
|
|
batch_recog_res.sort(key=lambda x: x['text'])
|
|
recog_res.sort(key=lambda x: x['text'])
|
|
assert np.all([
|
|
np.allclose(full_batch_recog_res[i]['score'], recog_res[i]['score'])
|
|
for i in range(len(full_batch_recog_res))
|
|
])
|
|
assert np.all([
|
|
np.allclose(batch_recog_res[i]['score'], recog_res[i]['score'])
|
|
for i in range(len(full_batch_recog_res))
|
|
])
|
|
|
|
# Test export
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
mmocr.readtext(toy_imgs, export=tmpdirname)
|
|
assert len(os.listdir(tmpdirname)) == len(toy_imgs)
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
mmocr_det.readtext(toy_imgs, export=tmpdirname)
|
|
assert len(os.listdir(tmpdirname)) == len(toy_imgs)
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
mmocr_recog.readtext(toy_imgs, export=tmpdirname)
|
|
assert len(os.listdir(tmpdirname)) == len(toy_imgs)
|
|
|
|
# Test output
|
|
# Single image
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
tmp_output = os.path.join(tmpdirname, '1.jpg')
|
|
mmocr.readtext(toy_imgs[0], output=tmp_output)
|
|
assert os.path.exists(tmp_output)
|
|
# Multiple images
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
mmocr.readtext(toy_imgs, output=tmpdirname)
|
|
assert len(os.listdir(tmpdirname)) == len(toy_imgs)
|
|
|
|
# Test imshow
|
|
with mock.patch('mmocr.utils.ocr.mmcv.imshow') as mock_imshow:
|
|
mmocr.readtext(toy_img1_path, imshow=True)
|
|
mock_imshow.assert_called_once()
|
|
mock_imshow.reset_mock()
|
|
mmocr.readtext(toy_imgs, imshow=True)
|
|
assert mock_imshow.call_count == len(toy_imgs)
|
|
|
|
# Test print_result
|
|
with io.StringIO() as capturedOutput:
|
|
sys.stdout = capturedOutput
|
|
res = mmocr.readtext(toy_imgs, print_result=True)
|
|
assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace(
|
|
'\n\n', ',').replace("'", '"')) == res
|
|
sys.stdout = sys.__stdout__
|
|
with io.StringIO() as capturedOutput:
|
|
sys.stdout = capturedOutput
|
|
res = mmocr.readtext(toy_imgs, details=True, print_result=True)
|
|
assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace(
|
|
'\n\n', ',').replace("'", '"')) == res
|
|
sys.stdout = sys.__stdout__
|
|
|
|
# Test merge
|
|
with mock.patch('mmocr.utils.ocr.stitch_boxes_into_lines') as mock_merge:
|
|
mmocr_det_recog.readtext(toy_imgs, merge=True)
|
|
assert mock_merge.call_count == len(toy_imgs)
|
|
|
|
|
|
@pytest.mark.parametrize('det, recog, target',
|
|
[('Tesseract', None, {
|
|
'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 1.0]]
|
|
}),
|
|
('Tesseract', 'Tesseract', {
|
|
'result': [{
|
|
'box': [0, 0, 1, 0, 1, 1, 0, 1],
|
|
'box_score': 1.0,
|
|
'text': 'text',
|
|
'text_score': 0.5
|
|
}]
|
|
}),
|
|
(None, 'Tesseract', {
|
|
'text': 'text',
|
|
'score': 0.5
|
|
})])
|
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
|
@mock.patch('mmocr.utils.ocr.tesserocr')
|
|
def test_tesseract_wrapper(mock_tesserocr, mock_init_detector, det, recog,
|
|
target):
|
|
|
|
def init_detector_skip_ckpt(config, ckpt, device):
|
|
return init_detector(config, device=device)
|
|
|
|
mock_init_detector.side_effect = init_detector_skip_ckpt
|
|
mock_tesseract = mock.Mock()
|
|
mock_tesseract.GetComponentImages.return_value = [(None, {
|
|
'x': 0,
|
|
'y': 0,
|
|
'w': 1,
|
|
'h': 1
|
|
}, 0, None)]
|
|
mock_tesseract.GetUTF8Text.return_value = 'text'
|
|
mock_tesseract.MeanTextConf.return_value = 50
|
|
mock_tesserocr.PyTessBaseAPI.return_value = mock_tesseract
|
|
|
|
mmocr = MMOCR(det=det, recog=recog, device='cpu')
|
|
|
|
img_path = 'demo/demo_kie.jpeg'
|
|
|
|
# Test imshow
|
|
with mock.patch('mmocr.utils.ocr.mmcv.imshow') as mock_imshow:
|
|
result = mmocr.readtext(img_path, imshow=True, details=True)
|
|
for k, v in target.items():
|
|
assert result[0][k] == v
|
|
mock_imshow.assert_called_once()
|
|
mock_imshow.reset_mock()
|