mirror of https://github.com/open-mmlab/mmocr.git
parent
906faec372
commit
80a0536c7c
|
@ -308,10 +308,15 @@ class MMOCR:
|
||||||
elif self.tr and self.tr not in textrecog_models:
|
elif self.tr and self.tr not in textrecog_models:
|
||||||
raise ValueError(self.tr,
|
raise ValueError(self.tr,
|
||||||
'is not a supported text recognition algorithm')
|
'is not a supported text recognition algorithm')
|
||||||
elif self.kie and self.kie not in kie_models:
|
elif self.kie:
|
||||||
|
if self.kie not in kie_models:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
self.kie, 'is not a supported key information extraction'
|
self.kie, 'is not a supported key information extraction'
|
||||||
' algorithm')
|
' algorithm')
|
||||||
|
elif not (self.td and self.tr):
|
||||||
|
raise NotImplementedError(
|
||||||
|
self.kie, 'has to run together'
|
||||||
|
' with text detection and recognition algorithms.')
|
||||||
|
|
||||||
self.detect_model = None
|
self.detect_model = None
|
||||||
if self.td:
|
if self.td:
|
||||||
|
@ -590,7 +595,7 @@ class MMOCR:
|
||||||
return end2end_res
|
return end2end_res
|
||||||
|
|
||||||
# Separate det/recog inference pipeline
|
# Separate det/recog inference pipeline
|
||||||
def single_inference(self, model, arrays, batch_mode, batch_size):
|
def single_inference(self, model, arrays, batch_mode, batch_size=0):
|
||||||
result = []
|
result = []
|
||||||
if batch_mode:
|
if batch_mode:
|
||||||
if batch_size == 0:
|
if batch_size == 0:
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
0 Ignore
|
||||||
|
1 Store_name_value
|
||||||
|
2 Store_name_key
|
||||||
|
3 Store_addr_value
|
||||||
|
4 Store_addr_key
|
||||||
|
5 Tel_value
|
||||||
|
6 Tel_key
|
||||||
|
7 Date_value
|
||||||
|
8 Date_key
|
||||||
|
9 Time_value
|
||||||
|
10 Time_key
|
||||||
|
11 Prod_item_value
|
||||||
|
12 Prod_item_key
|
||||||
|
13 Prod_quantity_value
|
||||||
|
14 Prod_quantity_key
|
||||||
|
15 Prod_price_value
|
||||||
|
16 Prod_price_key
|
||||||
|
17 Subtotal_value
|
||||||
|
18 Subtotal_key
|
||||||
|
19 Tax_value
|
||||||
|
20 Tax_key
|
||||||
|
21 Tips_value
|
||||||
|
22 Tips_key
|
||||||
|
23 Total_value
|
||||||
|
24 Total_key
|
||||||
|
25 Others
|
|
@ -0,0 +1,91 @@
|
||||||
|
/
|
||||||
|
\
|
||||||
|
.
|
||||||
|
$
|
||||||
|
£
|
||||||
|
€
|
||||||
|
¥
|
||||||
|
:
|
||||||
|
-
|
||||||
|
,
|
||||||
|
*
|
||||||
|
#
|
||||||
|
(
|
||||||
|
)
|
||||||
|
%
|
||||||
|
@
|
||||||
|
!
|
||||||
|
'
|
||||||
|
&
|
||||||
|
=
|
||||||
|
>
|
||||||
|
+
|
||||||
|
"
|
||||||
|
×
|
||||||
|
?
|
||||||
|
<
|
||||||
|
[
|
||||||
|
]
|
||||||
|
_
|
||||||
|
0
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
d
|
||||||
|
e
|
||||||
|
f
|
||||||
|
g
|
||||||
|
h
|
||||||
|
i
|
||||||
|
j
|
||||||
|
k
|
||||||
|
l
|
||||||
|
m
|
||||||
|
n
|
||||||
|
o
|
||||||
|
p
|
||||||
|
q
|
||||||
|
r
|
||||||
|
s
|
||||||
|
t
|
||||||
|
u
|
||||||
|
v
|
||||||
|
w
|
||||||
|
x
|
||||||
|
y
|
||||||
|
z
|
||||||
|
A
|
||||||
|
B
|
||||||
|
C
|
||||||
|
D
|
||||||
|
E
|
||||||
|
F
|
||||||
|
G
|
||||||
|
H
|
||||||
|
I
|
||||||
|
J
|
||||||
|
K
|
||||||
|
L
|
||||||
|
M
|
||||||
|
N
|
||||||
|
O
|
||||||
|
P
|
||||||
|
Q
|
||||||
|
R
|
||||||
|
S
|
||||||
|
T
|
||||||
|
U
|
||||||
|
V
|
||||||
|
W
|
||||||
|
X
|
||||||
|
Y
|
||||||
|
Z
|
|
@ -0,0 +1,353 @@
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from mmdet.apis import init_detector
|
||||||
|
|
||||||
|
from mmocr.datasets.kie_dataset import KIEDataset
|
||||||
|
from mmocr.utils.ocr import MMOCR
|
||||||
|
|
||||||
|
|
||||||
|
def test_ocr_init_errors():
|
||||||
|
# Test assertions
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = MMOCR(det='test')
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = MMOCR(recog='test')
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = MMOCR(kie='test')
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
_ = MMOCR(det=None, recog=None, kie='SDMGR')
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
_ = MMOCR(det='DB_r18', recog=None, kie='SDMGR')
|
||||||
|
|
||||||
|
|
||||||
|
cfg_default_prefix = os.path.join(str(Path.cwd()), 'configs/')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'det, recog, kie, config_dir, gt_cfg, gt_ckpt',
|
||||||
|
[('DB_r18', None, '', '',
|
||||||
|
cfg_default_prefix + 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
||||||
|
'https://download.openmmlab.com/mmocr/textdet/'
|
||||||
|
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'),
|
||||||
|
(None, 'CRNN', '', '',
|
||||||
|
cfg_default_prefix + 'textrecog/crnn/crnn_academic_dataset.py',
|
||||||
|
'https://download.openmmlab.com/mmocr/textrecog/'
|
||||||
|
'crnn/crnn_academic-a723a1c5.pth'),
|
||||||
|
('DB_r18', 'CRNN', 'SDMGR', '', [
|
||||||
|
cfg_default_prefix +
|
||||||
|
'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
||||||
|
cfg_default_prefix + 'textrecog/crnn/crnn_academic_dataset.py',
|
||||||
|
cfg_default_prefix + 'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py'
|
||||||
|
], [
|
||||||
|
'https://download.openmmlab.com/mmocr/textdet/'
|
||||||
|
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth',
|
||||||
|
'https://download.openmmlab.com/mmocr/textrecog/'
|
||||||
|
'crnn/crnn_academic-a723a1c5.pth',
|
||||||
|
'https://download.openmmlab.com/mmocr/kie/'
|
||||||
|
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
||||||
|
]),
|
||||||
|
('DB_r18', 'CRNN', 'SDMGR', 'test/', [
|
||||||
|
'test/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
||||||
|
'test/textrecog/crnn/crnn_academic_dataset.py',
|
||||||
|
'test/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py'
|
||||||
|
], [
|
||||||
|
'https://download.openmmlab.com/mmocr/textdet/'
|
||||||
|
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth',
|
||||||
|
'https://download.openmmlab.com/mmocr/textrecog/'
|
||||||
|
'crnn/crnn_academic-a723a1c5.pth',
|
||||||
|
'https://download.openmmlab.com/mmocr/kie/'
|
||||||
|
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
||||||
|
])],
|
||||||
|
)
|
||||||
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
||||||
|
@mock.patch('mmocr.utils.ocr.build_detector')
|
||||||
|
@mock.patch('mmocr.utils.ocr.Config.fromfile')
|
||||||
|
@mock.patch('mmocr.utils.ocr.load_checkpoint')
|
||||||
|
def test_ocr_init(mock_loading, mock_config, mock_build_detector,
|
||||||
|
mock_init_detector, det, recog, kie, config_dir, gt_cfg,
|
||||||
|
gt_ckpt):
|
||||||
|
|
||||||
|
def loadcheckpoint_assert(*args, **kwargs):
|
||||||
|
assert args[1] == gt_ckpt[-1]
|
||||||
|
|
||||||
|
mock_loading.side_effect = loadcheckpoint_assert
|
||||||
|
with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'):
|
||||||
|
if kie == '':
|
||||||
|
if config_dir == '':
|
||||||
|
_ = MMOCR(det=det, recog=recog)
|
||||||
|
else:
|
||||||
|
_ = MMOCR(det=det, recog=recog, config_dir=config_dir)
|
||||||
|
else:
|
||||||
|
if config_dir == '':
|
||||||
|
_ = MMOCR(det=det, recog=recog, kie=kie)
|
||||||
|
else:
|
||||||
|
_ = MMOCR(det=det, recog=recog, kie=kie, config_dir=config_dir)
|
||||||
|
if isinstance(gt_cfg, str):
|
||||||
|
gt_cfg = [gt_cfg]
|
||||||
|
if isinstance(gt_ckpt, str):
|
||||||
|
gt_ckpt = [gt_ckpt]
|
||||||
|
|
||||||
|
i_range = range(len(gt_cfg))
|
||||||
|
if kie:
|
||||||
|
i_range = i_range[:-1]
|
||||||
|
mock_config.assert_called_with(gt_cfg[-1])
|
||||||
|
mock_build_detector.assert_called_once()
|
||||||
|
mock_loading.assert_called_once()
|
||||||
|
calls = [
|
||||||
|
mock.call(gt_cfg[i], gt_ckpt[i], device='cuda:0') for i in i_range
|
||||||
|
]
|
||||||
|
mock_init_detector.assert_has_calls(calls)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'det, det_config, det_ckpt, recog, recog_config, recog_ckpt,'
|
||||||
|
'kie, kie_config, kie_ckpt, config_dir, gt_cfg, gt_ckpt',
|
||||||
|
[('DB_r18', 'test.py', '', 'CRNN', 'test.py', '', 'SDMGR', 'test.py', '',
|
||||||
|
'configs/', ['test.py', 'test.py', 'test.py'], [
|
||||||
|
'https://download.openmmlab.com/mmocr/textdet/'
|
||||||
|
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth',
|
||||||
|
'https://download.openmmlab.com/mmocr/textrecog/'
|
||||||
|
'crnn/crnn_academic-a723a1c5.pth',
|
||||||
|
'https://download.openmmlab.com/mmocr/kie/'
|
||||||
|
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
||||||
|
]),
|
||||||
|
('DB_r18', '', 'test.ckpt', 'CRNN', '', 'test.ckpt', 'SDMGR', '',
|
||||||
|
'test.ckpt', '', [
|
||||||
|
'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
||||||
|
'textrecog/crnn/crnn_academic_dataset.py',
|
||||||
|
'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py'
|
||||||
|
], ['test.ckpt', 'test.ckpt', 'test.ckpt']),
|
||||||
|
('DB_r18', 'test.py', 'test.ckpt', 'CRNN', 'test.py', 'test.ckpt',
|
||||||
|
'SDMGR', 'test.py', 'test.ckpt', '', ['test.py', 'test.py', 'test.py'],
|
||||||
|
['test.ckpt', 'test.ckpt', 'test.ckpt'])])
|
||||||
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
||||||
|
@mock.patch('mmocr.utils.ocr.build_detector')
|
||||||
|
@mock.patch('mmocr.utils.ocr.Config.fromfile')
|
||||||
|
@mock.patch('mmocr.utils.ocr.load_checkpoint')
|
||||||
|
def test_ocr_init_customize_config(mock_loading, mock_config,
|
||||||
|
mock_build_detector, mock_init_detector,
|
||||||
|
det, det_config, det_ckpt, recog,
|
||||||
|
recog_config, recog_ckpt, kie, kie_config,
|
||||||
|
kie_ckpt, config_dir, gt_cfg, gt_ckpt):
|
||||||
|
|
||||||
|
def loadcheckpoint_assert(*args, **kwargs):
|
||||||
|
assert args[1] == gt_ckpt[-1]
|
||||||
|
|
||||||
|
mock_loading.side_effect = loadcheckpoint_assert
|
||||||
|
with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'):
|
||||||
|
_ = MMOCR(
|
||||||
|
det=det,
|
||||||
|
det_config=det_config,
|
||||||
|
det_ckpt=det_ckpt,
|
||||||
|
recog=recog,
|
||||||
|
recog_config=recog_config,
|
||||||
|
recog_ckpt=recog_ckpt,
|
||||||
|
kie=kie,
|
||||||
|
kie_config=kie_config,
|
||||||
|
kie_ckpt=kie_ckpt,
|
||||||
|
config_dir=config_dir)
|
||||||
|
|
||||||
|
i_range = range(len(gt_cfg))
|
||||||
|
if kie:
|
||||||
|
i_range = i_range[:-1]
|
||||||
|
mock_config.assert_called_with(gt_cfg[-1])
|
||||||
|
mock_build_detector.assert_called_once()
|
||||||
|
mock_loading.assert_called_once()
|
||||||
|
calls = [
|
||||||
|
mock.call(gt_cfg[i], gt_ckpt[i], device='cuda:0') for i in i_range
|
||||||
|
]
|
||||||
|
mock_init_detector.assert_has_calls(calls)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
||||||
|
@mock.patch('mmocr.utils.ocr.build_detector')
|
||||||
|
@mock.patch('mmocr.utils.ocr.Config.fromfile')
|
||||||
|
@mock.patch('mmocr.utils.ocr.load_checkpoint')
|
||||||
|
@mock.patch('mmocr.utils.ocr.model_inference')
|
||||||
|
def test_single_inference(mock_model_inference, mock_loading, mock_config,
|
||||||
|
mock_build_detector, mock_init_detector):
|
||||||
|
|
||||||
|
def dummy_inference(model, arr, batch_mode):
|
||||||
|
return arr
|
||||||
|
|
||||||
|
mock_model_inference.side_effect = dummy_inference
|
||||||
|
mmocr = MMOCR()
|
||||||
|
|
||||||
|
data = list(range(20))
|
||||||
|
model = 'dummy'
|
||||||
|
res = mmocr.single_inference(model, data, batch_mode=False)
|
||||||
|
assert (data == res)
|
||||||
|
mock_model_inference.reset_mock()
|
||||||
|
|
||||||
|
res = mmocr.single_inference(model, data, batch_mode=True)
|
||||||
|
assert (data == res)
|
||||||
|
mock_model_inference.assert_called_once()
|
||||||
|
mock_model_inference.reset_mock()
|
||||||
|
|
||||||
|
res = mmocr.single_inference(model, data, batch_mode=True, batch_size=100)
|
||||||
|
assert (data == res)
|
||||||
|
mock_model_inference.assert_called_once()
|
||||||
|
mock_model_inference.reset_mock()
|
||||||
|
|
||||||
|
res = mmocr.single_inference(model, data, batch_mode=True, batch_size=3)
|
||||||
|
assert (data == res)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch('mmocr.utils.ocr.init_detector')
|
||||||
|
@mock.patch('mmocr.utils.ocr.load_checkpoint')
|
||||||
|
def MMOCR_testobj(mock_loading, mock_init_detector, **kwargs):
|
||||||
|
# returns an MMOCR object bypassing the
|
||||||
|
# checkpoint initialization step
|
||||||
|
def init_detector_skip_ckpt(config, ckpt, device):
|
||||||
|
return init_detector(config, device=device)
|
||||||
|
|
||||||
|
def modify_kie_class(model, ckpt, map_location):
|
||||||
|
model.class_list = 'tests/data/kie_toy_dataset/class_list.txt'
|
||||||
|
|
||||||
|
mock_init_detector.side_effect = init_detector_skip_ckpt
|
||||||
|
mock_loading.side_effect = modify_kie_class
|
||||||
|
kwargs['det'] = kwargs.get('det', 'DB_r18')
|
||||||
|
kwargs['recog'] = kwargs.get('recog', 'CRNN')
|
||||||
|
kwargs['kie'] = kwargs.get('kie', 'SDMGR')
|
||||||
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||||
|
return MMOCR(**kwargs, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch('mmocr.utils.ocr.KIEDataset')
|
||||||
|
def test_readtext(mock_kiedataset):
|
||||||
|
# Fixing the weights of models to prevent them from
|
||||||
|
# generating invalid results and triggering other assertion errors
|
||||||
|
torch.manual_seed(4)
|
||||||
|
random.seed(4)
|
||||||
|
mmocr = MMOCR_testobj()
|
||||||
|
mmocr_det = MMOCR_testobj(kie='', recog='')
|
||||||
|
mmocr_recog = MMOCR_testobj(kie='', det='', recog='CRNN_TPS')
|
||||||
|
mmocr_det_recog = MMOCR_testobj(kie='')
|
||||||
|
|
||||||
|
def readtext(imgs, ocr_obj=mmocr, **kwargs):
|
||||||
|
# filename can be different depends on how
|
||||||
|
# the the image was loaded
|
||||||
|
e2e_res = ocr_obj.readtext(imgs, **kwargs)
|
||||||
|
for res in e2e_res:
|
||||||
|
res.pop('filename')
|
||||||
|
return e2e_res
|
||||||
|
|
||||||
|
def kiedataset_with_test_dict(**kwargs):
|
||||||
|
kwargs['dict_file'] = 'tests/data/kie_toy_dataset/dict.txt'
|
||||||
|
return KIEDataset(**kwargs)
|
||||||
|
|
||||||
|
mock_kiedataset.side_effect = kiedataset_with_test_dict
|
||||||
|
|
||||||
|
# Single image
|
||||||
|
toy_dir = 'tests/data/toy_dataset/imgs/test/'
|
||||||
|
toy_img1_path = toy_dir + 'img_1.jpg'
|
||||||
|
str_e2e_res = readtext(toy_img1_path)
|
||||||
|
toy_img1 = mmcv.imread(toy_img1_path)
|
||||||
|
np_e2e_res = readtext(toy_img1)
|
||||||
|
assert str_e2e_res == np_e2e_res
|
||||||
|
|
||||||
|
# Multiple images
|
||||||
|
toy_img2_path = toy_dir + 'img_2.jpg'
|
||||||
|
toy_img2 = mmcv.imread(toy_img2_path)
|
||||||
|
toy_imgs = [toy_img1, toy_img2]
|
||||||
|
toy_img_paths = [toy_img1_path, toy_img2_path]
|
||||||
|
np_e2e_results = readtext(toy_imgs)
|
||||||
|
str_e2e_results = readtext(toy_img_paths)
|
||||||
|
str_tuple_e2e_results = readtext(tuple(toy_img_paths))
|
||||||
|
assert np_e2e_results == str_e2e_results
|
||||||
|
assert str_e2e_results == str_tuple_e2e_results
|
||||||
|
|
||||||
|
# Batch mode test
|
||||||
|
toy_imgs.append(toy_dir + 'img_3.jpg')
|
||||||
|
e2e_res = readtext(toy_imgs)
|
||||||
|
full_batch_e2e_res = readtext(toy_imgs, batch_mode=True)
|
||||||
|
assert full_batch_e2e_res == e2e_res
|
||||||
|
batch_e2e_res = readtext(
|
||||||
|
toy_imgs, batch_mode=True, recog_batch_size=2, det_batch_size=2)
|
||||||
|
assert batch_e2e_res == full_batch_e2e_res
|
||||||
|
|
||||||
|
# Batch mode test with DBNet only
|
||||||
|
full_batch_det_res = mmocr_det.readtext(toy_imgs, batch_mode=True)
|
||||||
|
det_res = mmocr_det.readtext(toy_imgs)
|
||||||
|
batch_det_res = mmocr_det.readtext(
|
||||||
|
toy_imgs, batch_mode=True, single_batch_size=2)
|
||||||
|
assert len(full_batch_det_res) == len(det_res)
|
||||||
|
assert len(batch_det_res) == len(det_res)
|
||||||
|
assert all([
|
||||||
|
np.allclose(full_batch_det_res[i]['boundary_result'],
|
||||||
|
det_res[i]['boundary_result'])
|
||||||
|
for i in range(len(full_batch_det_res))
|
||||||
|
])
|
||||||
|
assert all([
|
||||||
|
np.allclose(batch_det_res[i]['boundary_result'],
|
||||||
|
det_res[i]['boundary_result'])
|
||||||
|
for i in range(len(batch_det_res))
|
||||||
|
])
|
||||||
|
|
||||||
|
# Batch mode test with CRNN_TPS only (CRNN doesn't support batch inference)
|
||||||
|
full_batch_recog_res = mmocr_recog.readtext(toy_imgs, batch_mode=True)
|
||||||
|
recog_res = mmocr_recog.readtext(toy_imgs)
|
||||||
|
batch_recog_res = mmocr_recog.readtext(
|
||||||
|
toy_imgs, batch_mode=True, single_batch_size=2)
|
||||||
|
assert full_batch_recog_res == recog_res
|
||||||
|
assert batch_recog_res == recog_res
|
||||||
|
|
||||||
|
# Test export
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
mmocr.readtext(toy_imgs, export=tmpdirname)
|
||||||
|
assert len(os.listdir(tmpdirname)) == len(toy_imgs)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
mmocr_det.readtext(toy_imgs, export=tmpdirname)
|
||||||
|
assert len(os.listdir(tmpdirname)) == len(toy_imgs)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
mmocr_recog.readtext(toy_imgs, export=tmpdirname)
|
||||||
|
assert len(os.listdir(tmpdirname)) == len(toy_imgs)
|
||||||
|
|
||||||
|
# Test output
|
||||||
|
# Single image
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
tmp_output = os.path.join(tmpdirname, '1.jpg')
|
||||||
|
mmocr.readtext(toy_imgs[0], output=tmp_output)
|
||||||
|
assert os.path.exists(tmp_output)
|
||||||
|
# Multiple images
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
mmocr.readtext(toy_imgs, output=tmpdirname)
|
||||||
|
assert len(os.listdir(tmpdirname)) == len(toy_imgs)
|
||||||
|
|
||||||
|
# Test imshow
|
||||||
|
with mock.patch('mmocr.utils.ocr.mmcv.imshow') as mock_imshow:
|
||||||
|
mmocr.readtext(toy_img1_path, imshow=True)
|
||||||
|
mock_imshow.assert_called_once()
|
||||||
|
mock_imshow.reset_mock()
|
||||||
|
mmocr.readtext(toy_imgs, imshow=True)
|
||||||
|
assert mock_imshow.call_count == len(toy_imgs)
|
||||||
|
|
||||||
|
# Test print_result
|
||||||
|
with io.StringIO() as capturedOutput:
|
||||||
|
sys.stdout = capturedOutput
|
||||||
|
res = mmocr.readtext(toy_imgs, print_result=True)
|
||||||
|
assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace(
|
||||||
|
'\n\n', ',').replace("'", '"')) == res
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
with io.StringIO() as capturedOutput:
|
||||||
|
sys.stdout = capturedOutput
|
||||||
|
res = mmocr.readtext(toy_imgs, details=True, print_result=True)
|
||||||
|
assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace(
|
||||||
|
'\n\n', ',').replace("'", '"')) == res
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
|
||||||
|
# Test merge
|
||||||
|
with mock.patch('mmocr.utils.ocr.stitch_boxes_into_lines') as mock_merge:
|
||||||
|
mmocr_det_recog.readtext(toy_imgs, merge=True)
|
||||||
|
assert mock_merge.call_count == len(toy_imgs)
|
Loading…
Reference in New Issue