Add npu supporting (#1324)
parent
cc00a51af7
commit
a0eb34a642
|
@ -91,7 +91,7 @@ class Engine(object):
|
|||
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
|
||||
|
||||
# set device
|
||||
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu"]
|
||||
self.device = paddle.set_device(self.config["Global"]["device"])
|
||||
logger.info('train with paddle {} and device {}'.format(
|
||||
paddle.__version__, self.device))
|
||||
|
|
|
@ -91,14 +91,17 @@ def main(args):
|
|||
os.environ[k] = AMP_RELATED_FLAGS_SETTING[k]
|
||||
|
||||
use_xpu = global_config.get("use_xpu", False)
|
||||
use_npu = global_config.get("use_npu", False)
|
||||
assert (
|
||||
use_gpu and use_xpu
|
||||
) is not True, "gpu and xpu can not be true in the same time in static mode!"
|
||||
use_gpu and use_xpu and use_npu
|
||||
) is not True, "gpu, xpu and npu can not be true in the same time in static mode!"
|
||||
|
||||
if use_gpu:
|
||||
device = paddle.set_device('gpu')
|
||||
elif use_xpu:
|
||||
device = paddle.set_device('xpu')
|
||||
elif use_npu:
|
||||
device = paddle.set_device('npu')
|
||||
else:
|
||||
device = paddle.set_device('cpu')
|
||||
|
||||
|
|
Loading…
Reference in New Issue