fix clas distributed eval bug

pull/1320/head
dongshuilong 2021-10-21 03:47:03 +00:00
parent c93d638f4c
commit fd6f1ad2ca
1 changed files with 1 additions and 7 deletions
ppcls/engine/evaluation

View File

@ -88,12 +88,6 @@ def classification_eval(engine, epoch_id=0):
accum_samples]
current_samples = total_samples + current_samples - accum_samples
metric_dict = engine.eval_metric_func(pred, labels)
for key in metric_dict:
paddle.distributed.all_reduce(
metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
metric_dict[key] = metric_dict[
key] / paddle.distributed.get_world_size()
else:
metric_dict = engine.eval_metric_func(out, batch[1])
for key in metric_dict:
@ -103,7 +97,7 @@ def classification_eval(engine, epoch_id=0):
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(metric_dict[key].numpy()[0],
batch_size)
current_samples)
time_info["batch_cost"].update(time.time() - tic)