support infer pdf file
parent
2f312ae06d
commit
e61f40ef41
|
@ -7,6 +7,7 @@
|
|||
| 参数名称 | 类型 | 默认值 | 含义 |
|
||||
| :--: | :--: | :--: | :--: |
|
||||
| image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 |
|
||||
| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 |
|
||||
| vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 |
|
||||
| drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 |
|
||||
| use_pdserving | bool | False | 是否使用Paddle Serving进行预测 |
|
||||
|
|
|
@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par
|
|||
| parameters | type | default | implication |
|
||||
| :--: | :--: | :--: | :--: |
|
||||
| image_dir | str | None, must be specified explicitly | Image or folder path |
|
||||
| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default |
|
||||
| vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization |
|
||||
| drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results |
|
||||
| use_pdserving | bool | False | Whether to use Paddle Serving for prediction |
|
||||
|
|
|
@ -282,44 +282,69 @@ if __name__ == "__main__":
|
|||
args = utility.parse_args()
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_detector = TextDetector(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
draw_img_save = "./inference_results"
|
||||
draw_img_save_dir = args.draw_img_save_dir
|
||||
os.makedirs(draw_img_save_dir, exist_ok=True)
|
||||
|
||||
if args.warmup:
|
||||
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
||||
for i in range(2):
|
||||
res = text_detector(img)
|
||||
|
||||
if not os.path.exists(draw_img_save):
|
||||
os.makedirs(draw_img_save)
|
||||
save_results = []
|
||||
for image_file in image_file_list:
|
||||
img, flag, _ = check_and_read(image_file)
|
||||
if not flag:
|
||||
for idx, image_file in enumerate(image_file_list):
|
||||
img, flag_gif, flag_pdf = check_and_read(image_file)
|
||||
if not flag_gif and not flag_pdf:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
st = time.time()
|
||||
dt_boxes, _ = text_detector(img)
|
||||
elapse = time.time() - st
|
||||
if count > 0:
|
||||
if not flag_pdf:
|
||||
if img is None:
|
||||
logger.debug("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
imgs = [img]
|
||||
else:
|
||||
page_num = args.page_num
|
||||
if page_num > len(img) or page_num == 0:
|
||||
page_num = len(img)
|
||||
imgs = img[:page_num]
|
||||
for index, img in enumerate(imgs):
|
||||
st = time.time()
|
||||
dt_boxes, _ = text_detector(img)
|
||||
elapse = time.time() - st
|
||||
total_time += elapse
|
||||
count += 1
|
||||
save_pred = os.path.basename(image_file) + "\t" + str(
|
||||
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
|
||||
save_results.append(save_pred)
|
||||
logger.info(save_pred)
|
||||
logger.info("The predict time of {}: {}".format(image_file, elapse))
|
||||
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
||||
img_name_pure = os.path.split(image_file)[-1]
|
||||
img_path = os.path.join(draw_img_save,
|
||||
"det_res_{}".format(img_name_pure))
|
||||
cv2.imwrite(img_path, src_im)
|
||||
logger.info("The visualized image saved in {}".format(img_path))
|
||||
if len(imgs) > 1:
|
||||
save_pred = os.path.basename(image_file) + '_' + str(
|
||||
index) + "\t" + str(
|
||||
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
|
||||
else:
|
||||
save_pred = os.path.basename(image_file) + "\t" + str(
|
||||
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
|
||||
save_results.append(save_pred)
|
||||
logger.info(save_pred)
|
||||
if len(imgs) > 1:
|
||||
logger.info("{}_{} The predict time of {}: {}".format(
|
||||
idx, index, image_file, elapse))
|
||||
else:
|
||||
logger.info("{} The predict time of {}: {}".format(
|
||||
idx, image_file, elapse))
|
||||
if flag_pdf:
|
||||
src_im = utility.draw_text_det_res(dt_boxes, img, flag_pdf)
|
||||
else:
|
||||
src_im = utility.draw_text_det_res(dt_boxes, image_file,
|
||||
flag_pdf)
|
||||
if flag_gif:
|
||||
save_file = image_file[:-3] + "png"
|
||||
elif flag_pdf:
|
||||
save_file = image_file.replace('.pdf',
|
||||
'_' + str(index) + '.png')
|
||||
else:
|
||||
save_file = image_file
|
||||
img_path = os.path.join(
|
||||
draw_img_save_dir,
|
||||
"det_res_{}".format(os.path.basename(save_file)))
|
||||
cv2.imwrite(img_path, src_im)
|
||||
logger.info("The visualized image saved in {}".format(img_path))
|
||||
|
||||
with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
|
||||
with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f:
|
||||
f.writelines(save_results)
|
||||
f.close()
|
||||
if args.benchmark:
|
||||
|
|
|
@ -159,50 +159,75 @@ def main(args):
|
|||
count = 0
|
||||
for idx, image_file in enumerate(image_file_list):
|
||||
|
||||
img, flag, _ = check_and_read(image_file)
|
||||
if not flag:
|
||||
img, flag_gif, flag_pdf = check_and_read(image_file)
|
||||
if not flag_gif and not flag_pdf:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.debug("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
dt_boxes, rec_res, time_dict = text_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
total_time += elapse
|
||||
if not flag_pdf:
|
||||
if img is None:
|
||||
logger.debug("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
imgs = [img]
|
||||
else:
|
||||
page_num = args.page_num
|
||||
if page_num > len(img) or page_num == 0:
|
||||
page_num = len(img)
|
||||
imgs = img[:page_num]
|
||||
for index, img in enumerate(imgs):
|
||||
starttime = time.time()
|
||||
dt_boxes, rec_res, time_dict = text_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
total_time += elapse
|
||||
if len(imgs) > 1:
|
||||
logger.debug(
|
||||
str(idx) + '_' + str(index) + " Predict time of %s: %.3fs"
|
||||
% (image_file, elapse))
|
||||
else:
|
||||
logger.debug(
|
||||
str(idx) + " Predict time of %s: %.3fs" % (image_file,
|
||||
elapse))
|
||||
for text, score in rec_res:
|
||||
logger.debug("{}, {:.3f}".format(text, score))
|
||||
|
||||
logger.debug(
|
||||
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
|
||||
for text, score in rec_res:
|
||||
logger.debug("{}, {:.3f}".format(text, score))
|
||||
res = [{
|
||||
"transcription": rec_res[i][0],
|
||||
"points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
|
||||
} for i in range(len(dt_boxes))]
|
||||
if len(imgs) > 1:
|
||||
save_pred = os.path.basename(image_file) + '_' + str(
|
||||
index) + "\t" + json.dumps(
|
||||
res, ensure_ascii=False) + "\n"
|
||||
else:
|
||||
save_pred = os.path.basename(image_file) + "\t" + json.dumps(
|
||||
res, ensure_ascii=False) + "\n"
|
||||
save_results.append(save_pred)
|
||||
|
||||
res = [{
|
||||
"transcription": rec_res[idx][0],
|
||||
"points": np.array(dt_boxes[idx]).astype(np.int32).tolist(),
|
||||
} for idx in range(len(dt_boxes))]
|
||||
save_pred = os.path.basename(image_file) + "\t" + json.dumps(
|
||||
res, ensure_ascii=False) + "\n"
|
||||
save_results.append(save_pred)
|
||||
if is_visualize:
|
||||
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
boxes = dt_boxes
|
||||
txts = [rec_res[i][0] for i in range(len(rec_res))]
|
||||
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
||||
|
||||
if is_visualize:
|
||||
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
boxes = dt_boxes
|
||||
txts = [rec_res[i][0] for i in range(len(rec_res))]
|
||||
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
||||
|
||||
draw_img = draw_ocr_box_txt(
|
||||
image,
|
||||
boxes,
|
||||
txts,
|
||||
scores,
|
||||
drop_score=drop_score,
|
||||
font_path=font_path)
|
||||
if flag:
|
||||
image_file = image_file[:-3] + "png"
|
||||
cv2.imwrite(
|
||||
os.path.join(draw_img_save_dir, os.path.basename(image_file)),
|
||||
draw_img[:, :, ::-1])
|
||||
logger.debug("The visualized image saved in {}".format(
|
||||
os.path.join(draw_img_save_dir, os.path.basename(image_file))))
|
||||
draw_img = draw_ocr_box_txt(
|
||||
image,
|
||||
boxes,
|
||||
txts,
|
||||
scores,
|
||||
drop_score=drop_score,
|
||||
font_path=font_path)
|
||||
if flag_gif:
|
||||
save_file = image_file[:-3] + "png"
|
||||
elif flag_pdf:
|
||||
save_file = image_file.replace('.pdf',
|
||||
'_' + str(index) + '.png')
|
||||
else:
|
||||
save_file = image_file
|
||||
cv2.imwrite(
|
||||
os.path.join(draw_img_save_dir,
|
||||
os.path.basename(save_file)),
|
||||
draw_img[:, :, ::-1])
|
||||
logger.debug("The visualized image saved in {}".format(
|
||||
os.path.join(draw_img_save_dir, os.path.basename(
|
||||
save_file))))
|
||||
|
||||
logger.info("The predict total time is {}".format(time.time() - _st))
|
||||
if args.benchmark:
|
||||
|
|
|
@ -45,6 +45,7 @@ def init_args():
|
|||
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--page_num", type=int, default=0)
|
||||
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||
parser.add_argument("--det_model_dir", type=str)
|
||||
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||
|
@ -337,8 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path):
|
|||
return src_im
|
||||
|
||||
|
||||
def draw_text_det_res(dt_boxes, img_path):
|
||||
src_im = cv2.imread(img_path)
|
||||
def draw_text_det_res(dt_boxes, img_path, flag_pdf=False):
|
||||
if not flag_pdf:
|
||||
src_im = cv2.imread(img_path)
|
||||
else:
|
||||
src_im = img_path
|
||||
for box in dt_boxes:
|
||||
box = np.array(box).astype(np.int32).reshape(-1, 2)
|
||||
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||
|
|
Loading…
Reference in New Issue