[GCU] Support inference for GCU (#14142)

pull/14302/head
EnflameGCU 2024-11-21 10:10:41 +08:00 committed by GitHub
parent fbba2178d7
commit c8874d717f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 3 deletions

View File

@ -41,6 +41,12 @@ def init_args():
parser.add_argument("--use_xpu", type=str2bool, default=False)
parser.add_argument("--use_npu", type=str2bool, default=False)
parser.add_argument("--use_mlu", type=str2bool, default=False)
parser.add_argument(
"--use_gcu",
type=str2bool,
default=False,
help="Use Enflame GCU(General Compute Unit)",
)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15)
@ -298,6 +304,34 @@ def create_predictor(args, mode, logger):
config.enable_custom_device("mlu")
elif args.use_xpu:
config.enable_xpu(10 * 1024 * 1024)
elif args.use_gcu: # for Enflame GCU(General Compute Unit)
assert paddle.device.is_compiled_with_custom_device("gcu"), (
"Args use_gcu cannot be set as True while your paddle "
"is not compiled with gcu! \nPlease try: \n"
"\t1. Install paddle-custom-gcu to run model on GCU. \n"
"\t2. Set use_gcu as False in args to run model on CPU."
)
import paddle_custom_device.gcu.passes as gcu_passes
gcu_passes.setUp()
if args.precision == "fp16":
config.enable_custom_device(
"gcu", 0, paddle.inference.PrecisionType.Half
)
gcu_passes.set_exp_enable_mixed_precision_ops(config)
else:
config.enable_custom_device("gcu")
if paddle.framework.use_pir_api():
config.enable_new_ir(True)
config.enable_new_executor(True)
kPirGcuPasses = gcu_passes.inference_passes(
use_pir=True, name="PaddleOCR"
)
config.enable_custom_passes(kPirGcuPasses, True)
else:
pass_builder = config.pass_builder()
gcu_passes.append_passes_for_legacy_ir(pass_builder, "PaddleOCR")
else:
config.disable_gpu()
if args.enable_mkldnn:
@ -314,7 +348,8 @@ def create_predictor(args, mode, logger):
# enable memory optim
config.enable_memory_optim()
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
if not args.use_gcu: # for Enflame GCU(General Compute Unit)
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.delete_pass("matmul_transpose_reshape_fuse_pass")
if mode == "rec" and args.rec_algorithm == "SRN":
config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")

View File

@ -115,7 +115,7 @@ def merge_config(config, opts):
return config
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False, use_gcu=False):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
@ -154,6 +154,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
if use_mlu and not paddle.device.is_compiled_with_mlu():
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
sys.exit(1)
if use_gcu and not paddle.device.is_compiled_with_custom_device("gcu"):
print(err.format("use_gcu", "gcu", "gcu", "use_gcu"))
sys.exit(1)
except Exception as e:
pass
@ -799,6 +802,7 @@ def preprocess(is_train=False):
use_xpu = config["Global"].get("use_xpu", False)
use_npu = config["Global"].get("use_npu", False)
use_mlu = config["Global"].get("use_mlu", False)
use_gcu = config["Global"].get("use_gcu", False)
alg = config["Architecture"]["algorithm"]
assert alg in [
@ -853,9 +857,11 @@ def preprocess(is_train=False):
device = "npu:{0}".format(os.getenv("FLAGS_selected_npus", 0))
elif use_mlu:
device = "mlu:{0}".format(os.getenv("FLAGS_selected_mlus", 0))
elif use_gcu: # Use Enflame GCU(General Compute Unit)
device = "gcu:{0}".format(os.getenv("FLAGS_selected_gcus", 0))
else:
device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
check_device(use_gpu, use_xpu, use_npu, use_mlu)
check_device(use_gpu, use_xpu, use_npu, use_mlu, use_gcu)
device = paddle.set_device(device)