Merge pull request #1341 from RainFrost1/googlenet_bug
fig goooglenet distributed eval bugpull/1354/head
commit
a5d0e37b02
|
@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0):
|
||||||
# calc metric
|
# calc metric
|
||||||
if engine.eval_metric_func is not None:
|
if engine.eval_metric_func is not None:
|
||||||
if paddle.distributed.get_world_size() > 1:
|
if paddle.distributed.get_world_size() > 1:
|
||||||
pred_list = []
|
|
||||||
label_list = []
|
label_list = []
|
||||||
|
paddle.distributed.all_gather(label_list, batch[1])
|
||||||
|
labels = paddle.concat(label_list, 0)
|
||||||
|
|
||||||
if isinstance(out, dict):
|
if isinstance(out, dict):
|
||||||
out = out["logits"]
|
out = out["logits"]
|
||||||
paddle.distributed.all_gather(pred_list, out)
|
if isinstance(out, list):
|
||||||
paddle.distributed.all_gather(label_list, batch[1])
|
pred = []
|
||||||
pred = paddle.concat(pred_list, 0)
|
for x in out:
|
||||||
labels = paddle.concat(label_list, 0)
|
pred_list = []
|
||||||
|
paddle.distributed.all_gather(pred_list, x)
|
||||||
|
pred_x = paddle.concat(pred_list, 0)
|
||||||
|
pred.append(pred_x)
|
||||||
|
else:
|
||||||
|
pred_list = []
|
||||||
|
paddle.distributed.all_gather(pred_list, out)
|
||||||
|
pred = paddle.concat(pred_list, 0)
|
||||||
|
|
||||||
if accum_samples > total_samples:
|
if accum_samples > total_samples:
|
||||||
pred = pred[:total_samples + current_samples -
|
pred = pred[:total_samples + current_samples -
|
||||||
accum_samples]
|
accum_samples]
|
||||||
|
|
Loading…
Reference in New Issue