[npu] Update npu api (#9686)

* update npu api

* add version check
pull/9746/head
duanyanhui 2023-04-10 16:30:54 +08:00 committed by GitHub
parent 7bae3db2ec
commit 4b8e333f10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 3 deletions

View File

@ -134,9 +134,18 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
if use_xpu and not paddle.device.is_compiled_with_xpu():
print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
sys.exit(1)
if use_npu and not paddle.device.is_compiled_with_npu():
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
if use_npu:
if int(paddle.version.major) != 0 and int(
paddle.version.major) <= 2 and int(
paddle.version.minor) <= 4:
if not paddle.device.is_compiled_with_npu():
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
# is_compiled_with_npu() has been updated after paddle-2.4
else:
if not paddle.device.is_compiled_with_custom_device("npu"):
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
if use_mlu and not paddle.device.is_compiled_with_mlu():
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
sys.exit(1)