Merge pull request #2326 from PaddlePaddle/fix_bug
fix bugs to adapt to the new frameworkpull/2331/head
commit
8f7e260218
|
@ -81,7 +81,8 @@ def classification_eval(engine, epoch_id=0):
|
||||||
# gather Tensor when distributed
|
# gather Tensor when distributed
|
||||||
if paddle.distributed.get_world_size() > 1:
|
if paddle.distributed.get_world_size() > 1:
|
||||||
label_list = []
|
label_list = []
|
||||||
label = batch[1].cuda() if engine.config["Global"][
|
device_id = paddle.distributed.ParallelEnv().device_id
|
||||||
|
label = batch[1].cuda(device_id) if engine.config["Global"][
|
||||||
"device"] == "gpu" else batch[1]
|
"device"] == "gpu" else batch[1]
|
||||||
paddle.distributed.all_gather(label_list, label)
|
paddle.distributed.all_gather(label_list, label)
|
||||||
labels = paddle.concat(label_list, 0)
|
labels = paddle.concat(label_list, 0)
|
||||||
|
|
Loading…
Reference in New Issue