fig goooglenet distributed eval bug

pull/1341/head
dongshuilong 2021-10-26 11:56:30 +00:00
parent 69d9a477e0
commit 278f6d8050
1 changed files with 15 additions and 5 deletions

View File

@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0):
# calc metric
if engine.eval_metric_func is not None:
if paddle.distributed.get_world_size() > 1:
pred_list = []
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
if isinstance(out, dict):
out = out["logits"]
paddle.distributed.all_gather(pred_list, out)
paddle.distributed.all_gather(label_list, batch[1])
pred = paddle.concat(pred_list, 0)
labels = paddle.concat(label_list, 0)
if isinstance(out, list):
pred = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
pred.append(pred_x)
else:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
pred = paddle.concat(pred_list, 0)
if accum_samples > total_samples:
pred = pred[:total_samples + current_samples -
accum_samples]