fix eval bug

pull/1632/head
zhangbo9674 2022-01-07 06:41:46 +00:00 committed by Tingquan Gao
parent b2956c1b41
commit bb19c1f7a6
1 changed files with 23 additions and 8 deletions

View File

@ -56,15 +56,30 @@ def classification_eval(engine, epoch_id=0):
batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[0] = paddle.to_tensor(batch[0]).astype("float32")
if not engine.config["Global"].get("use_multilabel", False): if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
out = engine.model(batch[0]) if engine.amp:
# calc loss amp_level = 'O1'
if engine.eval_loss_func is not None: if engine.config['AMP']['use_pure_fp16'] is True:
loss_dict = engine.eval_loss_func(out, batch[1]) amp_level = 'O2'
for key in loss_dict: with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
if key not in output_info: out = engine.model(batch[0])
output_info[key] = AverageMeter(key, '7.5f') # calc loss
output_info[key].update(loss_dict[key].numpy()[0], batch_size) if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
else:
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# just for DistributedBatchSampler issue: repeat sampling # just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size() current_samples = batch_size * paddle.distributed.get_world_size()