fix bugs to adapt to the new framework
parent
98d8405268
commit
d19b2712f1
|
@ -81,8 +81,9 @@ def classification_eval(engine, epoch_id=0):
|
|||
# gather Tensor when distributed
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
label_list = []
|
||||
|
||||
paddle.distributed.all_gather(label_list, batch[1])
|
||||
label = batch[1].cuda() if engine.config["Global"][
|
||||
"device"] == "gpu" else batch[1]
|
||||
paddle.distributed.all_gather(label_list, label)
|
||||
labels = paddle.concat(label_list, 0)
|
||||
|
||||
if isinstance(out, list):
|
||||
|
|
Loading…
Reference in New Issue