fig goooglenet distributed eval bug

This commit is contained in:
dongshuilong 2021-10-26 11:56:30 +00:00
parent c30b72c867
commit 2062c20cd8

View File

@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0):
# calc metric # calc metric
if engine.eval_metric_func is not None: if engine.eval_metric_func is not None:
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
pred_list = []
label_list = [] label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
if isinstance(out, dict): if isinstance(out, dict):
out = out["logits"] out = out["logits"]
paddle.distributed.all_gather(pred_list, out) if isinstance(out, list):
paddle.distributed.all_gather(label_list, batch[1]) pred = []
pred = paddle.concat(pred_list, 0) for x in out:
labels = paddle.concat(label_list, 0) pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
pred.append(pred_x)
else:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
pred = paddle.concat(pred_list, 0)
if accum_samples > total_samples: if accum_samples > total_samples:
pred = pred[:total_samples + current_samples - pred = pred[:total_samples + current_samples -
accum_samples] accum_samples]