fix: fp32 eval by default when enable amp

If you want to eval by fp16 when enable amp, please set Amp.use_fp16_test=True, False by default.
This commit is contained in:
gaotingquan 2022-04-02 07:35:01 +00:00 committed by Tingquan Gao
parent f3af58198d
commit b761325faa
2 changed files with 5 additions and 4 deletions

View File

@ -58,7 +58,7 @@ def classification_eval(engine, epoch_id=0):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
if engine.amp:
if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
amp_level = engine.config['AMP'].get("level", "O1").upper()
with paddle.amp.auto_cast(
custom_black_list={

View File

@ -161,12 +161,13 @@ def main(args):
# load pretrained models or checkpoints
init_model(global_config, train_prog, exe)
if 'AMP' in config and config.AMP.get("level", "O1") == "O2":
if 'AMP' in config:
optimizer.amp_init(
device,
scope=paddle.static.global_scope(),
test_program=eval_prog
if global_config["eval_during_train"] else None)
if global_config["eval_during_train"] else None,
use_fp16_test=config["AMP"].get("use_fp16_test", False))
if not global_config.get("is_distributed", True):
compiled_train_prog = program.compile(
@ -182,7 +183,7 @@ def main(args):
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
train_fetchs, epoch_id, 'train', config, vdl_writer,
lr_scheduler, args.profiler_options)
# 2. evaate with eval dataset
# 2. evaluate with eval dataset
if global_config["eval_during_train"] and epoch_id % global_config[
"eval_interval"] == 0:
top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog,