diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 994eeb5ee..e9836fcbb 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -58,7 +58,7 @@ def classification_eval(engine, epoch_id=0): batch[1] = batch[1].reshape([-1, 1]).astype("int64") # image input - if engine.amp: + if engine.amp and engine.config["AMP"].get("use_fp16_test", False): amp_level = engine.config['AMP'].get("level", "O1").upper() with paddle.amp.auto_cast( custom_black_list={ diff --git a/ppcls/static/train.py b/ppcls/static/train.py index dd16cdb4c..45c937625 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -161,12 +161,13 @@ def main(args): # load pretrained models or checkpoints init_model(global_config, train_prog, exe) - if 'AMP' in config and config.AMP.get("level", "O1") == "O2": + if 'AMP' in config: optimizer.amp_init( device, scope=paddle.static.global_scope(), test_program=eval_prog - if global_config["eval_during_train"] else None) + if global_config["eval_during_train"] else None, + use_fp16_test=config["AMP"].get("use_fp16_test", False)) if not global_config.get("is_distributed", True): compiled_train_prog = program.compile( @@ -182,7 +183,7 @@ def main(args): program.run(train_dataloader, exe, compiled_train_prog, train_feeds, train_fetchs, epoch_id, 'train', config, vdl_writer, lr_scheduler, args.profiler_options) - # 2. evaate with eval dataset + # 2. evaluate with eval dataset if global_config["eval_during_train"] and epoch_id % global_config[ "eval_interval"] == 0: top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog,