fix eval bug
parent
b2956c1b41
commit
bb19c1f7a6
|
@ -56,15 +56,30 @@ def classification_eval(engine, epoch_id=0):
|
||||||
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
|
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
|
||||||
if not engine.config["Global"].get("use_multilabel", False):
|
if not engine.config["Global"].get("use_multilabel", False):
|
||||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||||
|
|
||||||
# image input
|
# image input
|
||||||
out = engine.model(batch[0])
|
if engine.amp:
|
||||||
# calc loss
|
amp_level = 'O1'
|
||||||
if engine.eval_loss_func is not None:
|
if engine.config['AMP']['use_pure_fp16'] is True:
|
||||||
loss_dict = engine.eval_loss_func(out, batch[1])
|
amp_level = 'O2'
|
||||||
for key in loss_dict:
|
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
|
||||||
if key not in output_info:
|
out = engine.model(batch[0])
|
||||||
output_info[key] = AverageMeter(key, '7.5f')
|
# calc loss
|
||||||
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
|
if engine.eval_loss_func is not None:
|
||||||
|
loss_dict = engine.eval_loss_func(out, batch[1])
|
||||||
|
for key in loss_dict:
|
||||||
|
if key not in output_info:
|
||||||
|
output_info[key] = AverageMeter(key, '7.5f')
|
||||||
|
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
|
||||||
|
else:
|
||||||
|
out = engine.model(batch[0])
|
||||||
|
# calc loss
|
||||||
|
if engine.eval_loss_func is not None:
|
||||||
|
loss_dict = engine.eval_loss_func(out, batch[1])
|
||||||
|
for key in loss_dict:
|
||||||
|
if key not in output_info:
|
||||||
|
output_info[key] = AverageMeter(key, '7.5f')
|
||||||
|
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
|
||||||
|
|
||||||
# just for DistributedBatchSampler issue: repeat sampling
|
# just for DistributedBatchSampler issue: repeat sampling
|
||||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||||
|
|
Loading…
Reference in New Issue