[GCU] Support GCU backend (#3301)

pull/3307/head
EnflameGCU 2024-11-20 11:03:15 +08:00 committed by GitHub
parent 77ddc6fd7e
commit 1850bca5d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 2 deletions

View File

@ -74,6 +74,22 @@ class Predictor(object):
config.enable_xpu()
elif args.get("use_mlu", False):
config.enable_custom_device('mlu')
elif args.get("use_gcu", False):
assert paddle.device.is_compiled_with_custom_device("gcu"), (
"Config use_gcu cannot be set as True while your paddle "
"is not compiled with gcu! \nPlease try: \n"
"\t1. Install paddle-custom-gcu to run model on GCU. \n"
"\t2. Set use_gcu as False in config file to run model on CPU."
)
import paddle_custom_device.gcu.passes as gcu_passes
gcu_passes.setUp()
config.enable_custom_device("gcu")
config.enable_new_ir(True)
config.enable_new_executor(True)
kPirGcuPasses = gcu_passes.inference_passes(
use_pir=True, name="PaddleClas"
)
config.enable_custom_passes(kPirGcuPasses, True)
else:
config.disable_gpu()
if args.enable_mkldnn:

View File

@ -105,7 +105,7 @@ class Engine(object):
# set device
assert self.config["Global"]["device"] in [
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps"
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps", "gcu"
]
self.device = paddle.set_device(self.config["Global"]["device"])
logger.info('train with paddle {} and device {}'.format(

View File

@ -92,7 +92,7 @@ def main(args):
# assign the device
assert global_config["device"] in [
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps"
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps", "gcu"
]
device = paddle.set_device(global_config["device"])