Merge pull request from RainFrost1/googlenet_bug

fig goooglenet distributed eval bug
pull/1354/head
Walter 2021-10-29 10:28:37 +08:00 committed by GitHub
commit a5d0e37b02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 5 deletions
ppcls/engine/evaluation

View File

@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0):
# calc metric
if engine.eval_metric_func is not None:
if paddle.distributed.get_world_size() > 1:
pred_list = []
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
if isinstance(out, dict):
out = out["logits"]
paddle.distributed.all_gather(pred_list, out)
paddle.distributed.all_gather(label_list, batch[1])
pred = paddle.concat(pred_list, 0)
labels = paddle.concat(label_list, 0)
if isinstance(out, list):
pred = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
pred.append(pred_x)
else:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
pred = paddle.concat(pred_list, 0)
if accum_samples > total_samples:
pred = pred[:total_samples + current_samples -
accum_samples]