mirror of https://github.com/open-mmlab/mmocr.git
878 lines
33 KiB
Python
878 lines
33 KiB
Python
#!/usr/bin/env python
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os
|
|
import warnings
|
|
from argparse import ArgumentParser, Namespace
|
|
from pathlib import Path
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.image.misc import tensor2imgs
|
|
from mmcv.runner import load_checkpoint
|
|
from mmcv.utils.config import Config
|
|
from PIL import Image
|
|
|
|
try:
|
|
import tesserocr
|
|
except ImportError:
|
|
tesserocr = None
|
|
|
|
from mmocr.apis import init_detector
|
|
from mmocr.apis.inference import model_inference
|
|
from mmocr.datasets import WildReceiptDataset
|
|
from mmocr.models.textdet.detectors import TextDetectorMixin
|
|
from mmocr.models.textrecog.recognizers import BaseRecognizer
|
|
from mmocr.registry import MODELS
|
|
from mmocr.utils import is_type_list, stitch_boxes_into_lines
|
|
from mmocr.utils.fileio import list_from_file
|
|
from mmocr.utils.img_utils import crop_img
|
|
from mmocr.utils.model import revert_sync_batchnorm
|
|
from mmocr.visualization.visualize import det_recog_show_result
|
|
|
|
|
|
# Parse CLI arguments
|
|
def parse_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument(
|
|
'img', type=str, help='Input image file or folder path.')
|
|
parser.add_argument(
|
|
'--output',
|
|
type=str,
|
|
default='',
|
|
help='Output file/folder name for visualization')
|
|
parser.add_argument(
|
|
'--det',
|
|
type=str,
|
|
default='PANet_IC15',
|
|
help='Pretrained text detection algorithm')
|
|
parser.add_argument(
|
|
'--det-config',
|
|
type=str,
|
|
default='',
|
|
help='Path to the custom config file of the selected det model. It '
|
|
'overrides the settings in det')
|
|
parser.add_argument(
|
|
'--det-ckpt',
|
|
type=str,
|
|
default='',
|
|
help='Path to the custom checkpoint file of the selected det model. '
|
|
'It overrides the settings in det')
|
|
parser.add_argument(
|
|
'--recog',
|
|
type=str,
|
|
default='SEG',
|
|
help='Pretrained text recognition algorithm')
|
|
parser.add_argument(
|
|
'--recog-config',
|
|
type=str,
|
|
default='',
|
|
help='Path to the custom config file of the selected recog model. It'
|
|
'overrides the settings in recog')
|
|
parser.add_argument(
|
|
'--recog-ckpt',
|
|
type=str,
|
|
default='',
|
|
help='Path to the custom checkpoint file of the selected recog model. '
|
|
'It overrides the settings in recog')
|
|
parser.add_argument(
|
|
'--kie',
|
|
type=str,
|
|
default='',
|
|
help='Pretrained key information extraction algorithm')
|
|
parser.add_argument(
|
|
'--kie-config',
|
|
type=str,
|
|
default='',
|
|
help='Path to the custom config file of the selected kie model. It'
|
|
'overrides the settings in kie')
|
|
parser.add_argument(
|
|
'--kie-ckpt',
|
|
type=str,
|
|
default='',
|
|
help='Path to the custom checkpoint file of the selected kie model. '
|
|
'It overrides the settings in kie')
|
|
parser.add_argument(
|
|
'--config-dir',
|
|
type=str,
|
|
default=os.path.join(str(Path.cwd()), 'configs/'),
|
|
help='Path to the config directory where all the config files '
|
|
'are located. Defaults to "configs/"')
|
|
parser.add_argument(
|
|
'--batch-mode',
|
|
action='store_true',
|
|
help='Whether use batch mode for inference')
|
|
parser.add_argument(
|
|
'--recog-batch-size',
|
|
type=int,
|
|
default=0,
|
|
help='Batch size for text recognition')
|
|
parser.add_argument(
|
|
'--det-batch-size',
|
|
type=int,
|
|
default=0,
|
|
help='Batch size for text detection')
|
|
parser.add_argument(
|
|
'--single-batch-size',
|
|
type=int,
|
|
default=0,
|
|
help='Batch size for separate det/recog inference')
|
|
parser.add_argument(
|
|
'--device', default=None, help='Device used for inference.')
|
|
parser.add_argument(
|
|
'--export',
|
|
type=str,
|
|
default='',
|
|
help='Folder where the results of each image are exported')
|
|
parser.add_argument(
|
|
'--export-format',
|
|
type=str,
|
|
default='json',
|
|
help='Format of the exported result file(s)')
|
|
parser.add_argument(
|
|
'--details',
|
|
action='store_true',
|
|
help='Whether include the text boxes coordinates and confidence values'
|
|
)
|
|
parser.add_argument(
|
|
'--imshow',
|
|
action='store_true',
|
|
help='Whether show image with OpenCV.')
|
|
parser.add_argument(
|
|
'--print-result',
|
|
action='store_true',
|
|
help='Prints the recognised text')
|
|
parser.add_argument(
|
|
'--merge', action='store_true', help='Merge neighboring boxes')
|
|
parser.add_argument(
|
|
'--merge-xdist',
|
|
type=float,
|
|
default=20,
|
|
help='The maximum x-axis distance to merge boxes')
|
|
args = parser.parse_args()
|
|
if args.det == 'None':
|
|
args.det = None
|
|
if args.recog == 'None':
|
|
args.recog = None
|
|
# Warnings
|
|
if args.merge and not (args.det and args.recog):
|
|
warnings.warn(
|
|
'Box merging will not work if the script is not'
|
|
' running in detection + recognition mode.', UserWarning)
|
|
if not os.path.samefile(args.config_dir, os.path.join(str(
|
|
Path.cwd()))) and (args.det_config != ''
|
|
or args.recog_config != ''):
|
|
warnings.warn(
|
|
'config_dir will be overridden by det-config or recog-config.',
|
|
UserWarning)
|
|
return args
|
|
|
|
|
|
class MMOCR:
|
|
|
|
def __init__(self,
|
|
det='PANet_IC15',
|
|
det_config='',
|
|
det_ckpt='',
|
|
recog='SEG',
|
|
recog_config='',
|
|
recog_ckpt='',
|
|
kie='',
|
|
kie_config='',
|
|
kie_ckpt='',
|
|
config_dir=os.path.join(str(Path.cwd()), 'configs/'),
|
|
device=None,
|
|
**kwargs):
|
|
|
|
textdet_models = {
|
|
'DB_r18': {
|
|
'config':
|
|
'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
|
'ckpt':
|
|
'dbnet/'
|
|
'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
|
|
},
|
|
'DB_r50': {
|
|
'config':
|
|
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
|
|
'ckpt':
|
|
'dbnet/'
|
|
'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth'
|
|
},
|
|
'DBPP_r50': {
|
|
'config':
|
|
'dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py',
|
|
'ckpt':
|
|
'dbnet/'
|
|
'dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth'
|
|
},
|
|
'DRRG': {
|
|
'config':
|
|
'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
|
|
'ckpt':
|
|
'drrg/drrg_r50_fpn_unet_1200e_ctw1500_20211022-fb30b001.pth'
|
|
},
|
|
'FCE_IC15': {
|
|
'config':
|
|
'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
|
|
'ckpt':
|
|
'fcenet/fcenet_r50_fpn_1500e_icdar2015_20211022-daefb6ed.pth'
|
|
},
|
|
'FCE_CTW_DCNv2': {
|
|
'config':
|
|
'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
|
|
'ckpt':
|
|
'fcenet/' +
|
|
'fcenet_r50dcnv2_fpn_1500e_ctw1500_20211022-e326d7ec.pth'
|
|
},
|
|
'MaskRCNN_CTW': {
|
|
'config':
|
|
'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
|
|
'ckpt':
|
|
'maskrcnn/'
|
|
'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
|
|
},
|
|
'MaskRCNN_IC15': {
|
|
'config':
|
|
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
|
|
'ckpt':
|
|
'maskrcnn/'
|
|
'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
|
|
},
|
|
'MaskRCNN_IC17': {
|
|
'config':
|
|
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
|
|
'ckpt':
|
|
'maskrcnn/'
|
|
'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
|
|
},
|
|
'PANet_CTW': {
|
|
'config':
|
|
'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
|
|
'ckpt':
|
|
'panet/'
|
|
'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
|
|
},
|
|
'PANet_IC15': {
|
|
'config':
|
|
'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
|
|
'ckpt':
|
|
'panet/'
|
|
'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
|
|
},
|
|
'PS_CTW': {
|
|
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
|
|
'ckpt':
|
|
'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
|
|
},
|
|
'PS_IC15': {
|
|
'config':
|
|
'psenet/psenet_r50_fpnf_600e_icdar2015.py',
|
|
'ckpt':
|
|
'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
|
|
},
|
|
'TextSnake': {
|
|
'config':
|
|
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
|
|
'ckpt':
|
|
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
|
|
},
|
|
'Tesseract': {}
|
|
}
|
|
|
|
textrecog_models = {
|
|
'CRNN': {
|
|
'config': 'crnn/crnn_academic_dataset.py',
|
|
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
|
|
},
|
|
'SAR': {
|
|
'config': 'sar/sar_r31_parallel_decoder_academic.py',
|
|
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
|
|
},
|
|
'SAR_CN': {
|
|
'config':
|
|
'sar/sar_r31_parallel_decoder_chinese.py',
|
|
'ckpt':
|
|
'sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth'
|
|
},
|
|
'NRTR_1/16-1/8': {
|
|
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
|
|
'ckpt':
|
|
'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth'
|
|
},
|
|
'NRTR_1/8-1/4': {
|
|
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
|
|
'ckpt':
|
|
'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth'
|
|
},
|
|
'RobustScanner': {
|
|
'config': 'robust_scanner/robustscanner_r31_academic.py',
|
|
'ckpt': 'robustscanner/robustscanner_r31_academic-5f05874f.pth'
|
|
},
|
|
'SATRN': {
|
|
'config': 'satrn/satrn_academic.py',
|
|
'ckpt': 'satrn/satrn_academic_20211009-cb8b1580.pth'
|
|
},
|
|
'SATRN_sm': {
|
|
'config': 'satrn/satrn_small.py',
|
|
'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth'
|
|
},
|
|
'ABINet': {
|
|
'config': 'abinet/abinet_academic.py',
|
|
'ckpt': 'abinet/abinet_academic-f718abf6.pth'
|
|
},
|
|
'ABINet_Vision': {
|
|
'config': 'abinet/abinet_vision_only_academic.py',
|
|
'ckpt': 'abinet/abinet_vision_only_academic-e6b9ea89.pth'
|
|
},
|
|
'SEG': {
|
|
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
|
|
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
|
|
},
|
|
'CRNN_TPS': {
|
|
'config': 'tps/crnn_tps_academic_dataset.py',
|
|
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
|
|
},
|
|
'Tesseract': {},
|
|
'MASTER': {
|
|
'config': 'master/master_r31_12e_ST_MJ_SA.py',
|
|
'ckpt': 'master/master_r31_12e_ST_MJ_SA-787edd36.pth'
|
|
}
|
|
}
|
|
|
|
kie_models = {
|
|
'SDMGR': {
|
|
'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py',
|
|
'ckpt':
|
|
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
|
}
|
|
}
|
|
|
|
self.td = det
|
|
self.tr = recog
|
|
self.kie = kie
|
|
self.device = device
|
|
if self.device is None:
|
|
self.device = torch.device(
|
|
'cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Check if the det/recog model choice is valid
|
|
if self.td and self.td not in textdet_models:
|
|
raise ValueError(self.td,
|
|
'is not a supported text detection algorthm')
|
|
elif self.tr and self.tr not in textrecog_models:
|
|
raise ValueError(self.tr,
|
|
'is not a supported text recognition algorithm')
|
|
elif self.kie:
|
|
if self.kie not in kie_models:
|
|
raise ValueError(
|
|
self.kie, 'is not a supported key information extraction'
|
|
' algorithm')
|
|
elif not (self.td and self.tr):
|
|
raise NotImplementedError(
|
|
self.kie, 'has to run together'
|
|
' with text detection and recognition algorithms.')
|
|
|
|
self.detect_model = None
|
|
if self.td and self.td == 'Tesseract':
|
|
if tesserocr is None:
|
|
raise ImportError('Please install tesserocr first. '
|
|
'Check out the installation guide at '
|
|
'https://github.com/sirfz/tesserocr')
|
|
self.detect_model = 'Tesseract_det'
|
|
elif self.td:
|
|
# Build detection model
|
|
if not det_config:
|
|
det_config = os.path.join(config_dir, 'textdet/',
|
|
textdet_models[self.td]['config'])
|
|
if not det_ckpt:
|
|
det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \
|
|
textdet_models[self.td]['ckpt']
|
|
|
|
self.detect_model = init_detector(
|
|
det_config, det_ckpt, device=self.device)
|
|
self.detect_model = revert_sync_batchnorm(self.detect_model)
|
|
|
|
self.recog_model = None
|
|
if self.tr and self.tr == 'Tesseract':
|
|
if tesserocr is None:
|
|
raise ImportError('Please install tesserocr first. '
|
|
'Check out the installation guide at '
|
|
'https://github.com/sirfz/tesserocr')
|
|
self.recog_model = 'Tesseract_recog'
|
|
elif self.tr:
|
|
# Build recognition model
|
|
if not recog_config:
|
|
recog_config = os.path.join(
|
|
config_dir, 'textrecog/',
|
|
textrecog_models[self.tr]['config'])
|
|
if not recog_ckpt:
|
|
recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
|
|
'textrecog/' + textrecog_models[self.tr]['ckpt']
|
|
|
|
self.recog_model = init_detector(
|
|
recog_config, recog_ckpt, device=self.device)
|
|
self.recog_model = revert_sync_batchnorm(self.recog_model)
|
|
|
|
self.kie_model = None
|
|
if self.kie:
|
|
# Build key information extraction model
|
|
if not kie_config:
|
|
kie_config = os.path.join(config_dir, 'kie/',
|
|
kie_models[self.kie]['config'])
|
|
if not kie_ckpt:
|
|
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
|
|
'kie/' + kie_models[self.kie]['ckpt']
|
|
|
|
kie_cfg = Config.fromfile(kie_config)
|
|
self.kie_model = MODELS.build(
|
|
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
|
|
self.kie_model = revert_sync_batchnorm(self.kie_model)
|
|
self.kie_model.cfg = kie_cfg
|
|
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
|
|
|
|
# Attribute check
|
|
for model in list(filter(None, [self.recog_model, self.detect_model])):
|
|
if hasattr(model, 'module'):
|
|
model = model.module
|
|
|
|
@staticmethod
|
|
def get_tesserocr_api():
|
|
"""Get tesserocr api depending on different platform."""
|
|
import subprocess
|
|
import sys
|
|
|
|
if sys.platform == 'linux':
|
|
api = tesserocr.PyTessBaseAPI()
|
|
elif sys.platform == 'win32':
|
|
try:
|
|
p = subprocess.Popen(
|
|
'where tesseract', stdout=subprocess.PIPE, shell=True)
|
|
s = p.communicate()[0].decode('utf-8').split('\\')
|
|
path = s[:-1] + ['tessdata']
|
|
tessdata_path = '/'.join(path)
|
|
api = tesserocr.PyTessBaseAPI(path=tessdata_path)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
'Please install tesseract first.\n Check out the'
|
|
' installation guide at'
|
|
' https://github.com/UB-Mannheim/tesseract/wiki')
|
|
else:
|
|
raise NotImplementedError
|
|
return api
|
|
|
|
def tesseract_det_inference(self, imgs, **kwargs):
|
|
"""Inference image(s) with the tesseract detector.
|
|
|
|
Args:
|
|
imgs (ndarray or list[ndarray]): image(s) to inference.
|
|
|
|
Returns:
|
|
result (dict): Predicted results.
|
|
"""
|
|
is_batch = True
|
|
if isinstance(imgs, np.ndarray):
|
|
is_batch = False
|
|
imgs = [imgs]
|
|
assert is_type_list(imgs, np.ndarray)
|
|
api = self.get_tesserocr_api()
|
|
|
|
# Get detection result using tesseract
|
|
results = []
|
|
for img in imgs:
|
|
image = Image.fromarray(img)
|
|
api.SetImage(image)
|
|
boxes = api.GetComponentImages(tesserocr.RIL.TEXTLINE, True)
|
|
boundaries = []
|
|
for _, box, _, _ in boxes:
|
|
min_x = box['x']
|
|
min_y = box['y']
|
|
max_x = box['x'] + box['w']
|
|
max_y = box['y'] + box['h']
|
|
boundary = [
|
|
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y, 1.0
|
|
]
|
|
boundaries.append(boundary)
|
|
results.append({'boundary_result': boundaries})
|
|
|
|
# close tesserocr api
|
|
api.End()
|
|
|
|
if not is_batch:
|
|
return results[0]
|
|
else:
|
|
return results
|
|
|
|
def tesseract_recog_inference(self, imgs, **kwargs):
|
|
"""Inference image(s) with the tesseract recognizer.
|
|
|
|
Args:
|
|
imgs (ndarray or list[ndarray]): image(s) to inference.
|
|
|
|
Returns:
|
|
result (dict): Predicted results.
|
|
"""
|
|
is_batch = True
|
|
if isinstance(imgs, np.ndarray):
|
|
is_batch = False
|
|
imgs = [imgs]
|
|
assert is_type_list(imgs, np.ndarray)
|
|
api = self.get_tesserocr_api()
|
|
|
|
results = []
|
|
for img in imgs:
|
|
image = Image.fromarray(img)
|
|
api.SetImage(image)
|
|
api.SetRectangle(0, 0, img.shape[1], img.shape[0])
|
|
# Remove beginning and trailing spaces from Tesseract
|
|
text = api.GetUTF8Text().strip()
|
|
conf = api.MeanTextConf() / 100
|
|
results.append({'text': text, 'score': conf})
|
|
|
|
# close tesserocr api
|
|
api.End()
|
|
|
|
if not is_batch:
|
|
return results[0]
|
|
else:
|
|
return results
|
|
|
|
def readtext(self,
|
|
img,
|
|
output=None,
|
|
details=False,
|
|
export=None,
|
|
export_format='json',
|
|
batch_mode=False,
|
|
recog_batch_size=0,
|
|
det_batch_size=0,
|
|
single_batch_size=0,
|
|
imshow=False,
|
|
print_result=False,
|
|
merge=False,
|
|
merge_xdist=20,
|
|
**kwargs):
|
|
args = locals().copy()
|
|
[args.pop(x, None) for x in ['kwargs', 'self']]
|
|
args = Namespace(**args)
|
|
|
|
# Input and output arguments processing
|
|
self._args_processing(args)
|
|
self.args = args
|
|
|
|
pp_result = None
|
|
|
|
# Send args and models to the MMOCR model inference API
|
|
# and call post-processing functions for the output
|
|
if self.detect_model and self.recog_model:
|
|
det_recog_result = self.det_recog_kie_inference(
|
|
self.detect_model, self.recog_model, kie_model=self.kie_model)
|
|
pp_result = self.det_recog_pp(det_recog_result)
|
|
else:
|
|
for model in list(
|
|
filter(None, [self.recog_model, self.detect_model])):
|
|
result = self.single_inference(model, args.arrays,
|
|
args.batch_mode,
|
|
args.single_batch_size)
|
|
pp_result = self.single_pp(result, model)
|
|
|
|
return pp_result
|
|
|
|
# Post processing function for end2end ocr
|
|
def det_recog_pp(self, result):
|
|
final_results = []
|
|
args = self.args
|
|
for arr, output, export, det_recog_result in zip(
|
|
args.arrays, args.output, args.export, result):
|
|
if output or args.imshow:
|
|
if self.kie_model:
|
|
res_img = det_recog_show_result(arr, det_recog_result)
|
|
else:
|
|
res_img = det_recog_show_result(
|
|
arr, det_recog_result, out_file=output)
|
|
if args.imshow and not self.kie_model:
|
|
mmcv.imshow(res_img, 'inference results')
|
|
if not args.details:
|
|
simple_res = {}
|
|
simple_res['filename'] = det_recog_result['filename']
|
|
simple_res['text'] = [
|
|
x['text'] for x in det_recog_result['result']
|
|
]
|
|
final_result = simple_res
|
|
else:
|
|
final_result = det_recog_result
|
|
if export:
|
|
mmcv.dump(final_result, export, indent=4)
|
|
if args.print_result:
|
|
print(final_result, end='\n\n')
|
|
final_results.append(final_result)
|
|
return final_results
|
|
|
|
# Post processing function for separate det/recog inference
|
|
def single_pp(self, result, model):
|
|
for arr, output, export, res in zip(self.args.arrays, self.args.output,
|
|
self.args.export, result):
|
|
if export:
|
|
mmcv.dump(res, export, indent=4)
|
|
if output or self.args.imshow:
|
|
if model == 'Tesseract_det':
|
|
res_img = TextDetectorMixin(show_score=False).show_result(
|
|
arr, res, out_file=output)
|
|
elif model == 'Tesseract_recog':
|
|
res_img = BaseRecognizer.show_result(
|
|
arr, res, out_file=output)
|
|
else:
|
|
res_img = model.show_result(arr, res, out_file=output)
|
|
if self.args.imshow:
|
|
mmcv.imshow(res_img, 'inference results')
|
|
if self.args.print_result:
|
|
print(res, end='\n\n')
|
|
return result
|
|
|
|
def generate_kie_labels(self, result, boxes, class_list):
|
|
idx_to_cls = {}
|
|
if class_list is not None:
|
|
for line in list_from_file(class_list):
|
|
class_idx, class_label = line.strip().split()
|
|
idx_to_cls[class_idx] = class_label
|
|
|
|
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
|
|
node_pred_label = max_idx.numpy().tolist()
|
|
node_pred_score = max_value.numpy().tolist()
|
|
labels = []
|
|
for i in range(len(boxes)):
|
|
pred_label = str(node_pred_label[i])
|
|
if pred_label in idx_to_cls:
|
|
pred_label = idx_to_cls[pred_label]
|
|
pred_score = node_pred_score[i]
|
|
labels.append((pred_label, pred_score))
|
|
return labels
|
|
|
|
def visualize_kie_output(self,
|
|
model,
|
|
data,
|
|
result,
|
|
out_file=None,
|
|
show=False):
|
|
"""Visualizes KIE output."""
|
|
img_tensor = data['img'].data
|
|
img_meta = data['img_metas'].data
|
|
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
|
|
if img_tensor.dtype == torch.uint8:
|
|
# The img tensor is the raw input not being normalized
|
|
# (For SDMGR non-visual)
|
|
img = img_tensor.cpu().numpy().transpose(1, 2, 0)
|
|
else:
|
|
img = tensor2imgs(
|
|
img_tensor.unsqueeze(0), **img_meta.get('img_norm_cfg', {}))[0]
|
|
h, w, _ = img_meta.get('img_shape', img.shape)
|
|
img_show = img[:h, :w, :]
|
|
model.show_result(
|
|
img_show, result, gt_bboxes, show=show, out_file=out_file)
|
|
|
|
# End2end ocr inference pipeline
|
|
def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
|
|
end2end_res = []
|
|
# Find bounding boxes in the images (text detection)
|
|
det_result = self.single_inference(det_model, self.args.arrays,
|
|
self.args.batch_mode,
|
|
self.args.det_batch_size)
|
|
bboxes_list = [res['boundary_result'] for res in det_result]
|
|
|
|
if kie_model:
|
|
kie_dataset = WildReceiptDataset(
|
|
dict_file=kie_model.cfg.data.test.dict_file)
|
|
|
|
# For each bounding box, the image is cropped and
|
|
# sent to the recognition model either one by one
|
|
# or all together depending on the batch_mode
|
|
for filename, arr, bboxes, out_file in zip(self.args.filenames,
|
|
self.args.arrays,
|
|
bboxes_list,
|
|
self.args.output):
|
|
img_e2e_res = {}
|
|
img_e2e_res['filename'] = filename
|
|
img_e2e_res['result'] = []
|
|
box_imgs = []
|
|
for bbox in bboxes:
|
|
box_res = {}
|
|
box_res['box'] = [round(x) for x in bbox[:-1]]
|
|
box_res['box_score'] = float(bbox[-1])
|
|
box = bbox[:8]
|
|
if len(bbox) > 9:
|
|
min_x = min(bbox[0:-1:2])
|
|
min_y = min(bbox[1:-1:2])
|
|
max_x = max(bbox[0:-1:2])
|
|
max_y = max(bbox[1:-1:2])
|
|
box = [
|
|
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
|
|
]
|
|
box_img = crop_img(arr, box)
|
|
if self.args.batch_mode:
|
|
box_imgs.append(box_img)
|
|
else:
|
|
if recog_model == 'Tesseract_recog':
|
|
recog_result = self.single_inference(
|
|
recog_model, box_img, batch_mode=True)
|
|
else:
|
|
recog_result = model_inference(recog_model, box_img)
|
|
text = recog_result['text']
|
|
text_score = recog_result['score']
|
|
if isinstance(text_score, list):
|
|
text_score = sum(text_score) / max(1, len(text))
|
|
box_res['text'] = text
|
|
box_res['text_score'] = text_score
|
|
img_e2e_res['result'].append(box_res)
|
|
|
|
if self.args.batch_mode:
|
|
recog_results = self.single_inference(
|
|
recog_model, box_imgs, True, self.args.recog_batch_size)
|
|
for i, recog_result in enumerate(recog_results):
|
|
text = recog_result['text']
|
|
text_score = recog_result['score']
|
|
if isinstance(text_score, (list, tuple)):
|
|
text_score = sum(text_score) / max(1, len(text))
|
|
img_e2e_res['result'][i]['text'] = text
|
|
img_e2e_res['result'][i]['text_score'] = text_score
|
|
|
|
if self.args.merge:
|
|
img_e2e_res['result'] = stitch_boxes_into_lines(
|
|
img_e2e_res['result'], self.args.merge_xdist, 0.5)
|
|
|
|
if kie_model:
|
|
annotations = copy.deepcopy(img_e2e_res['result'])
|
|
# Customized for kie_dataset, which
|
|
# assumes that boxes are represented by only 4 points
|
|
for i, ann in enumerate(annotations):
|
|
min_x = min(ann['box'][::2])
|
|
min_y = min(ann['box'][1::2])
|
|
max_x = max(ann['box'][::2])
|
|
max_y = max(ann['box'][1::2])
|
|
annotations[i]['box'] = [
|
|
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
|
|
]
|
|
ann_info = kie_dataset._parse_anno_info(annotations)
|
|
ann_info['ori_bboxes'] = ann_info.get('ori_bboxes',
|
|
ann_info['bboxes'])
|
|
ann_info['gt_bboxes'] = ann_info.get('gt_bboxes',
|
|
ann_info['bboxes'])
|
|
kie_result, data = model_inference(
|
|
kie_model,
|
|
arr,
|
|
ann=ann_info,
|
|
return_data=True,
|
|
batch_mode=self.args.batch_mode)
|
|
# visualize KIE results
|
|
self.visualize_kie_output(
|
|
kie_model,
|
|
data,
|
|
kie_result,
|
|
out_file=out_file,
|
|
show=self.args.imshow)
|
|
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
|
|
labels = self.generate_kie_labels(kie_result, gt_bboxes,
|
|
kie_model.class_list)
|
|
for i in range(len(gt_bboxes)):
|
|
img_e2e_res['result'][i]['label'] = labels[i][0]
|
|
img_e2e_res['result'][i]['label_score'] = labels[i][1]
|
|
|
|
end2end_res.append(img_e2e_res)
|
|
return end2end_res
|
|
|
|
# Separate det/recog inference pipeline
|
|
def single_inference(self, model, arrays, batch_mode, batch_size=0):
|
|
|
|
def inference(m, a, **kwargs):
|
|
if model == 'Tesseract_det':
|
|
return self.tesseract_det_inference(a)
|
|
elif model == 'Tesseract_recog':
|
|
return self.tesseract_recog_inference(a)
|
|
else:
|
|
return model_inference(m, a, **kwargs)
|
|
|
|
result = []
|
|
if batch_mode:
|
|
if batch_size == 0:
|
|
result = inference(model, arrays, batch_mode=True)
|
|
else:
|
|
n = batch_size
|
|
arr_chunks = [
|
|
arrays[i:i + n] for i in range(0, len(arrays), n)
|
|
]
|
|
for chunk in arr_chunks:
|
|
result.extend(inference(model, chunk, batch_mode=True))
|
|
else:
|
|
for arr in arrays:
|
|
result.append(inference(model, arr, batch_mode=False))
|
|
return result
|
|
|
|
# Arguments pre-processing function
|
|
def _args_processing(self, args):
|
|
# Check if the input is a list/tuple that
|
|
# contains only np arrays or strings
|
|
if isinstance(args.img, (list, tuple)):
|
|
img_list = args.img
|
|
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
|
|
raise AssertionError('Images must be strings or numpy arrays')
|
|
|
|
# Create a list of the images
|
|
if isinstance(args.img, str):
|
|
img_path = Path(args.img)
|
|
if img_path.is_dir():
|
|
img_list = [str(x) for x in img_path.glob('*')]
|
|
else:
|
|
img_list = [str(img_path)]
|
|
elif isinstance(args.img, np.ndarray):
|
|
img_list = [args.img]
|
|
|
|
# Read all image(s) in advance to reduce wasted time
|
|
# re-reading the images for visualization output
|
|
args.arrays = [mmcv.imread(x) for x in img_list]
|
|
|
|
# Create a list of filenames (used for output images and result files)
|
|
if isinstance(img_list[0], str):
|
|
args.filenames = [str(Path(x).stem) for x in img_list]
|
|
else:
|
|
args.filenames = [str(x) for x in range(len(img_list))]
|
|
|
|
# If given an output argument, create a list of output image filenames
|
|
num_res = len(img_list)
|
|
if args.output:
|
|
output_path = Path(args.output)
|
|
if output_path.is_dir():
|
|
args.output = [
|
|
str(output_path / f'out_{x}.png') for x in args.filenames
|
|
]
|
|
else:
|
|
args.output = [str(args.output)]
|
|
if args.batch_mode:
|
|
raise AssertionError('Output of multiple images inference'
|
|
' must be a directory')
|
|
else:
|
|
args.output = [None] * num_res
|
|
|
|
# If given an export argument, create a list of
|
|
# result filenames for each image
|
|
if args.export:
|
|
export_path = Path(args.export)
|
|
args.export = [
|
|
str(export_path / f'out_{x}.{args.export_format}')
|
|
for x in args.filenames
|
|
]
|
|
else:
|
|
args.export = [None] * num_res
|
|
|
|
return args
|
|
|
|
|
|
# Create an inference pipeline with parsed arguments
|
|
def main():
|
|
args = parse_args()
|
|
ocr = MMOCR(**vars(args))
|
|
ocr.readtext(**vars(args))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|