[GCU] Support inference for GCU (#14142)
parent
fbba2178d7
commit
c8874d717f
|
@ -41,6 +41,12 @@ def init_args():
|
|||
parser.add_argument("--use_xpu", type=str2bool, default=False)
|
||||
parser.add_argument("--use_npu", type=str2bool, default=False)
|
||||
parser.add_argument("--use_mlu", type=str2bool, default=False)
|
||||
parser.add_argument(
|
||||
"--use_gcu",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use Enflame GCU(General Compute Unit)",
|
||||
)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--min_subgraph_size", type=int, default=15)
|
||||
|
@ -298,6 +304,34 @@ def create_predictor(args, mode, logger):
|
|||
config.enable_custom_device("mlu")
|
||||
elif args.use_xpu:
|
||||
config.enable_xpu(10 * 1024 * 1024)
|
||||
elif args.use_gcu: # for Enflame GCU(General Compute Unit)
|
||||
assert paddle.device.is_compiled_with_custom_device("gcu"), (
|
||||
"Args use_gcu cannot be set as True while your paddle "
|
||||
"is not compiled with gcu! \nPlease try: \n"
|
||||
"\t1. Install paddle-custom-gcu to run model on GCU. \n"
|
||||
"\t2. Set use_gcu as False in args to run model on CPU."
|
||||
)
|
||||
import paddle_custom_device.gcu.passes as gcu_passes
|
||||
|
||||
gcu_passes.setUp()
|
||||
if args.precision == "fp16":
|
||||
config.enable_custom_device(
|
||||
"gcu", 0, paddle.inference.PrecisionType.Half
|
||||
)
|
||||
gcu_passes.set_exp_enable_mixed_precision_ops(config)
|
||||
else:
|
||||
config.enable_custom_device("gcu")
|
||||
|
||||
if paddle.framework.use_pir_api():
|
||||
config.enable_new_ir(True)
|
||||
config.enable_new_executor(True)
|
||||
kPirGcuPasses = gcu_passes.inference_passes(
|
||||
use_pir=True, name="PaddleOCR"
|
||||
)
|
||||
config.enable_custom_passes(kPirGcuPasses, True)
|
||||
else:
|
||||
pass_builder = config.pass_builder()
|
||||
gcu_passes.append_passes_for_legacy_ir(pass_builder, "PaddleOCR")
|
||||
else:
|
||||
config.disable_gpu()
|
||||
if args.enable_mkldnn:
|
||||
|
@ -314,7 +348,8 @@ def create_predictor(args, mode, logger):
|
|||
# enable memory optim
|
||||
config.enable_memory_optim()
|
||||
config.disable_glog_info()
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
if not args.use_gcu: # for Enflame GCU(General Compute Unit)
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
config.delete_pass("matmul_transpose_reshape_fuse_pass")
|
||||
if mode == "rec" and args.rec_algorithm == "SRN":
|
||||
config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")
|
||||
|
|
|
@ -115,7 +115,7 @@ def merge_config(config, opts):
|
|||
return config
|
||||
|
||||
|
||||
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
|
||||
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False, use_gcu=False):
|
||||
"""
|
||||
Log error and exit when set use_gpu=true in paddlepaddle
|
||||
cpu version.
|
||||
|
@ -154,6 +154,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
|
|||
if use_mlu and not paddle.device.is_compiled_with_mlu():
|
||||
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
|
||||
sys.exit(1)
|
||||
if use_gcu and not paddle.device.is_compiled_with_custom_device("gcu"):
|
||||
print(err.format("use_gcu", "gcu", "gcu", "use_gcu"))
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
@ -799,6 +802,7 @@ def preprocess(is_train=False):
|
|||
use_xpu = config["Global"].get("use_xpu", False)
|
||||
use_npu = config["Global"].get("use_npu", False)
|
||||
use_mlu = config["Global"].get("use_mlu", False)
|
||||
use_gcu = config["Global"].get("use_gcu", False)
|
||||
|
||||
alg = config["Architecture"]["algorithm"]
|
||||
assert alg in [
|
||||
|
@ -853,9 +857,11 @@ def preprocess(is_train=False):
|
|||
device = "npu:{0}".format(os.getenv("FLAGS_selected_npus", 0))
|
||||
elif use_mlu:
|
||||
device = "mlu:{0}".format(os.getenv("FLAGS_selected_mlus", 0))
|
||||
elif use_gcu: # Use Enflame GCU(General Compute Unit)
|
||||
device = "gcu:{0}".format(os.getenv("FLAGS_selected_gcus", 0))
|
||||
else:
|
||||
device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu"
|
||||
check_device(use_gpu, use_xpu, use_npu, use_mlu)
|
||||
check_device(use_gpu, use_xpu, use_npu, use_mlu, use_gcu)
|
||||
|
||||
device = paddle.set_device(device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue