diff --git a/tools/program.py b/tools/program.py index afb8a4725..a8373435c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -134,9 +134,18 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False): if use_xpu and not paddle.device.is_compiled_with_xpu(): print(err.format("use_xpu", "xpu", "xpu", "use_xpu")) sys.exit(1) - if use_npu and not paddle.device.is_compiled_with_npu(): - print(err.format("use_npu", "npu", "npu", "use_npu")) - sys.exit(1) + if use_npu: + if int(paddle.version.major) != 0 and int( + paddle.version.major) <= 2 and int( + paddle.version.minor) <= 4: + if not paddle.device.is_compiled_with_npu(): + print(err.format("use_npu", "npu", "npu", "use_npu")) + sys.exit(1) + # is_compiled_with_npu() has been updated after paddle-2.4 + else: + if not paddle.device.is_compiled_with_custom_device("npu"): + print(err.format("use_npu", "npu", "npu", "use_npu")) + sys.exit(1) if use_mlu and not paddle.device.is_compiled_with_mlu(): print(err.format("use_mlu", "mlu", "mlu", "use_mlu")) sys.exit(1)