fix ips info and reduce interval of metric calc
parent
cffb031658
commit
229acda856
|
@ -31,7 +31,8 @@ class CTCLoss(nn.Layer):
|
|||
predicts = predicts[-1]
|
||||
predicts = predicts.transpose((1, 0, 2))
|
||||
N, B, _ = predicts.shape
|
||||
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||
preds_lengths = paddle.to_tensor(
|
||||
[N] * B, dtype='int64', place=paddle.CPUPlace())
|
||||
labels = batch[1].astype("int32")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
|
||||
|
|
|
@ -146,6 +146,7 @@ def train(config,
|
|||
scaler=None):
|
||||
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
||||
False)
|
||||
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
print_batch_step = config['Global']['print_batch_step']
|
||||
|
@ -244,6 +245,16 @@ def train(config,
|
|||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
|
||||
batch = [item.numpy() for item in batch]
|
||||
if model_type in ['table', 'kie']:
|
||||
eval_class(preds, batch)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
metric = eval_class.get_metric()
|
||||
train_stats.update(metric)
|
||||
|
||||
train_batch_time = time.time() - reader_start
|
||||
train_batch_cost += train_batch_time
|
||||
eta_meter.update(train_batch_time)
|
||||
|
@ -258,16 +269,6 @@ def train(config,
|
|||
stats['lr'] = lr
|
||||
train_stats.update(stats)
|
||||
|
||||
if cal_metric_during_train: # only rec and cls need
|
||||
batch = [item.numpy() for item in batch]
|
||||
if model_type in ['table', 'kie']:
|
||||
eval_class(preds, batch)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
metric = eval_class.get_metric()
|
||||
train_stats.update(metric)
|
||||
|
||||
if vdl_writer is not None and dist.get_rank() == 0:
|
||||
for k, v in train_stats.get().items():
|
||||
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
|
||||
|
|
Loading…
Reference in New Issue