fix eval bug

pull/1632/head
zhangbo9674 2022-01-07 06:41:46 +00:00 committed by Tingquan Gao
parent b2956c1b41
commit bb19c1f7a6
1 changed files with 23 additions and 8 deletions

View File

@ -56,7 +56,22 @@ def classification_eval(engine, epoch_id=0):
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
if engine.amp:
amp_level = 'O1'
if engine.config['AMP']['use_pure_fp16'] is True:
amp_level = 'O2'
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
else:
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None: