fix clas distributed eval bug
parent
c93d638f4c
commit
fd6f1ad2ca
ppcls/engine/evaluation
|
@ -88,12 +88,6 @@ def classification_eval(engine, epoch_id=0):
|
|||
accum_samples]
|
||||
current_samples = total_samples + current_samples - accum_samples
|
||||
metric_dict = engine.eval_metric_func(pred, labels)
|
||||
|
||||
for key in metric_dict:
|
||||
paddle.distributed.all_reduce(
|
||||
metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
|
||||
metric_dict[key] = metric_dict[
|
||||
key] / paddle.distributed.get_world_size()
|
||||
else:
|
||||
metric_dict = engine.eval_metric_func(out, batch[1])
|
||||
for key in metric_dict:
|
||||
|
@ -103,7 +97,7 @@ def classification_eval(engine, epoch_id=0):
|
|||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
|
||||
output_info[key].update(metric_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
current_samples)
|
||||
|
||||
time_info["batch_cost"].update(time.time() - tic)
|
||||
|
||||
|
|
Loading…
Reference in New Issue