Merge pull request #2326 from PaddlePaddle/fix_bug

fix bugs to adapt to the new framework
pull/2331/head
cuicheng01 2022-09-20 19:05:42 +08:00 committed by GitHub
commit 8f7e260218
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -81,7 +81,8 @@ def classification_eval(engine, epoch_id=0):
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
label = batch[1].cuda() if engine.config["Global"][
device_id = paddle.distributed.ParallelEnv().device_id
label = batch[1].cuda(device_id) if engine.config["Global"][
"device"] == "gpu" else batch[1]
paddle.distributed.all_gather(label_list, label)
labels = paddle.concat(label_list, 0)