* add multi xpu support for PaddleClas about dygraph * add dygraph multi xpu supportpull/711/head
parent
82142837f9
commit
2a41727dfd
|
@ -56,7 +56,16 @@ def main(args):
|
||||||
config = get_config(args.config, overrides=args.override, show=True)
|
config = get_config(args.config, overrides=args.override, show=True)
|
||||||
# assign the place
|
# assign the place
|
||||||
use_gpu = config.get("use_gpu", True)
|
use_gpu = config.get("use_gpu", True)
|
||||||
place = paddle.set_device('gpu' if use_gpu else 'cpu')
|
use_xpu = config.get("use_xpu", False)
|
||||||
|
assert (
|
||||||
|
use_gpu and use_xpu
|
||||||
|
) is not True, "gpu and xpu can not be true in the same time in static mode!"
|
||||||
|
if use_gpu:
|
||||||
|
place = paddle.set_device('gpu')
|
||||||
|
elif use_xpu:
|
||||||
|
place = paddle.set_device('xpu')
|
||||||
|
else:
|
||||||
|
place = paddle.set_device('cpu')
|
||||||
|
|
||||||
trainer_num = paddle.distributed.get_world_size()
|
trainer_num = paddle.distributed.get_world_size()
|
||||||
use_data_parallel = trainer_num != 1
|
use_data_parallel = trainer_num != 1
|
||||||
|
|
Loading…
Reference in New Issue