fig goooglenet distributed eval bug
parent
69d9a477e0
commit
278f6d8050
|
@ -73,14 +73,24 @@ def classification_eval(engine, epoch_id=0):
|
|||
# calc metric
|
||||
if engine.eval_metric_func is not None:
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
pred_list = []
|
||||
label_list = []
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
labels = paddle.concat(label_list, 0)
|
||||
|
||||
if isinstance(out, dict):
|
||||
out = out["logits"]
|
||||
paddle.distributed.all_gather(pred_list, out)
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
pred = paddle.concat(pred_list, 0)
|
||||
labels = paddle.concat(label_list, 0)
|
||||
if isinstance(out, list):
|
||||
pred = []
|
||||
for x in out:
|
||||
pred_list = []
|
||||
paddle.distributed.all_gather(pred_list, x)
|
||||
pred_x = paddle.concat(pred_list, 0)
|
||||
pred.append(pred_x)
|
||||
else:
|
||||
pred_list = []
|
||||
paddle.distributed.all_gather(pred_list, out)
|
||||
pred = paddle.concat(pred_list, 0)
|
||||
|
||||
if accum_samples > total_samples:
|
||||
pred = pred[:total_samples + current_samples -
|
||||
accum_samples]
|
||||
|
|
Loading…
Reference in New Issue