2020-05-10 16:26:57 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
2020-10-13 17:13:33 +08:00
import os
2020-10-22 18:24:42 +08:00
import sys
2021-12-27 15:10:32 +08:00
import platform
2020-05-10 16:26:57 +08:00
import cv2
import numpy as np
2021-11-10 20:20:45 +08:00
import paddle
2020-05-13 20:29:45 +08:00
from PIL import Image , ImageDraw , ImageFont
2020-05-27 14:55:58 +08:00
import math
2020-12-21 17:10:00 +08:00
from paddle import inference
2021-05-26 18:40:16 +08:00
import time
2022-09-04 12:40:45 +08:00
import random
2021-05-26 18:40:16 +08:00
from ppocr . utils . logging import get_logger
2021-06-02 20:10:59 +08:00
2021-08-02 20:18:49 +08:00
2021-05-26 17:34:47 +08:00
def str2bool ( v ) :
2023-07-20 20:24:42 +08:00
return v . lower ( ) in ( " true " , " yes " , " t " , " y " , " 1 " )
2020-05-10 16:26:57 +08:00
2023-07-21 10:22:05 +08:00
2023-07-20 20:24:42 +08:00
def str2int_tuple ( v ) :
return tuple ( [ int ( i . strip ( ) ) for i in v . split ( " , " ) ] )
2020-05-10 16:26:57 +08:00
2023-07-21 10:22:05 +08:00
2021-06-02 20:10:59 +08:00
def init_args ( ) :
2020-05-10 16:26:57 +08:00
parser = argparse . ArgumentParser ( )
2020-11-17 12:54:24 +08:00
# params for prediction engine
2020-05-10 16:26:57 +08:00
parser . add_argument ( " --use_gpu " , type = str2bool , default = True )
2022-05-23 23:38:33 +08:00
parser . add_argument ( " --use_xpu " , type = str2bool , default = False )
2022-09-13 14:11:21 +08:00
parser . add_argument ( " --use_npu " , type = str2bool , default = False )
2023-07-21 10:22:05 +08:00
parser . add_argument ( " --use_mlu " , type = str2bool , default = False )
2020-05-10 16:26:57 +08:00
parser . add_argument ( " --ir_optim " , type = str2bool , default = True )
parser . add_argument ( " --use_tensorrt " , type = str2bool , default = False )
2021-08-26 10:21:22 +08:00
parser . add_argument ( " --min_subgraph_size " , type = int , default = 15 )
2021-05-26 18:40:16 +08:00
parser . add_argument ( " --precision " , type = str , default = " fp32 " )
2020-12-22 15:57:21 +08:00
parser . add_argument ( " --gpu_mem " , type = int , default = 500 )
2023-01-30 16:19:45 +08:00
parser . add_argument ( " --gpu_id " , type = int , default = 0 )
2020-05-10 16:26:57 +08:00
2020-11-17 12:54:24 +08:00
# params for text detector
2020-05-10 16:26:57 +08:00
parser . add_argument ( " --image_dir " , type = str )
2022-09-29 15:31:45 +08:00
parser . add_argument ( " --page_num " , type = int , default = 0 )
2020-05-10 16:26:57 +08:00
parser . add_argument ( " --det_algorithm " , type = str , default = ' DB ' )
parser . add_argument ( " --det_model_dir " , type = str )
2020-11-17 17:28:28 +08:00
parser . add_argument ( " --det_limit_side_len " , type = float , default = 960 )
parser . add_argument ( " --det_limit_type " , type = str , default = ' max ' )
2022-10-09 11:28:32 +08:00
parser . add_argument ( " --det_box_type " , type = str , default = ' quad ' )
2020-05-10 16:26:57 +08:00
2020-11-17 12:54:24 +08:00
# DB parmas
2020-05-10 16:26:57 +08:00
parser . add_argument ( " --det_db_thresh " , type = float , default = 0.3 )
2021-08-02 20:18:49 +08:00
parser . add_argument ( " --det_db_box_thresh " , type = float , default = 0.6 )
parser . add_argument ( " --det_db_unclip_ratio " , type = float , default = 1.5 )
2020-12-18 15:27:44 +08:00
parser . add_argument ( " --max_batch_size " , type = int , default = 10 )
2021-08-07 15:43:05 +08:00
parser . add_argument ( " --use_dilation " , type = str2bool , default = False )
2021-04-23 15:43:16 +08:00
parser . add_argument ( " --det_db_score_mode " , type = str , default = " fast " )
2022-10-09 11:28:32 +08:00
2020-11-17 12:54:24 +08:00
# EAST parmas
2020-05-10 16:26:57 +08:00
parser . add_argument ( " --det_east_score_thresh " , type = float , default = 0.8 )
parser . add_argument ( " --det_east_cover_thresh " , type = float , default = 0.1 )
parser . add_argument ( " --det_east_nms_thresh " , type = float , default = 0.2 )
2020-11-17 12:54:24 +08:00
# SAST parmas
2020-08-17 20:29:28 +08:00
parser . add_argument ( " --det_sast_score_thresh " , type = float , default = 0.5 )
parser . add_argument ( " --det_sast_nms_thresh " , type = float , default = 0.2 )
2021-07-29 19:31:30 +08:00
# PSE parmas
parser . add_argument ( " --det_pse_thresh " , type = float , default = 0 )
parser . add_argument ( " --det_pse_box_thresh " , type = float , default = 0.85 )
parser . add_argument ( " --det_pse_min_area " , type = float , default = 16 )
parser . add_argument ( " --det_pse_scale " , type = int , default = 1 )
2022-03-11 17:12:35 +08:00
# FCE parmas
parser . add_argument ( " --scales " , type = list , default = [ 8 , 16 , 32 ] )
parser . add_argument ( " --alpha " , type = float , default = 1.0 )
parser . add_argument ( " --beta " , type = float , default = 1.0 )
parser . add_argument ( " --fourier_degree " , type = int , default = 5 )
2020-11-17 12:54:24 +08:00
# params for text recognizer
2022-05-10 10:41:54 +08:00
parser . add_argument ( " --rec_algorithm " , type = str , default = ' SVTR_LCNet ' )
2020-05-10 16:26:57 +08:00
parser . add_argument ( " --rec_model_dir " , type = str )
2022-10-15 20:27:05 +08:00
parser . add_argument ( " --rec_image_inverse " , type = str2bool , default = True )
2022-05-09 20:44:01 +08:00
parser . add_argument ( " --rec_image_shape " , type = str , default = " 3, 48, 320 " )
2020-12-22 15:57:21 +08:00
parser . add_argument ( " --rec_batch_num " , type = int , default = 6 )
2020-08-16 13:30:25 +08:00
parser . add_argument ( " --max_text_length " , type = int , default = 25 )
2020-05-10 16:26:57 +08:00
parser . add_argument (
" --rec_char_dict_path " ,
type = str ,
default = " ./ppocr/utils/ppocr_keys_v1.txt " )
2020-11-17 12:54:24 +08:00
parser . add_argument ( " --use_space_char " , type = str2bool , default = True )
parser . add_argument (
2021-01-26 15:24:13 +08:00
" --vis_font_path " , type = str , default = " ./doc/fonts/simfang.ttf " )
2020-12-07 12:51:40 +08:00
parser . add_argument ( " --drop_score " , type = float , default = 0.5 )
2020-11-17 12:54:24 +08:00
2021-03-15 13:58:53 +08:00
# params for e2e
parser . add_argument ( " --e2e_algorithm " , type = str , default = ' PGNet ' )
parser . add_argument ( " --e2e_model_dir " , type = str )
parser . add_argument ( " --e2e_limit_side_len " , type = float , default = 768 )
parser . add_argument ( " --e2e_limit_type " , type = str , default = ' max ' )
# PGNet parmas
parser . add_argument ( " --e2e_pgnet_score_thresh " , type = float , default = 0.5 )
parser . add_argument (
2021-03-19 11:59:35 +08:00
" --e2e_char_dict_path " , type = str , default = " ./ppocr/utils/ic15_dict.txt " )
2021-03-15 13:58:53 +08:00
parser . add_argument ( " --e2e_pgnet_valid_set " , type = str , default = ' totaltext ' )
2021-04-12 18:59:05 +08:00
parser . add_argument ( " --e2e_pgnet_mode " , type = str , default = ' fast ' )
2021-03-15 13:58:53 +08:00
2020-11-17 12:54:24 +08:00
# params for text classifier
parser . add_argument ( " --use_angle_cls " , type = str2bool , default = False )
parser . add_argument ( " --cls_model_dir " , type = str )
parser . add_argument ( " --cls_image_shape " , type = str , default = " 3, 48, 192 " )
parser . add_argument ( " --label_list " , type = list , default = [ ' 0 ' , ' 180 ' ] )
2020-12-18 11:11:02 +08:00
parser . add_argument ( " --cls_batch_num " , type = int , default = 6 )
2020-11-17 12:54:24 +08:00
parser . add_argument ( " --cls_thresh " , type = float , default = 0.9 )
parser . add_argument ( " --enable_mkldnn " , type = str2bool , default = False )
2021-05-26 18:40:16 +08:00
parser . add_argument ( " --cpu_threads " , type = int , default = 10 )
2020-11-17 12:54:24 +08:00
parser . add_argument ( " --use_pdserving " , type = str2bool , default = False )
2021-11-17 12:32:05 +08:00
parser . add_argument ( " --warmup " , type = str2bool , default = False )
2022-08-12 10:49:54 +08:00
# SR parmas
parser . add_argument ( " --sr_model_dir " , type = str )
parser . add_argument ( " --sr_image_shape " , type = str , default = " 3, 32, 128 " )
parser . add_argument ( " --sr_batch_num " , type = int , default = 1 )
2021-11-17 12:32:05 +08:00
#
parser . add_argument (
" --draw_img_save_dir " , type = str , default = " ./inference_results " )
parser . add_argument ( " --save_crop_res " , type = str2bool , default = False )
parser . add_argument ( " --crop_res_save_dir " , type = str , default = " ./output " )
2020-11-17 12:54:24 +08:00
2021-06-16 10:04:27 +08:00
# multi-process
2021-04-09 19:50:01 +08:00
parser . add_argument ( " --use_mp " , type = str2bool , default = False )
2021-04-09 18:19:34 +08:00
parser . add_argument ( " --total_process_num " , type = int , default = 1 )
parser . add_argument ( " --process_id " , type = int , default = 0 )
2021-06-30 10:45:27 +08:00
2021-08-07 15:43:05 +08:00
parser . add_argument ( " --benchmark " , type = str2bool , default = False )
2021-05-26 18:40:16 +08:00
parser . add_argument ( " --save_log_path " , type = str , default = " ./log_output/ " )
2021-06-07 13:50:36 +08:00
2021-06-05 22:00:17 +08:00
parser . add_argument ( " --show_log " , type = str2bool , default = True )
2021-11-08 16:40:53 +08:00
parser . add_argument ( " --use_onnx " , type = str2bool , default = False )
2023-08-02 19:11:28 +08:00
# extended function
parser . add_argument ( " --return_word_box " , type = str2bool , default = False , help = ' Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery ' )
2021-06-02 20:10:59 +08:00
return parser
2020-11-17 12:54:24 +08:00
2021-04-09 18:19:34 +08:00
2021-05-26 17:34:47 +08:00
def parse_args ( ) :
2021-06-02 20:10:59 +08:00
parser = init_args ( )
2020-05-10 16:26:57 +08:00
return parser . parse_args ( )
2020-10-22 18:24:42 +08:00
def create_predictor ( args , mode , logger ) :
if mode == " det " :
model_dir = args . det_model_dir
elif mode == ' cls ' :
model_dir = args . cls_model_dir
2021-03-15 13:58:53 +08:00
elif mode == ' rec ' :
2020-10-22 18:24:42 +08:00
model_dir = args . rec_model_dir
2021-06-23 12:28:32 +08:00
elif mode == ' table ' :
model_dir = args . table_model_dir
2022-07-01 16:52:08 +08:00
elif mode == ' ser ' :
model_dir = args . ser_model_dir
2022-09-20 22:13:27 +08:00
elif mode == ' re ' :
model_dir = args . re_model_dir
2022-08-12 10:49:54 +08:00
elif mode == " sr " :
model_dir = args . sr_model_dir
2022-08-08 14:50:27 +08:00
elif mode == ' layout ' :
model_dir = args . layout_model_dir
2021-03-15 13:58:53 +08:00
else :
model_dir = args . e2e_model_dir
2020-10-22 18:24:42 +08:00
if model_dir is None :
logger . info ( " not find {} model file path {} " . format ( mode , model_dir ) )
sys . exit ( 0 )
2021-11-08 16:40:53 +08:00
if args . use_onnx :
import onnxruntime as ort
model_file_path = model_dir
if not os . path . exists ( model_file_path ) :
raise ValueError ( " not find model file path {} " . format (
model_file_path ) )
sess = ort . InferenceSession ( model_file_path )
return sess , sess . get_inputs ( ) [ 0 ] , None , None
2021-05-26 18:40:16 +08:00
2021-11-25 09:56:59 +08:00
else :
2022-08-20 16:10:45 +08:00
file_names = [ ' model ' , ' inference ' ]
for file_name in file_names :
model_file_path = ' {} / {} .pdmodel ' . format ( model_dir , file_name )
params_file_path = ' {} / {} .pdiparams ' . format ( model_dir , file_name )
if os . path . exists ( model_file_path ) and os . path . exists (
params_file_path ) :
break
2021-11-08 16:40:53 +08:00
if not os . path . exists ( model_file_path ) :
2022-08-20 16:10:45 +08:00
raise ValueError (
" not find model.pdmodel or inference.pdmodel in {} " . format (
model_dir ) )
2021-11-08 16:40:53 +08:00
if not os . path . exists ( params_file_path ) :
2022-08-20 16:10:45 +08:00
raise ValueError (
" not find model.pdiparams or inference.pdiparams in {} " . format (
model_dir ) )
2021-11-08 16:40:53 +08:00
config = inference . Config ( model_file_path , params_file_path )
if hasattr ( args , ' precision ' ) :
if args . precision == " fp16 " and args . use_tensorrt :
precision = inference . PrecisionType . Half
elif args . precision == " int8 " :
precision = inference . PrecisionType . Int8
else :
precision = inference . PrecisionType . Float32
2021-05-26 18:54:54 +08:00
else :
2021-11-08 16:40:53 +08:00
precision = inference . PrecisionType . Float32
if args . use_gpu :
gpu_id = get_infer_gpuid ( )
if gpu_id is None :
2021-11-09 11:34:17 +08:00
logger . warning (
2022-05-02 14:01:29 +08:00
" GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson. "
2021-11-08 16:40:53 +08:00
)
2023-01-30 16:19:45 +08:00
config . enable_use_gpu ( args . gpu_mem , args . gpu_id )
2021-11-08 16:40:53 +08:00
if args . use_tensorrt :
config . enable_tensorrt_engine (
2021-11-25 09:56:59 +08:00
workspace_size = 1 << 30 ,
2021-11-08 16:40:53 +08:00
precision_mode = precision ,
max_batch_size = args . max_batch_size ,
2022-08-08 14:58:15 +08:00
min_subgraph_size = args .
min_subgraph_size , # skip the minmum trt subgraph
2022-05-24 15:34:24 +08:00
use_calib_mode = False )
2022-08-08 14:58:15 +08:00
2022-08-22 16:31:54 +08:00
# collect shape
2022-09-20 22:13:27 +08:00
trt_shape_f = os . path . join ( model_dir ,
f " { mode } _trt_dynamic_shape.txt " )
2022-09-19 14:03:15 +08:00
if not os . path . exists ( trt_shape_f ) :
config . collect_shape_range_info ( trt_shape_f )
logger . info (
f " collect dynamic shape info into : { trt_shape_f } " )
try :
config . enable_tuned_tensorrt_dynamic_shape ( trt_shape_f ,
True )
except Exception as E :
logger . info ( E )
logger . info ( " Please keep your paddlepaddle-gpu >= 2.3.0! " )
2021-05-26 15:18:40 +08:00
2022-09-13 14:11:21 +08:00
elif args . use_npu :
2023-04-20 10:26:37 +08:00
config . enable_custom_device ( " npu " )
2023-07-21 10:22:05 +08:00
elif args . use_mlu :
config . enable_custom_device ( " mlu " )
2022-05-23 23:38:33 +08:00
elif args . use_xpu :
config . enable_xpu ( 10 * 1024 * 1024 )
2021-05-26 18:40:16 +08:00
else :
2021-11-08 16:40:53 +08:00
config . disable_gpu ( )
if args . enable_mkldnn :
# cache 10 different shapes for mkldnn to avoid memory leak
config . set_mkldnn_cache_capacity ( 10 )
config . enable_mkldnn ( )
if args . precision == " fp16 " :
config . enable_mkldnn_bfloat16 ( )
2022-08-23 19:52:42 +08:00
if hasattr ( args , " cpu_threads " ) :
config . set_cpu_math_library_num_threads ( args . cpu_threads )
else :
# default cpu threads as 10
config . set_cpu_math_library_num_threads ( 10 )
2021-11-08 16:40:53 +08:00
# enable memory optim
config . enable_memory_optim ( )
2022-05-13 22:46:02 +08:00
config . disable_glog_info ( )
2021-11-08 16:40:53 +08:00
config . delete_pass ( " conv_transpose_eltwiseadd_bn_fuse_pass " )
2022-04-27 15:48:14 +08:00
config . delete_pass ( " matmul_transpose_reshape_fuse_pass " )
2022-09-20 22:13:27 +08:00
if mode == ' re ' :
config . delete_pass ( " simplify_with_basic_ops_pass " )
2021-11-08 16:40:53 +08:00
if mode == ' table ' :
config . delete_pass ( " fc_fuse_pass " ) # not supported for table
config . switch_use_feed_fetch_ops ( False )
config . switch_ir_optim ( True )
# create predictor
predictor = inference . create_predictor ( config )
input_names = predictor . get_input_names ( )
2022-07-01 17:59:11 +08:00
if mode in [ ' ser ' , ' re ' ] :
2022-07-01 16:52:08 +08:00
input_tensor = [ ]
for name in input_names :
input_tensor . append ( predictor . get_input_handle ( name ) )
else :
for name in input_names :
input_tensor = predictor . get_input_handle ( name )
2022-04-07 16:26:20 +08:00
output_tensors = get_output_tensors ( args , mode , predictor )
return predictor , input_tensor , output_tensors , config
def get_output_tensors ( args , mode , predictor ) :
output_names = predictor . get_output_names ( )
output_tensors = [ ]
2022-05-10 10:41:54 +08:00
if mode == " rec " and args . rec_algorithm in [ " CRNN " , " SVTR_LCNet " ] :
2022-04-07 16:26:20 +08:00
output_name = ' softmax_0.tmp_0 '
if output_name in output_names :
return [ predictor . get_output_handle ( output_name ) ]
2022-04-12 14:00:07 +08:00
else :
for output_name in output_names :
output_tensor = predictor . get_output_handle ( output_name )
output_tensors . append ( output_tensor )
2022-04-07 16:26:20 +08:00
else :
2021-11-08 16:40:53 +08:00
for output_name in output_names :
output_tensor = predictor . get_output_handle ( output_name )
output_tensors . append ( output_tensor )
2022-04-07 16:26:20 +08:00
return output_tensors
2020-10-22 18:24:42 +08:00
2021-08-17 10:43:15 +08:00
def get_infer_gpuid ( ) :
2021-12-27 15:10:32 +08:00
sysstr = platform . system ( )
if sysstr == " Windows " :
return 0
2021-12-02 20:48:00 +08:00
if not paddle . fluid . core . is_compiled_with_rocm ( ) :
cmd = " env | grep CUDA_VISIBLE_DEVICES "
else :
cmd = " env | grep HIP_VISIBLE_DEVICES "
2021-08-17 10:43:15 +08:00
env_cuda = os . popen ( cmd ) . readlines ( )
if len ( env_cuda ) == 0 :
return 0
else :
gpu_id = env_cuda [ 0 ] . strip ( ) . split ( " = " ) [ 1 ]
return int ( gpu_id [ 0 ] )
2021-03-15 13:58:53 +08:00
def draw_e2e_res ( dt_boxes , strs , img_path ) :
src_im = cv2 . imread ( img_path )
for box , str in zip ( dt_boxes , strs ) :
box = box . astype ( np . int32 ) . reshape ( ( - 1 , 1 , 2 ) )
cv2 . polylines ( src_im , [ box ] , True , color = ( 255 , 255 , 0 ) , thickness = 2 )
cv2 . putText (
src_im ,
str ,
org = ( int ( box [ 0 , 0 , 0 ] ) , int ( box [ 0 , 0 , 1 ] ) ) ,
fontFace = cv2 . FONT_HERSHEY_COMPLEX ,
fontScale = 0.7 ,
color = ( 0 , 255 , 0 ) ,
thickness = 1 )
return src_im
2022-09-29 15:51:13 +08:00
def draw_text_det_res ( dt_boxes , img ) :
2020-05-10 16:26:57 +08:00
for box in dt_boxes :
box = np . array ( box ) . astype ( np . int32 ) . reshape ( - 1 , 2 )
2022-09-29 15:51:13 +08:00
cv2 . polylines ( img , [ box ] , True , color = ( 255 , 255 , 0 ) , thickness = 2 )
return img
2020-05-13 20:29:45 +08:00
2020-05-14 12:08:11 +08:00
def resize_img ( img , input_size = 600 ) :
"""
2020-05-28 15:46:05 +08:00
resize img and limit the longest side of the image to input_size
2020-05-14 12:08:11 +08:00
"""
img = np . array ( img )
im_shape = img . shape
im_size_max = np . max ( im_shape [ 0 : 2 ] )
im_scale = float ( input_size ) / float ( im_size_max )
2020-10-13 17:13:33 +08:00
img = cv2 . resize ( img , None , None , fx = im_scale , fy = im_scale )
return img
2020-05-14 12:08:11 +08:00
2020-08-22 19:42:14 +08:00
def draw_ocr ( image ,
boxes ,
txts = None ,
scores = None ,
drop_score = 0.5 ,
2021-05-26 15:09:52 +08:00
font_path = " ./doc/fonts/simfang.ttf " ) :
2020-05-27 14:55:58 +08:00
"""
Visualize the results of OCR detection and recognition
args :
2020-05-28 15:46:05 +08:00
image ( Image | array ) : RGB image
2020-05-27 14:55:58 +08:00
boxes ( list ) : boxes with shape ( N , 4 , 2 )
txts ( list ) : the texts
scores ( list ) : txxs corresponding scores
drop_score ( float ) : only scores greater than drop_threshold will be visualized
2020-08-22 19:42:14 +08:00
font_path : the path of font which is used to draw text
2020-05-27 14:55:58 +08:00
return ( array ) :
the visualized img
"""
2020-05-15 14:22:57 +08:00
if scores is None :
scores = [ 1 ] * len ( boxes )
2020-08-22 19:42:14 +08:00
box_num = len ( boxes )
for i in range ( box_num ) :
if scores is not None and ( scores [ i ] < drop_score or
math . isnan ( scores [ i ] ) ) :
2020-05-14 12:08:11 +08:00
continue
2020-08-22 19:42:14 +08:00
box = np . reshape ( np . array ( boxes [ i ] ) , [ - 1 , 1 , 2 ] ) . astype ( np . int64 )
2020-05-28 20:06:26 +08:00
image = cv2 . polylines ( np . array ( image ) , [ box ] , True , ( 255 , 0 , 0 ) , 2 )
2020-08-22 19:42:14 +08:00
if txts is not None :
2020-05-28 20:06:26 +08:00
img = np . array ( resize_img ( image , input_size = 600 ) )
2020-05-27 14:55:58 +08:00
txt_img = text_visual (
2020-08-22 19:42:14 +08:00
txts ,
scores ,
img_h = img . shape [ 0 ] ,
img_w = 600 ,
threshold = drop_score ,
font_path = font_path )
2020-05-27 14:55:58 +08:00
img = np . concatenate ( [ np . array ( img ) , np . array ( txt_img ) ] , axis = 1 )
2020-05-28 20:06:26 +08:00
return img
return image
2020-05-27 14:55:58 +08:00
2020-12-07 13:10:12 +08:00
def draw_ocr_box_txt ( image ,
boxes ,
2022-09-04 12:38:18 +08:00
txts = None ,
2020-12-07 13:10:12 +08:00
scores = None ,
drop_score = 0.5 ,
2022-09-04 12:38:18 +08:00
font_path = " ./doc/fonts/simfang.ttf " ) :
2022-09-04 10:01:16 +08:00
h , w = image . height , image . width
img_left = image . copy ( )
img_right = np . ones ( ( h , w , 3 ) , dtype = np . uint8 ) * 255
random . seed ( 0 )
draw_left = ImageDraw . Draw ( img_left )
if txts is None or len ( txts ) != len ( boxes ) :
txts = [ None ] * len ( boxes )
for idx , ( box , txt ) in enumerate ( zip ( boxes , txts ) ) :
if scores is not None and scores [ idx ] < drop_score :
continue
2022-09-13 14:11:21 +08:00
color = ( random . randint ( 0 , 255 ) , random . randint ( 0 , 255 ) ,
random . randint ( 0 , 255 ) )
2022-09-04 10:01:16 +08:00
draw_left . polygon ( box , fill = color )
img_right_text = draw_box_txt_fine ( ( w , h ) , box , txt , font_path )
pts = np . array ( box , np . int32 ) . reshape ( ( - 1 , 1 , 2 ) )
cv2 . polylines ( img_right_text , [ pts ] , True , color , 1 )
img_right = cv2 . bitwise_and ( img_right , img_right_text )
img_left = Image . blend ( image , img_left , 0.5 )
img_show = Image . new ( ' RGB ' , ( w * 2 , h ) , ( 255 , 255 , 255 ) )
img_show . paste ( img_left , ( 0 , 0 , w , h ) )
img_show . paste ( Image . fromarray ( img_right ) , ( w , 0 , w * 2 , h ) )
return np . array ( img_show )
def draw_box_txt_fine ( img_size , box , txt , font_path = " ./doc/fonts/simfang.ttf " ) :
2022-09-13 14:11:21 +08:00
box_height = int (
math . sqrt ( ( box [ 0 ] [ 0 ] - box [ 3 ] [ 0 ] ) * * 2 + ( box [ 0 ] [ 1 ] - box [ 3 ] [ 1 ] ) * * 2 ) )
box_width = int (
math . sqrt ( ( box [ 0 ] [ 0 ] - box [ 1 ] [ 0 ] ) * * 2 + ( box [ 0 ] [ 1 ] - box [ 1 ] [ 1 ] ) * * 2 ) )
2022-09-04 10:01:16 +08:00
if box_height > 2 * box_width and box_height > 30 :
img_text = Image . new ( ' RGB ' , ( box_height , box_width ) , ( 255 , 255 , 255 ) )
draw_text = ImageDraw . Draw ( img_text )
if txt :
font = create_font ( txt , ( box_height , box_width ) , font_path )
draw_text . text ( [ 0 , 0 ] , txt , fill = ( 0 , 0 , 0 ) , font = font )
img_text = img_text . transpose ( Image . ROTATE_270 )
else :
img_text = Image . new ( ' RGB ' , ( box_width , box_height ) , ( 255 , 255 , 255 ) )
draw_text = ImageDraw . Draw ( img_text )
if txt :
font = create_font ( txt , ( box_width , box_height ) , font_path )
draw_text . text ( [ 0 , 0 ] , txt , fill = ( 0 , 0 , 0 ) , font = font )
2022-09-13 14:11:21 +08:00
pts1 = np . float32 (
[ [ 0 , 0 ] , [ box_width , 0 ] , [ box_width , box_height ] , [ 0 , box_height ] ] )
2022-09-04 10:01:16 +08:00
pts2 = np . array ( box , dtype = np . float32 )
M = cv2 . getPerspectiveTransform ( pts1 , pts2 )
img_text = np . array ( img_text , dtype = np . uint8 )
2022-09-13 14:11:21 +08:00
img_right_text = cv2 . warpPerspective (
img_text ,
M ,
img_size ,
flags = cv2 . INTER_NEAREST ,
borderMode = cv2 . BORDER_CONSTANT ,
borderValue = ( 255 , 255 , 255 ) )
2022-09-04 10:01:16 +08:00
return img_right_text
def create_font ( txt , sz , font_path = " ./doc/fonts/simfang.ttf " ) :
font_size = int ( sz [ 1 ] * 0.99 )
font = ImageFont . truetype ( font_path , font_size , encoding = " utf-8 " )
length = font . getsize ( txt ) [ 0 ]
if length > sz [ 0 ] :
font_size = int ( font_size * sz [ 0 ] / length )
font = ImageFont . truetype ( font_path , font_size , encoding = " utf-8 " )
return font
2020-05-27 14:55:58 +08:00
def str_count ( s ) :
"""
Count the number of Chinese characters ,
a single English character and a single number
equal to half the length of Chinese characters .
args :
s ( string ) : the input of string
return ( int ) :
the number of Chinese characters
"""
import string
count_zh = count_pu = 0
s_len = len ( s )
en_dg_count = 0
for c in s :
if c in string . ascii_letters or c . isdigit ( ) or c . isspace ( ) :
en_dg_count + = 1
elif c . isalpha ( ) :
count_zh + = 1
else :
count_pu + = 1
return s_len - math . ceil ( en_dg_count / 2 )
2020-08-22 19:42:14 +08:00
def text_visual ( texts ,
scores ,
img_h = 400 ,
img_w = 600 ,
threshold = 0. ,
font_path = " ./doc/simfang.ttf " ) :
2020-05-27 14:55:58 +08:00
"""
create new blank img and draw txt on it
args :
texts ( list ) : the text will be draw
scores ( list | None ) : corresponding score of each txt
img_h ( int ) : the height of blank img
img_w ( int ) : the width of blank img
2020-08-22 19:42:14 +08:00
font_path : the path of font which is used to draw text
2020-05-27 14:55:58 +08:00
return ( array ) :
"""
if scores is not None :
assert len ( texts ) == len (
scores ) , " The number of txts and corresponding scores must match "
def create_blank_img ( ) :
blank_img = np . ones ( shape = [ img_h , img_w ] , dtype = np . int8 ) * 255
blank_img [ : , img_w - 1 : ] = 0
2020-05-13 20:29:45 +08:00
blank_img = Image . fromarray ( blank_img ) . convert ( " RGB " )
draw_txt = ImageDraw . Draw ( blank_img )
2020-05-27 14:55:58 +08:00
return blank_img , draw_txt
2020-05-13 20:29:45 +08:00
2020-05-27 14:55:58 +08:00
blank_img , draw_txt = create_blank_img ( )
font_size = 20
txt_color = ( 0 , 0 , 0 )
2020-08-22 19:42:14 +08:00
font = ImageFont . truetype ( font_path , font_size , encoding = " utf-8 " )
2020-05-27 14:55:58 +08:00
gap = font_size + 5
txt_img_list = [ ]
2020-05-28 20:06:26 +08:00
count , index = 1 , 0
2020-05-27 14:55:58 +08:00
for idx , txt in enumerate ( texts ) :
index + = 1
2020-05-28 15:46:05 +08:00
if scores [ idx ] < threshold or math . isnan ( scores [ idx ] ) :
2020-05-27 14:55:58 +08:00
index - = 1
continue
first_line = True
while str_count ( txt ) > = img_w / / font_size - 4 :
tmp = txt
txt = tmp [ : img_w / / font_size - 4 ]
if first_line :
new_txt = str ( index ) + ' : ' + txt
first_line = False
else :
new_txt = ' ' + txt
2020-05-29 15:23:09 +08:00
draw_txt . text ( ( 0 , gap * count ) , new_txt , txt_color , font = font )
2020-05-27 14:55:58 +08:00
txt = tmp [ img_w / / font_size - 4 : ]
if count > = img_h / / gap - 1 :
txt_img_list . append ( np . array ( blank_img ) )
blank_img , draw_txt = create_blank_img ( )
count = 0
2020-05-28 20:06:26 +08:00
count + = 1
2020-05-27 14:55:58 +08:00
if first_line :
new_txt = str ( index ) + ' : ' + txt + ' ' + ' %.3f ' % ( scores [ idx ] )
else :
2020-05-28 15:46:05 +08:00
new_txt = " " + txt + " " + ' %.3f ' % ( scores [ idx ] )
2020-05-28 20:06:26 +08:00
draw_txt . text ( ( 0 , gap * count ) , new_txt , txt_color , font = font )
2020-05-27 14:55:58 +08:00
# whether add new blank img or not
2020-05-28 20:06:26 +08:00
if count > = img_h / / gap - 1 and idx + 1 < len ( texts ) :
2020-05-27 14:55:58 +08:00
txt_img_list . append ( np . array ( blank_img ) )
blank_img , draw_txt = create_blank_img ( )
count = 0
2020-05-28 20:06:26 +08:00
count + = 1
2020-05-27 14:55:58 +08:00
txt_img_list . append ( np . array ( blank_img ) )
if len ( txt_img_list ) == 1 :
blank_img = np . array ( txt_img_list [ 0 ] )
else :
blank_img = np . concatenate ( txt_img_list , axis = 1 )
return np . array ( blank_img )
2020-05-13 20:29:45 +08:00
2020-07-09 20:34:42 +08:00
def base64_to_cv2 ( b64str ) :
import base64
data = base64 . b64decode ( b64str . encode ( ' utf8 ' ) )
2022-08-21 18:03:57 +08:00
data = np . frombuffer ( data , np . uint8 )
2020-07-09 20:34:42 +08:00
data = cv2 . imdecode ( data , cv2 . IMREAD_COLOR )
return data
def draw_boxes ( image , boxes , scores = None , drop_score = 0.5 ) :
if scores is None :
scores = [ 1 ] * len ( boxes )
for ( box , score ) in zip ( boxes , scores ) :
if score < drop_score :
continue
box = np . reshape ( np . array ( box ) , [ - 1 , 1 , 2 ] ) . astype ( np . int64 )
image = cv2 . polylines ( np . array ( image ) , [ box ] , True , ( 255 , 0 , 0 ) , 2 )
return image
2021-07-01 14:36:33 +08:00
def get_rotate_crop_image ( img , points ) :
'''
img_height , img_width = img . shape [ 0 : 2 ]
left = int ( np . min ( points [ : , 0 ] ) )
right = int ( np . max ( points [ : , 0 ] ) )
top = int ( np . min ( points [ : , 1 ] ) )
bottom = int ( np . max ( points [ : , 1 ] ) )
img_crop = img [ top : bottom , left : right , : ] . copy ( )
points [ : , 0 ] = points [ : , 0 ] - left
points [ : , 1 ] = points [ : , 1 ] - top
'''
assert len ( points ) == 4 , " shape of points must be 4*2 "
img_crop_width = int (
max (
np . linalg . norm ( points [ 0 ] - points [ 1 ] ) ,
np . linalg . norm ( points [ 2 ] - points [ 3 ] ) ) )
img_crop_height = int (
max (
np . linalg . norm ( points [ 0 ] - points [ 3 ] ) ,
np . linalg . norm ( points [ 1 ] - points [ 2 ] ) ) )
pts_std = np . float32 ( [ [ 0 , 0 ] , [ img_crop_width , 0 ] ,
[ img_crop_width , img_crop_height ] ,
[ 0 , img_crop_height ] ] )
M = cv2 . getPerspectiveTransform ( points , pts_std )
dst_img = cv2 . warpPerspective (
img ,
M , ( img_crop_width , img_crop_height ) ,
borderMode = cv2 . BORDER_REPLICATE ,
flags = cv2 . INTER_CUBIC )
dst_img_height , dst_img_width = dst_img . shape [ 0 : 2 ]
if dst_img_height * 1.0 / dst_img_width > = 1.5 :
dst_img = np . rot90 ( dst_img )
return dst_img
2022-10-27 15:37:15 +08:00
def get_minarea_rect_crop ( img , points ) :
bounding_box = cv2 . minAreaRect ( np . array ( points ) . astype ( np . int32 ) )
points = sorted ( list ( cv2 . boxPoints ( bounding_box ) ) , key = lambda x : x [ 0 ] )
index_a , index_b , index_c , index_d = 0 , 1 , 2 , 3
if points [ 1 ] [ 1 ] > points [ 0 ] [ 1 ] :
index_a = 0
index_d = 1
else :
index_a = 1
index_d = 0
if points [ 3 ] [ 1 ] > points [ 2 ] [ 1 ] :
index_b = 2
index_c = 3
else :
index_b = 3
index_c = 2
box = [ points [ index_a ] , points [ index_b ] , points [ index_c ] , points [ index_d ] ]
crop_img = get_rotate_crop_image ( img , np . array ( box ) )
return crop_img
2021-11-10 20:20:45 +08:00
def check_gpu ( use_gpu ) :
if use_gpu and not paddle . is_compiled_with_cuda ( ) :
use_gpu = False
return use_gpu
2020-05-13 20:29:45 +08:00
if __name__ == ' __main__ ' :
2021-05-26 15:36:44 +08:00
pass