all ready
parent
297871d4be
commit
93670ab5a2
|
@ -3,7 +3,7 @@ Global:
|
||||||
epoch_num: 72
|
epoch_num: 72
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 5
|
print_batch_step: 5
|
||||||
save_model_dir: ./output/rec/srn
|
save_model_dir: ./output/rec/srn_new
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
eval_batch_step: [0, 5000]
|
eval_batch_step: [0, 5000]
|
||||||
|
@ -25,8 +25,10 @@ Global:
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
name: Adam
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
clip_norm: 10.0
|
||||||
lr:
|
lr:
|
||||||
name: Cosine
|
|
||||||
learning_rate: 0.0001
|
learning_rate: 0.0001
|
||||||
|
|
||||||
Architecture:
|
Architecture:
|
||||||
|
@ -58,7 +60,6 @@ Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDataSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/srn_train_data_duiqi
|
data_dir: ./train_data/srn_train_data_duiqi
|
||||||
#label_file_list: ["./train_data/ic15_data/1.txt"]
|
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
img_mode: BGR
|
img_mode: BGR
|
||||||
|
@ -77,7 +78,7 @@ Train:
|
||||||
loader:
|
loader:
|
||||||
shuffle: False
|
shuffle: False
|
||||||
batch_size_per_card: 64
|
batch_size_per_card: 64
|
||||||
drop_last: True
|
drop_last: False
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
|
|
|
@ -359,6 +359,7 @@ class PrepareDecoder(nn.Layer):
|
||||||
self.emb0 = paddle.nn.Embedding(
|
self.emb0 = paddle.nn.Embedding(
|
||||||
num_embeddings=src_vocab_size,
|
num_embeddings=src_vocab_size,
|
||||||
embedding_dim=self.src_emb_dim,
|
embedding_dim=self.src_emb_dim,
|
||||||
|
padding_idx=bos_idx,
|
||||||
weight_attr=paddle.ParamAttr(
|
weight_attr=paddle.ParamAttr(
|
||||||
name=word_emb_param_name,
|
name=word_emb_param_name,
|
||||||
initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
|
initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
|
||||||
|
|
|
@ -182,14 +182,15 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
||||||
|
|
||||||
preds_prob = np.reshape(preds_prob, [-1, 25])
|
preds_prob = np.reshape(preds_prob, [-1, 25])
|
||||||
|
|
||||||
text = self.decode(preds_idx, preds_prob)
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||||
|
|
||||||
if label is None:
|
if label is None:
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
return text
|
return text
|
||||||
label = self.decode(label, is_remove_duplicate=False)
|
label = self.decode(label, is_remove_duplicate=True)
|
||||||
return text, label
|
return text, label
|
||||||
|
|
||||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=True):
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
""" convert text-index into text-label. """
|
""" convert text-index into text-label. """
|
||||||
result_list = []
|
result_list = []
|
||||||
ignored_tokens = self.get_ignored_tokens()
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
|
|
@ -242,6 +242,12 @@ def train(config,
|
||||||
# eval
|
# eval
|
||||||
if global_step > start_eval_step and \
|
if global_step > start_eval_step and \
|
||||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||||
|
model_average = paddle.optimizer.ModelAverage(
|
||||||
|
0.15,
|
||||||
|
parameters=model.parameters(),
|
||||||
|
min_average_window=10000,
|
||||||
|
max_average_window=15625)
|
||||||
|
model_average.apply()
|
||||||
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class)
|
eval_class)
|
||||||
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
||||||
|
@ -277,6 +283,7 @@ def train(config,
|
||||||
best_model_dict[main_indicator],
|
best_model_dict[main_indicator],
|
||||||
global_step)
|
global_step)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
optimizer.clear_grad()
|
||||||
batch_start = time.time()
|
batch_start = time.time()
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
save_model(
|
save_model(
|
||||||
|
|
Loading…
Reference in New Issue