Merge pull request #1341 from RainFrost1/googlenet_bug

fig goooglenet distributed eval bug
pull/1354/head
Walter 2021-10-29 10:28:37 +08:00 committed by GitHub
commit a5d0e37b02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 5 deletions

View File

@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0):
# calc metric # calc metric
if engine.eval_metric_func is not None: if engine.eval_metric_func is not None:
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
pred_list = []
label_list = [] label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
if isinstance(out, dict): if isinstance(out, dict):
out = out["logits"] out = out["logits"]
paddle.distributed.all_gather(pred_list, out) if isinstance(out, list):
paddle.distributed.all_gather(label_list, batch[1]) pred = []
pred = paddle.concat(pred_list, 0) for x in out:
labels = paddle.concat(label_list, 0) pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
pred.append(pred_x)
else:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
pred = paddle.concat(pred_list, 0)
if accum_samples > total_samples: if accum_samples > total_samples:
pred = pred[:total_samples + current_samples - pred = pred[:total_samples + current_samples -
accum_samples] accum_samples]