mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix: fp32 eval by default when enable amp
If you want to eval by fp16 when enable amp, please set Amp.use_fp16_test=True, False by default.
This commit is contained in:
parent
f3af58198d
commit
b761325faa
@ -58,7 +58,7 @@ def classification_eval(engine, epoch_id=0):
|
|||||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||||
|
|
||||||
# image input
|
# image input
|
||||||
if engine.amp:
|
if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
|
||||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||||
with paddle.amp.auto_cast(
|
with paddle.amp.auto_cast(
|
||||||
custom_black_list={
|
custom_black_list={
|
||||||
|
@ -161,12 +161,13 @@ def main(args):
|
|||||||
# load pretrained models or checkpoints
|
# load pretrained models or checkpoints
|
||||||
init_model(global_config, train_prog, exe)
|
init_model(global_config, train_prog, exe)
|
||||||
|
|
||||||
if 'AMP' in config and config.AMP.get("level", "O1") == "O2":
|
if 'AMP' in config:
|
||||||
optimizer.amp_init(
|
optimizer.amp_init(
|
||||||
device,
|
device,
|
||||||
scope=paddle.static.global_scope(),
|
scope=paddle.static.global_scope(),
|
||||||
test_program=eval_prog
|
test_program=eval_prog
|
||||||
if global_config["eval_during_train"] else None)
|
if global_config["eval_during_train"] else None,
|
||||||
|
use_fp16_test=config["AMP"].get("use_fp16_test", False))
|
||||||
|
|
||||||
if not global_config.get("is_distributed", True):
|
if not global_config.get("is_distributed", True):
|
||||||
compiled_train_prog = program.compile(
|
compiled_train_prog = program.compile(
|
||||||
@ -182,7 +183,7 @@ def main(args):
|
|||||||
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
|
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
|
||||||
train_fetchs, epoch_id, 'train', config, vdl_writer,
|
train_fetchs, epoch_id, 'train', config, vdl_writer,
|
||||||
lr_scheduler, args.profiler_options)
|
lr_scheduler, args.profiler_options)
|
||||||
# 2. evaate with eval dataset
|
# 2. evaluate with eval dataset
|
||||||
if global_config["eval_during_train"] and epoch_id % global_config[
|
if global_config["eval_during_train"] and epoch_id % global_config[
|
||||||
"eval_interval"] == 0:
|
"eval_interval"] == 0:
|
||||||
top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog,
|
top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user