update pgnet
parent
04aaaa748f
commit
929b4f4557
|
@ -13,6 +13,7 @@ Global:
|
|||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
infer_visual_type: EN # two mode: EN is for english datasets, CN is for chinese datasets
|
||||
valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
|
||||
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
|
@ -32,6 +33,7 @@ Architecture:
|
|||
name: PGFPN
|
||||
Head:
|
||||
name: PGHead
|
||||
tcc_channels: 37 # the length of character dict
|
||||
|
||||
Loss:
|
||||
name: PGLoss
|
||||
|
@ -45,16 +47,18 @@ Optimizer:
|
|||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 50
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
factor: 0.0001
|
||||
|
||||
PostProcess:
|
||||
name: PGPostProcess
|
||||
score_thresh: 0.5
|
||||
mode: fast # fast or slow two ways
|
||||
tcc_type: v3 # same as PGProcessTrain: tcc_type
|
||||
|
||||
Metric:
|
||||
name: E2EMetric
|
||||
|
@ -76,9 +80,12 @@ Train:
|
|||
- E2ELabelEncodeTrain:
|
||||
- PGProcessTrain:
|
||||
batch_size: 14 # same as loader: batch_size_per_card
|
||||
use_resize: True
|
||||
use_random_crop: False
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
tcc_type: v3 # two ways, v2 is original code, v3 is updated code
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
|
||||
loader:
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.morphology._skeletonize import thin
|
||||
from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2
|
||||
|
||||
__all__ = ['PGProcessTrain']
|
||||
|
||||
|
@ -26,17 +28,24 @@ class PGProcessTrain(object):
|
|||
max_text_nums,
|
||||
tcl_len,
|
||||
batch_size=14,
|
||||
use_resize=True,
|
||||
use_random_crop=False,
|
||||
min_crop_size=24,
|
||||
min_text_size=4,
|
||||
max_text_size=512,
|
||||
tcc_type='v3',
|
||||
**kwargs):
|
||||
self.tcl_len = tcl_len
|
||||
self.max_text_length = max_text_length
|
||||
self.max_text_nums = max_text_nums
|
||||
self.batch_size = batch_size
|
||||
self.min_crop_size = min_crop_size
|
||||
if use_random_crop is True:
|
||||
self.min_crop_size = min_crop_size
|
||||
self.use_random_crop = use_random_crop
|
||||
self.min_text_size = min_text_size
|
||||
self.max_text_size = max_text_size
|
||||
self.use_resize = use_resize
|
||||
self.tcc_type = tcc_type
|
||||
self.Lexicon_Table = self.get_dict(character_dict_path)
|
||||
self.pad_num = len(self.Lexicon_Table)
|
||||
self.img_id = 0
|
||||
|
@ -282,6 +291,95 @@ class PGProcessTrain(object):
|
|||
pos_m[:keep] = 1.0
|
||||
return pos_l, pos_m
|
||||
|
||||
def fit_and_gather_tcl_points_v3(self,
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h,
|
||||
max_w,
|
||||
fixed_point_num=64,
|
||||
img_id=0,
|
||||
reference_height=3):
|
||||
"""
|
||||
Find the center point of poly as key_points, then fit and gather.
|
||||
"""
|
||||
det_mask = np.zeros((int(max_h / self.ds_ratio),
|
||||
int(max_w / self.ds_ratio))).astype(np.float32)
|
||||
|
||||
# score_big_map
|
||||
cv2.fillPoly(det_mask,
|
||||
np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
|
||||
det_mask = cv2.resize(
|
||||
det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
|
||||
det_mask = np.array(det_mask > 1e-3, dtype='float32')
|
||||
|
||||
f_direction = self.f_direction
|
||||
skeleton_map = thin(det_mask.astype(np.uint8))
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
ys, xs = np.where(instance_label_map == 1)
|
||||
pos_list = list(zip(ys, xs))
|
||||
if len(pos_list) < 3:
|
||||
return None
|
||||
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||
pos_list, f_direction, det_mask)
|
||||
|
||||
pos_list_sorted = np.array(pos_list_sorted)
|
||||
length = len(pos_list_sorted) - 1
|
||||
insert_num = 0
|
||||
for index in range(length):
|
||||
stride_y = np.abs(pos_list_sorted[index + insert_num][0] -
|
||||
pos_list_sorted[index + 1 + insert_num][0])
|
||||
stride_x = np.abs(pos_list_sorted[index + insert_num][1] -
|
||||
pos_list_sorted[index + 1 + insert_num][1])
|
||||
max_points = int(max(stride_x, stride_y))
|
||||
|
||||
stride = (pos_list_sorted[index + insert_num] -
|
||||
pos_list_sorted[index + 1 + insert_num]) / (max_points)
|
||||
insert_num_temp = max_points - 1
|
||||
|
||||
for i in range(int(insert_num_temp)):
|
||||
insert_value = pos_list_sorted[index + insert_num] - (i + 1
|
||||
) * stride
|
||||
insert_index = index + i + 1 + insert_num
|
||||
pos_list_sorted = np.insert(
|
||||
pos_list_sorted, insert_index, insert_value, axis=0)
|
||||
insert_num += insert_num_temp
|
||||
|
||||
pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype(
|
||||
np.float32) # xy-> yx
|
||||
|
||||
point_num = len(pos_info)
|
||||
if point_num > fixed_point_num:
|
||||
keep_ids = [
|
||||
int((point_num * 1.0 / fixed_point_num) * x)
|
||||
for x in range(fixed_point_num)
|
||||
]
|
||||
pos_info = pos_info[keep_ids, :]
|
||||
|
||||
keep = int(min(len(pos_info), fixed_point_num))
|
||||
reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) +
|
||||
np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2
|
||||
if np.random.rand() < 1:
|
||||
dh = (np.random.rand(keep) - 0.5) * reference_height
|
||||
offset = np.random.rand() - 0.5
|
||||
dw = np.array([[0, offset * reference_width * 0.2]])
|
||||
random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape(
|
||||
[keep, 1])
|
||||
random_float_w = dw.repeat(keep, axis=0)
|
||||
pos_info += random_float_h
|
||||
pos_info += random_float_w
|
||||
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
|
||||
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
|
||||
|
||||
# padding to fixed length
|
||||
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
|
||||
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
|
||||
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
|
||||
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
|
||||
pos_m[:keep] = 1.0
|
||||
return pos_l, pos_m
|
||||
|
||||
def generate_direction_map(self, poly_quads, n_char, direction_map):
|
||||
"""
|
||||
"""
|
||||
|
@ -334,6 +432,7 @@ class PGProcessTrain(object):
|
|||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
self.ds_ratio = ds_ratio
|
||||
score_map_big = np.zeros(
|
||||
(
|
||||
h,
|
||||
|
@ -384,7 +483,6 @@ class PGProcessTrain(object):
|
|||
text_label = text_strs[poly_idx]
|
||||
text_label = self.prepare_text_label(text_label,
|
||||
self.Lexicon_Table)
|
||||
|
||||
text_label_index_list = [[self.Lexicon_Table.index(c_)]
|
||||
for c_ in text_label
|
||||
if c_ in self.Lexicon_Table]
|
||||
|
@ -432,14 +530,30 @@ class PGProcessTrain(object):
|
|||
# pos info
|
||||
average_shrink_height = self.calculate_average_height(
|
||||
stcl_quads)
|
||||
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h=h,
|
||||
max_w=w,
|
||||
fixed_point_num=64,
|
||||
img_id=self.img_id,
|
||||
reference_height=average_shrink_height)
|
||||
|
||||
if self.tcc_type == 'v3':
|
||||
self.f_direction = direction_map[:, :, :-1].copy()
|
||||
pos_res = self.fit_and_gather_tcl_points_v3(
|
||||
min_area_quad,
|
||||
stcl_quads,
|
||||
max_h=h,
|
||||
max_w=w,
|
||||
fixed_point_num=64,
|
||||
img_id=self.img_id,
|
||||
reference_height=average_shrink_height)
|
||||
if pos_res is None:
|
||||
continue
|
||||
pos_l, pos_m = pos_res[0], pos_res[1]
|
||||
|
||||
elif self.tcc_type == 'v2':
|
||||
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h=h,
|
||||
max_w=w,
|
||||
fixed_point_num=64,
|
||||
img_id=self.img_id,
|
||||
reference_height=average_shrink_height)
|
||||
|
||||
label_l = text_label_index_list
|
||||
if len(text_label_index_list) < 2:
|
||||
|
@ -770,27 +884,41 @@ class PGProcessTrain(object):
|
|||
text_polys[:, :, 0] *= asp_wx
|
||||
text_polys[:, :, 1] *= asp_hy
|
||||
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
if self.use_resize is True:
|
||||
ori_h, ori_w, _ = im.shape
|
||||
if max(ori_h, ori_w) < 200:
|
||||
ratio = 200 / max(ori_h, ori_w)
|
||||
im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
|
||||
text_polys[:, :, 0] *= ratio
|
||||
text_polys[:, :, 1] *= ratio
|
||||
|
||||
# no background
|
||||
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
|
||||
im,
|
||||
text_polys,
|
||||
text_tags,
|
||||
hv_tags,
|
||||
text_strs,
|
||||
crop_background=False)
|
||||
if max(ori_h, ori_w) > 512:
|
||||
ratio = 512 / max(ori_h, ori_w)
|
||||
im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
|
||||
text_polys[:, :, 0] *= ratio
|
||||
text_polys[:, :, 1] *= ratio
|
||||
elif self.use_random_crop is True:
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
|
||||
# no background
|
||||
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
|
||||
im,
|
||||
text_polys,
|
||||
text_tags,
|
||||
hv_tags,
|
||||
text_strs,
|
||||
crop_background=False)
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
# # continue for all ignore case
|
||||
# continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
new_h, new_w, _ = im.shape
|
||||
|
|
|
@ -89,12 +89,13 @@ class PGLoss(nn.Layer):
|
|||
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
|
||||
tcl_pos = paddle.cast(tcl_pos, dtype=int)
|
||||
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
|
||||
f_tcl_char = paddle.reshape(f_tcl_char,
|
||||
[-1, 64, 37]) # len(Lexicon_Table)+1
|
||||
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
|
||||
f_tcl_char = paddle.reshape(
|
||||
f_tcl_char, [-1, 64, self.pad_num + 1]) # len(Lexicon_Table)+1
|
||||
f_tcl_char_fg, f_tcl_char_bg = paddle.split(
|
||||
f_tcl_char, [self.pad_num, 1], axis=2)
|
||||
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
|
||||
b, c, l = tcl_mask.shape
|
||||
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
|
||||
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, self.pad_num * l])
|
||||
tcl_mask_fg.stop_gradient = True
|
||||
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
|
||||
-20.0)
|
||||
|
|
|
@ -66,7 +66,7 @@ class PGHead(nn.Layer):
|
|||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
def __init__(self, in_channels, tcc_channels=37, **kwargs):
|
||||
super(PGHead, self).__init__()
|
||||
self.conv_f_score1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
|
@ -178,7 +178,7 @@ class PGHead(nn.Layer):
|
|||
name="conv_f_char{}".format(5))
|
||||
self.conv3 = nn.Conv2D(
|
||||
in_channels=256,
|
||||
out_channels=37,
|
||||
out_channels=tcc_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
|
|
|
@ -31,11 +31,12 @@ class PGPostProcess(object):
|
|||
"""
|
||||
|
||||
def __init__(self, character_dict_path, valid_set, score_thresh, mode,
|
||||
**kwargs):
|
||||
tcc_type, **kwargs):
|
||||
self.character_dict_path = character_dict_path
|
||||
self.valid_set = valid_set
|
||||
self.score_thresh = score_thresh
|
||||
self.mode = mode
|
||||
self.tcc_type = tcc_type
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
|
@ -43,8 +44,13 @@ class PGPostProcess(object):
|
|||
self.is_python35 = True
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
post = PGNet_PostProcess(self.character_dict_path, self.valid_set,
|
||||
self.score_thresh, outs_dict, shape_list)
|
||||
post = PGNet_PostProcess(
|
||||
self.character_dict_path,
|
||||
self.valid_set,
|
||||
self.score_thresh,
|
||||
outs_dict,
|
||||
shape_list,
|
||||
tcc_type=self.tcc_type)
|
||||
if self.mode == 'fast':
|
||||
data = post.pg_postprocess_fast()
|
||||
else:
|
||||
|
|
|
@ -88,8 +88,33 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
|||
return dst_str, keep_idx_list
|
||||
|
||||
|
||||
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
|
||||
def instance_ctc_greedy_decoder(gather_info,
|
||||
logits_map,
|
||||
pts_num=4,
|
||||
tcc_type='v3'):
|
||||
_, _, C = logits_map.shape
|
||||
if tcc_type == 'v3':
|
||||
insert_num = 0
|
||||
gather_info = np.array(gather_info)
|
||||
length = len(gather_info) - 1
|
||||
for index in range(length):
|
||||
stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[
|
||||
index + 1 + insert_num][0])
|
||||
stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[
|
||||
index + 1 + insert_num][1])
|
||||
max_points = int(max(stride_x, stride_y))
|
||||
stride = (gather_info[index + insert_num] -
|
||||
gather_info[index + 1 + insert_num]) / (max_points)
|
||||
insert_num_temp = max_points - 1
|
||||
|
||||
for i in range(int(insert_num_temp)):
|
||||
insert_value = gather_info[index + insert_num] - (i + 1
|
||||
) * stride
|
||||
insert_index = index + i + 1 + insert_num
|
||||
gather_info = np.insert(
|
||||
gather_info, insert_index, insert_value, axis=0)
|
||||
insert_num += insert_num_temp
|
||||
gather_info = gather_info.tolist()
|
||||
ys, xs = zip(*gather_info)
|
||||
logits_seq = logits_map[list(ys), list(xs)]
|
||||
probs_seq = logits_seq
|
||||
|
@ -104,7 +129,8 @@ def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
|
|||
def ctc_decoder_for_image(gather_info_list,
|
||||
logits_map,
|
||||
Lexicon_Table,
|
||||
pts_num=6):
|
||||
pts_num=6,
|
||||
tcc_type='v3'):
|
||||
"""
|
||||
CTC decoder using multiple processes.
|
||||
"""
|
||||
|
@ -114,7 +140,7 @@ def ctc_decoder_for_image(gather_info_list,
|
|||
if len(gather_info) < pts_num:
|
||||
continue
|
||||
dst_str, xys_list = instance_ctc_greedy_decoder(
|
||||
gather_info, logits_map, pts_num=pts_num)
|
||||
gather_info, logits_map, pts_num=pts_num, tcc_type='v3')
|
||||
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
|
||||
if len(dst_str_readable) < 2:
|
||||
continue
|
||||
|
@ -356,7 +382,8 @@ def generate_pivot_list_fast(p_score,
|
|||
p_char_maps,
|
||||
f_direction,
|
||||
Lexicon_Table,
|
||||
score_thresh=0.5):
|
||||
score_thresh=0.5,
|
||||
tcc_type='v3'):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
|
@ -384,7 +411,10 @@ def generate_pivot_list_fast(p_score,
|
|||
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decoded_str, keep_yxs_list = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
|
||||
all_pos_yxs,
|
||||
logits_map=p_char_maps,
|
||||
Lexicon_Table=Lexicon_Table,
|
||||
tcc_type='v3')
|
||||
return keep_yxs_list, decoded_str
|
||||
|
||||
|
||||
|
|
|
@ -28,13 +28,19 @@ from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
|
|||
|
||||
class PGNet_PostProcess(object):
|
||||
# two different post-process
|
||||
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
|
||||
shape_list):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
valid_set,
|
||||
score_thresh,
|
||||
outs_dict,
|
||||
shape_list,
|
||||
tcc_type='v3'):
|
||||
self.Lexicon_Table = get_dict(character_dict_path)
|
||||
self.valid_set = valid_set
|
||||
self.score_thresh = score_thresh
|
||||
self.outs_dict = outs_dict
|
||||
self.shape_list = shape_list
|
||||
self.tcc_type = tcc_type
|
||||
|
||||
def pg_postprocess_fast(self):
|
||||
p_score = self.outs_dict['f_score']
|
||||
|
@ -58,7 +64,8 @@ class PGNet_PostProcess(object):
|
|||
p_char,
|
||||
p_direction,
|
||||
self.Lexicon_Table,
|
||||
score_thresh=self.score_thresh)
|
||||
score_thresh=self.score_thresh,
|
||||
tcc_type=self.tcc_type)
|
||||
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
|
||||
p_border, ratio_w, ratio_h,
|
||||
src_w, src_h, self.valid_set)
|
||||
|
|
|
@ -37,6 +37,46 @@ from ppocr.postprocess import build_post_process
|
|||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import math
|
||||
|
||||
|
||||
def draw_e2e_res_for_chinese(image,
|
||||
boxes,
|
||||
txts,
|
||||
config,
|
||||
img_name,
|
||||
font_path="./doc/simfang.ttf"):
|
||||
h, w = image.height, image.width
|
||||
img_left = image.copy()
|
||||
img_right = Image.new('RGB', (w, h), (255, 255, 255))
|
||||
|
||||
import random
|
||||
|
||||
random.seed(0)
|
||||
draw_left = ImageDraw.Draw(img_left)
|
||||
draw_right = ImageDraw.Draw(img_right)
|
||||
for idx, (box, txt) in enumerate(zip(boxes, txts)):
|
||||
box = np.array(box)
|
||||
box = [tuple(x) for x in box]
|
||||
color = (random.randint(0, 255), random.randint(0, 255),
|
||||
random.randint(0, 255))
|
||||
draw_left.polygon(box, fill=color)
|
||||
draw_right.polygon(box, outline=color)
|
||||
font = ImageFont.truetype(font_path, 15, encoding="utf-8")
|
||||
draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
|
||||
img_left = Image.blend(image, img_left, 0.5)
|
||||
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
|
||||
img_show.paste(img_left, (0, 0, w, h))
|
||||
img_show.paste(img_right, (w, 0, w * 2, h))
|
||||
|
||||
save_e2e_path = os.path.dirname(config['Global'][
|
||||
'save_res_path']) + "/e2e_results/"
|
||||
if not os.path.exists(save_e2e_path):
|
||||
os.makedirs(save_e2e_path)
|
||||
save_path = os.path.join(save_e2e_path, os.path.basename(img_name))
|
||||
cv2.imwrite(save_path, np.array(img_show)[:, :, ::-1])
|
||||
logger.info("The e2e Image saved in {}".format(save_path))
|
||||
|
||||
|
||||
def draw_e2e_res(dt_boxes, strs, config, img, img_name):
|
||||
|
@ -113,7 +153,19 @@ def main():
|
|||
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
|
||||
fout.write(otstr.encode())
|
||||
src_img = cv2.imread(file)
|
||||
draw_e2e_res(points, strs, config, src_img, file)
|
||||
if global_config['infer_visual_type'] == 'EN':
|
||||
draw_e2e_res(points, strs, config, src_img, file)
|
||||
elif global_config['infer_visual_type'] == 'CN':
|
||||
src_img = Image.fromarray(
|
||||
cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
|
||||
draw_e2e_res_for_chinese(
|
||||
src_img,
|
||||
points,
|
||||
strs,
|
||||
config,
|
||||
file,
|
||||
font_path="./doc/fonts/simfang.ttf")
|
||||
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue