update pgnet
parent
4c0b08733d
commit
0fd122b674
|
@ -58,7 +58,7 @@ PostProcess:
|
|||
name: PGPostProcess
|
||||
score_thresh: 0.5
|
||||
mode: fast # fast or slow two ways
|
||||
point_gather_mode: v3 # same as PGProcessTrain: point_gather_mode
|
||||
point_gather_mode: align # same as PGProcessTrain: point_gather_mode
|
||||
|
||||
Metric:
|
||||
name: E2EMetric
|
||||
|
@ -85,7 +85,7 @@ Train:
|
|||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
point_gather_mode: v3 # two ways, v2 is original code, v3 is updated code
|
||||
point_gather_mode: align # two mode: align and none, align mode is better than none mode
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
|
||||
loader:
|
||||
|
|
|
@ -33,7 +33,7 @@ class PGProcessTrain(object):
|
|||
min_crop_size=24,
|
||||
min_text_size=4,
|
||||
max_text_size=512,
|
||||
point_gather_mode='v3',
|
||||
point_gather_mode=None,
|
||||
**kwargs):
|
||||
self.tcl_len = tcl_len
|
||||
self.max_text_length = max_text_length
|
||||
|
@ -531,7 +531,7 @@ class PGProcessTrain(object):
|
|||
average_shrink_height = self.calculate_average_height(
|
||||
stcl_quads)
|
||||
|
||||
if self.point_gather_mode == 'v3':
|
||||
if self.point_gather_mode == 'align':
|
||||
self.f_direction = direction_map[:, :, :-1].copy()
|
||||
pos_res = self.fit_and_gather_tcl_points_v3(
|
||||
min_area_quad,
|
||||
|
@ -545,7 +545,7 @@ class PGProcessTrain(object):
|
|||
continue
|
||||
pos_l, pos_m = pos_res[0], pos_res[1]
|
||||
|
||||
elif self.point_gather_mode == 'v2':
|
||||
else:
|
||||
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
||||
min_area_quad,
|
||||
poly,
|
||||
|
|
|
@ -30,8 +30,13 @@ class PGPostProcess(object):
|
|||
The post process for PGNet.
|
||||
"""
|
||||
|
||||
def __init__(self, character_dict_path, valid_set, score_thresh, mode,
|
||||
point_gather_mode, **kwargs):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
valid_set,
|
||||
score_thresh,
|
||||
mode,
|
||||
point_gather_mode=None,
|
||||
**kwargs):
|
||||
self.character_dict_path = character_dict_path
|
||||
self.valid_set = valid_set
|
||||
self.score_thresh = score_thresh
|
||||
|
|
|
@ -91,9 +91,9 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
|||
def instance_ctc_greedy_decoder(gather_info,
|
||||
logits_map,
|
||||
pts_num=4,
|
||||
point_gather_mode='v3'):
|
||||
point_gather_mode=None):
|
||||
_, _, C = logits_map.shape
|
||||
if point_gather_mode == 'v3':
|
||||
if point_gather_mode == 'align':
|
||||
insert_num = 0
|
||||
gather_info = np.array(gather_info)
|
||||
length = len(gather_info) - 1
|
||||
|
@ -115,6 +115,8 @@ def instance_ctc_greedy_decoder(gather_info,
|
|||
gather_info, insert_index, insert_value, axis=0)
|
||||
insert_num += insert_num_temp
|
||||
gather_info = gather_info.tolist()
|
||||
else:
|
||||
pass
|
||||
ys, xs = zip(*gather_info)
|
||||
logits_seq = logits_map[list(ys), list(xs)]
|
||||
probs_seq = logits_seq
|
||||
|
@ -130,7 +132,7 @@ def ctc_decoder_for_image(gather_info_list,
|
|||
logits_map,
|
||||
Lexicon_Table,
|
||||
pts_num=6,
|
||||
point_gather_mode='v3'):
|
||||
point_gather_mode=None):
|
||||
"""
|
||||
CTC decoder using multiple processes.
|
||||
"""
|
||||
|
@ -140,7 +142,10 @@ def ctc_decoder_for_image(gather_info_list,
|
|||
if len(gather_info) < pts_num:
|
||||
continue
|
||||
dst_str, xys_list = instance_ctc_greedy_decoder(
|
||||
gather_info, logits_map, pts_num=pts_num, point_gather_mode='v3')
|
||||
gather_info,
|
||||
logits_map,
|
||||
pts_num=pts_num,
|
||||
point_gather_mode=point_gather_mode)
|
||||
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
|
||||
if len(dst_str_readable) < 2:
|
||||
continue
|
||||
|
@ -383,7 +388,7 @@ def generate_pivot_list_fast(p_score,
|
|||
f_direction,
|
||||
Lexicon_Table,
|
||||
score_thresh=0.5,
|
||||
point_gather_mode='v3'):
|
||||
point_gather_mode=None):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
|
@ -414,7 +419,7 @@ def generate_pivot_list_fast(p_score,
|
|||
all_pos_yxs,
|
||||
logits_map=p_char_maps,
|
||||
Lexicon_Table=Lexicon_Table,
|
||||
point_gather_mode='v3')
|
||||
point_gather_mode=point_gather_mode)
|
||||
return keep_yxs_list, decoded_str
|
||||
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class PGNet_PostProcess(object):
|
|||
score_thresh,
|
||||
outs_dict,
|
||||
shape_list,
|
||||
point_gather_mode='v3'):
|
||||
point_gather_mode=None):
|
||||
self.Lexicon_Table = get_dict(character_dict_path)
|
||||
self.valid_set = valid_set
|
||||
self.score_thresh = score_thresh
|
||||
|
|
Loading…
Reference in New Issue