mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
update pgnet
This commit is contained in:
parent
929b4f4557
commit
4c0b08733d
@ -33,7 +33,7 @@ Architecture:
|
|||||||
name: PGFPN
|
name: PGFPN
|
||||||
Head:
|
Head:
|
||||||
name: PGHead
|
name: PGHead
|
||||||
tcc_channels: 37 # the length of character dict
|
character_dict_path: ppocr/utils/ic15_dict.txt # the same as Global:character_dict_path
|
||||||
|
|
||||||
Loss:
|
Loss:
|
||||||
name: PGLoss
|
name: PGLoss
|
||||||
@ -58,7 +58,7 @@ PostProcess:
|
|||||||
name: PGPostProcess
|
name: PGPostProcess
|
||||||
score_thresh: 0.5
|
score_thresh: 0.5
|
||||||
mode: fast # fast or slow two ways
|
mode: fast # fast or slow two ways
|
||||||
tcc_type: v3 # same as PGProcessTrain: tcc_type
|
point_gather_mode: v3 # same as PGProcessTrain: point_gather_mode
|
||||||
|
|
||||||
Metric:
|
Metric:
|
||||||
name: E2EMetric
|
name: E2EMetric
|
||||||
@ -85,7 +85,7 @@ Train:
|
|||||||
min_crop_size: 24
|
min_crop_size: 24
|
||||||
min_text_size: 4
|
min_text_size: 4
|
||||||
max_text_size: 512
|
max_text_size: 512
|
||||||
tcc_type: v3 # two ways, v2 is original code, v3 is updated code
|
point_gather_mode: v3 # two ways, v2 is original code, v3 is updated code
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
|
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
|
||||||
loader:
|
loader:
|
||||||
|
@ -33,7 +33,7 @@ class PGProcessTrain(object):
|
|||||||
min_crop_size=24,
|
min_crop_size=24,
|
||||||
min_text_size=4,
|
min_text_size=4,
|
||||||
max_text_size=512,
|
max_text_size=512,
|
||||||
tcc_type='v3',
|
point_gather_mode='v3',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.tcl_len = tcl_len
|
self.tcl_len = tcl_len
|
||||||
self.max_text_length = max_text_length
|
self.max_text_length = max_text_length
|
||||||
@ -45,7 +45,7 @@ class PGProcessTrain(object):
|
|||||||
self.min_text_size = min_text_size
|
self.min_text_size = min_text_size
|
||||||
self.max_text_size = max_text_size
|
self.max_text_size = max_text_size
|
||||||
self.use_resize = use_resize
|
self.use_resize = use_resize
|
||||||
self.tcc_type = tcc_type
|
self.point_gather_mode = point_gather_mode
|
||||||
self.Lexicon_Table = self.get_dict(character_dict_path)
|
self.Lexicon_Table = self.get_dict(character_dict_path)
|
||||||
self.pad_num = len(self.Lexicon_Table)
|
self.pad_num = len(self.Lexicon_Table)
|
||||||
self.img_id = 0
|
self.img_id = 0
|
||||||
@ -531,7 +531,7 @@ class PGProcessTrain(object):
|
|||||||
average_shrink_height = self.calculate_average_height(
|
average_shrink_height = self.calculate_average_height(
|
||||||
stcl_quads)
|
stcl_quads)
|
||||||
|
|
||||||
if self.tcc_type == 'v3':
|
if self.point_gather_mode == 'v3':
|
||||||
self.f_direction = direction_map[:, :, :-1].copy()
|
self.f_direction = direction_map[:, :, :-1].copy()
|
||||||
pos_res = self.fit_and_gather_tcl_points_v3(
|
pos_res = self.fit_and_gather_tcl_points_v3(
|
||||||
min_area_quad,
|
min_area_quad,
|
||||||
@ -545,7 +545,7 @@ class PGProcessTrain(object):
|
|||||||
continue
|
continue
|
||||||
pos_l, pos_m = pos_res[0], pos_res[1]
|
pos_l, pos_m = pos_res[0], pos_res[1]
|
||||||
|
|
||||||
elif self.tcc_type == 'v2':
|
elif self.point_gather_mode == 'v2':
|
||||||
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
||||||
min_area_quad,
|
min_area_quad,
|
||||||
poly,
|
poly,
|
||||||
|
@ -66,8 +66,17 @@ class PGHead(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, tcc_channels=37, **kwargs):
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
character_dict_path='ppocr/utils/ic15_dict.txt',
|
||||||
|
**kwargs):
|
||||||
super(PGHead, self).__init__()
|
super(PGHead, self).__init__()
|
||||||
|
|
||||||
|
# get character_length
|
||||||
|
with open(character_dict_path, "rb") as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
character_length = len(lines) + 1
|
||||||
|
|
||||||
self.conv_f_score1 = ConvBNLayer(
|
self.conv_f_score1 = ConvBNLayer(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=64,
|
out_channels=64,
|
||||||
@ -178,7 +187,7 @@ class PGHead(nn.Layer):
|
|||||||
name="conv_f_char{}".format(5))
|
name="conv_f_char{}".format(5))
|
||||||
self.conv3 = nn.Conv2D(
|
self.conv3 = nn.Conv2D(
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
out_channels=tcc_channels,
|
out_channels=character_length,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
|
@ -31,12 +31,12 @@ class PGPostProcess(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, character_dict_path, valid_set, score_thresh, mode,
|
def __init__(self, character_dict_path, valid_set, score_thresh, mode,
|
||||||
tcc_type, **kwargs):
|
point_gather_mode, **kwargs):
|
||||||
self.character_dict_path = character_dict_path
|
self.character_dict_path = character_dict_path
|
||||||
self.valid_set = valid_set
|
self.valid_set = valid_set
|
||||||
self.score_thresh = score_thresh
|
self.score_thresh = score_thresh
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.tcc_type = tcc_type
|
self.point_gather_mode = point_gather_mode
|
||||||
|
|
||||||
# c++ la-nms is faster, but only support python 3.5
|
# c++ la-nms is faster, but only support python 3.5
|
||||||
self.is_python35 = False
|
self.is_python35 = False
|
||||||
@ -50,7 +50,7 @@ class PGPostProcess(object):
|
|||||||
self.score_thresh,
|
self.score_thresh,
|
||||||
outs_dict,
|
outs_dict,
|
||||||
shape_list,
|
shape_list,
|
||||||
tcc_type=self.tcc_type)
|
point_gather_mode=self.point_gather_mode)
|
||||||
if self.mode == 'fast':
|
if self.mode == 'fast':
|
||||||
data = post.pg_postprocess_fast()
|
data = post.pg_postprocess_fast()
|
||||||
else:
|
else:
|
||||||
|
@ -91,9 +91,9 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
|||||||
def instance_ctc_greedy_decoder(gather_info,
|
def instance_ctc_greedy_decoder(gather_info,
|
||||||
logits_map,
|
logits_map,
|
||||||
pts_num=4,
|
pts_num=4,
|
||||||
tcc_type='v3'):
|
point_gather_mode='v3'):
|
||||||
_, _, C = logits_map.shape
|
_, _, C = logits_map.shape
|
||||||
if tcc_type == 'v3':
|
if point_gather_mode == 'v3':
|
||||||
insert_num = 0
|
insert_num = 0
|
||||||
gather_info = np.array(gather_info)
|
gather_info = np.array(gather_info)
|
||||||
length = len(gather_info) - 1
|
length = len(gather_info) - 1
|
||||||
@ -130,7 +130,7 @@ def ctc_decoder_for_image(gather_info_list,
|
|||||||
logits_map,
|
logits_map,
|
||||||
Lexicon_Table,
|
Lexicon_Table,
|
||||||
pts_num=6,
|
pts_num=6,
|
||||||
tcc_type='v3'):
|
point_gather_mode='v3'):
|
||||||
"""
|
"""
|
||||||
CTC decoder using multiple processes.
|
CTC decoder using multiple processes.
|
||||||
"""
|
"""
|
||||||
@ -140,7 +140,7 @@ def ctc_decoder_for_image(gather_info_list,
|
|||||||
if len(gather_info) < pts_num:
|
if len(gather_info) < pts_num:
|
||||||
continue
|
continue
|
||||||
dst_str, xys_list = instance_ctc_greedy_decoder(
|
dst_str, xys_list = instance_ctc_greedy_decoder(
|
||||||
gather_info, logits_map, pts_num=pts_num, tcc_type='v3')
|
gather_info, logits_map, pts_num=pts_num, point_gather_mode='v3')
|
||||||
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
|
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
|
||||||
if len(dst_str_readable) < 2:
|
if len(dst_str_readable) < 2:
|
||||||
continue
|
continue
|
||||||
@ -383,7 +383,7 @@ def generate_pivot_list_fast(p_score,
|
|||||||
f_direction,
|
f_direction,
|
||||||
Lexicon_Table,
|
Lexicon_Table,
|
||||||
score_thresh=0.5,
|
score_thresh=0.5,
|
||||||
tcc_type='v3'):
|
point_gather_mode='v3'):
|
||||||
"""
|
"""
|
||||||
return center point and end point of TCL instance; filter with the char maps;
|
return center point and end point of TCL instance; filter with the char maps;
|
||||||
"""
|
"""
|
||||||
@ -414,7 +414,7 @@ def generate_pivot_list_fast(p_score,
|
|||||||
all_pos_yxs,
|
all_pos_yxs,
|
||||||
logits_map=p_char_maps,
|
logits_map=p_char_maps,
|
||||||
Lexicon_Table=Lexicon_Table,
|
Lexicon_Table=Lexicon_Table,
|
||||||
tcc_type='v3')
|
point_gather_mode='v3')
|
||||||
return keep_yxs_list, decoded_str
|
return keep_yxs_list, decoded_str
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,13 +34,13 @@ class PGNet_PostProcess(object):
|
|||||||
score_thresh,
|
score_thresh,
|
||||||
outs_dict,
|
outs_dict,
|
||||||
shape_list,
|
shape_list,
|
||||||
tcc_type='v3'):
|
point_gather_mode='v3'):
|
||||||
self.Lexicon_Table = get_dict(character_dict_path)
|
self.Lexicon_Table = get_dict(character_dict_path)
|
||||||
self.valid_set = valid_set
|
self.valid_set = valid_set
|
||||||
self.score_thresh = score_thresh
|
self.score_thresh = score_thresh
|
||||||
self.outs_dict = outs_dict
|
self.outs_dict = outs_dict
|
||||||
self.shape_list = shape_list
|
self.shape_list = shape_list
|
||||||
self.tcc_type = tcc_type
|
self.point_gather_mode = point_gather_mode
|
||||||
|
|
||||||
def pg_postprocess_fast(self):
|
def pg_postprocess_fast(self):
|
||||||
p_score = self.outs_dict['f_score']
|
p_score = self.outs_dict['f_score']
|
||||||
@ -65,7 +65,7 @@ class PGNet_PostProcess(object):
|
|||||||
p_direction,
|
p_direction,
|
||||||
self.Lexicon_Table,
|
self.Lexicon_Table,
|
||||||
score_thresh=self.score_thresh,
|
score_thresh=self.score_thresh,
|
||||||
tcc_type=self.tcc_type)
|
point_gather_mode=self.point_gather_mode)
|
||||||
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
|
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
|
||||||
p_border, ratio_w, ratio_h,
|
p_border, ratio_w, ratio_h,
|
||||||
src_w, src_h, self.valid_set)
|
src_w, src_h, self.valid_set)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user