support eval pre epoch (#11003)

This commit is contained in:
zhangyubo0722 2023-09-26 18:50:42 +08:00 committed by GitHub
parent e49e491417
commit 4ba32bc91c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions

View File

@ -185,6 +185,7 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
step_pre_epoch,
log_writer=None,
scaler=None,
amp_level='O2',
@ -198,6 +199,7 @@ def train(config,
epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
eval_batch_epoch = config['Global'].get('eval_batch_epoch', None)
profiler_options = config['profiler_options']
global_step = 0
@ -205,8 +207,9 @@ def train(config,
global_step = pre_best_model_dict['global_step']
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0
eval_batch_step = eval_batch_step[
1] if not eval_batch_epoch else step_pre_epoch * eval_batch_epoch
if len(valid_dataloader) == 0:
logger.info(
'No Images in eval dataset, evaluation during training ' \

View File

@ -61,9 +61,11 @@ def main(config, device, logger, vdl_writer, seed):
return
if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger, seed)
valid_dataloader = build_dataloader(config, 'Eval', device, logger,
seed)
else:
valid_dataloader = None
step_pre_epoch = len(train_dataloader)
# build post process
post_process_class = build_post_process(config['PostProcess'],
@ -93,7 +95,8 @@ def main(config, device, logger, vdl_writer, seed):
'DistillationSARLoss'][
'ignore_index'] = char_num + 1
out_channels_list['SARLabelDecode'] = char_num + 2
elif any('DistillationNRTRLoss' in d for d in config['Loss']['loss_config_list']):
elif any('DistillationNRTRLoss' in d
for d in config['Loss']['loss_config_list']):
out_channels_list['NRTRLabelDecode'] = char_num + 3
config['Architecture']['Models'][key]['Head'][
@ -196,9 +199,9 @@ def main(config, device, logger, vdl_writer, seed):
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
amp_level, amp_custom_black_list, amp_custom_white_list,
amp_dtype)
eval_class, pre_best_model_dict, logger, step_pre_epoch,
vdl_writer, scaler, amp_level, amp_custom_black_list,
amp_custom_white_list, amp_dtype)
def test_reader(config, device, logger):