fix bugs to adapt to the new framework

pull/2240/merge
cuicheng01 2022-09-19 02:01:31 +00:00
parent 98d8405268
commit d19b2712f1
1 changed files with 3 additions and 2 deletions

View File

@ -81,8 +81,9 @@ def classification_eval(engine, epoch_id=0):
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
label = batch[1].cuda() if engine.config["Global"][
"device"] == "gpu" else batch[1]
paddle.distributed.all_gather(label_list, label)
labels = paddle.concat(label_list, 0)
if isinstance(out, list):