[GCU] Support GCU backend (#3302)
parent
9a74666e93
commit
e2b8b289f5
|
@ -74,6 +74,22 @@ class Predictor(object):
|
||||||
config.enable_xpu()
|
config.enable_xpu()
|
||||||
elif args.get("use_mlu", False):
|
elif args.get("use_mlu", False):
|
||||||
config.enable_custom_device('mlu')
|
config.enable_custom_device('mlu')
|
||||||
|
elif args.get("use_gcu", False):
|
||||||
|
assert paddle.device.is_compiled_with_custom_device("gcu"), (
|
||||||
|
"Config use_gcu cannot be set as True while your paddle "
|
||||||
|
"is not compiled with gcu! \nPlease try: \n"
|
||||||
|
"\t1. Install paddle-custom-gcu to run model on GCU. \n"
|
||||||
|
"\t2. Set use_gcu as False in config file to run model on CPU."
|
||||||
|
)
|
||||||
|
import paddle_custom_device.gcu.passes as gcu_passes
|
||||||
|
gcu_passes.setUp()
|
||||||
|
config.enable_custom_device("gcu")
|
||||||
|
config.enable_new_ir(True)
|
||||||
|
config.enable_new_executor(True)
|
||||||
|
kPirGcuPasses = gcu_passes.inference_passes(
|
||||||
|
use_pir=True, name="PaddleClas"
|
||||||
|
)
|
||||||
|
config.enable_custom_passes(kPirGcuPasses, True)
|
||||||
else:
|
else:
|
||||||
config.disable_gpu()
|
config.disable_gpu()
|
||||||
if args.enable_mkldnn:
|
if args.enable_mkldnn:
|
||||||
|
|
|
@ -105,7 +105,7 @@ class Engine(object):
|
||||||
|
|
||||||
# set device
|
# set device
|
||||||
assert self.config["Global"]["device"] in [
|
assert self.config["Global"]["device"] in [
|
||||||
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps"
|
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps", "gcu"
|
||||||
]
|
]
|
||||||
self.device = paddle.set_device(self.config["Global"]["device"])
|
self.device = paddle.set_device(self.config["Global"]["device"])
|
||||||
logger.info('train with paddle {} and device {}'.format(
|
logger.info('train with paddle {} and device {}'.format(
|
||||||
|
|
|
@ -92,7 +92,7 @@ def main(args):
|
||||||
|
|
||||||
# assign the device
|
# assign the device
|
||||||
assert global_config["device"] in [
|
assert global_config["device"] in [
|
||||||
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps"
|
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps", "gcu"
|
||||||
]
|
]
|
||||||
device = paddle.set_device(global_config["device"])
|
device = paddle.set_device(global_config["device"])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue