mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
support eval pre epoch (#11003)
This commit is contained in:
parent
e49e491417
commit
4ba32bc91c
@ -185,6 +185,7 @@ def train(config,
|
||||
eval_class,
|
||||
pre_best_model_dict,
|
||||
logger,
|
||||
step_pre_epoch,
|
||||
log_writer=None,
|
||||
scaler=None,
|
||||
amp_level='O2',
|
||||
@ -198,6 +199,7 @@ def train(config,
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
print_batch_step = config['Global']['print_batch_step']
|
||||
eval_batch_step = config['Global']['eval_batch_step']
|
||||
eval_batch_epoch = config['Global'].get('eval_batch_epoch', None)
|
||||
profiler_options = config['profiler_options']
|
||||
|
||||
global_step = 0
|
||||
@ -205,8 +207,9 @@ def train(config,
|
||||
global_step = pre_best_model_dict['global_step']
|
||||
start_eval_step = 0
|
||||
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
||||
start_eval_step = eval_batch_step[0]
|
||||
eval_batch_step = eval_batch_step[1]
|
||||
start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0
|
||||
eval_batch_step = eval_batch_step[
|
||||
1] if not eval_batch_epoch else step_pre_epoch * eval_batch_epoch
|
||||
if len(valid_dataloader) == 0:
|
||||
logger.info(
|
||||
'No Images in eval dataset, evaluation during training ' \
|
||||
|
@ -61,9 +61,11 @@ def main(config, device, logger, vdl_writer, seed):
|
||||
return
|
||||
|
||||
if config['Eval']:
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger, seed)
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger,
|
||||
seed)
|
||||
else:
|
||||
valid_dataloader = None
|
||||
step_pre_epoch = len(train_dataloader)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
@ -93,7 +95,8 @@ def main(config, device, logger, vdl_writer, seed):
|
||||
'DistillationSARLoss'][
|
||||
'ignore_index'] = char_num + 1
|
||||
out_channels_list['SARLabelDecode'] = char_num + 2
|
||||
elif any('DistillationNRTRLoss' in d for d in config['Loss']['loss_config_list']):
|
||||
elif any('DistillationNRTRLoss' in d
|
||||
for d in config['Loss']['loss_config_list']):
|
||||
out_channels_list['NRTRLabelDecode'] = char_num + 3
|
||||
|
||||
config['Architecture']['Models'][key]['Head'][
|
||||
@ -196,9 +199,9 @@ def main(config, device, logger, vdl_writer, seed):
|
||||
# start train
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
|
||||
amp_level, amp_custom_black_list, amp_custom_white_list,
|
||||
amp_dtype)
|
||||
eval_class, pre_best_model_dict, logger, step_pre_epoch,
|
||||
vdl_writer, scaler, amp_level, amp_custom_black_list,
|
||||
amp_custom_white_list, amp_dtype)
|
||||
|
||||
|
||||
def test_reader(config, device, logger):
|
||||
|
Loading…
x
Reference in New Issue
Block a user