From b031ae47f64de4583f0161f19cba7f92f377c445 Mon Sep 17 00:00:00 2001 From: zhangyubo0722 Date: Fri, 28 Mar 2025 13:11:23 +0000 Subject: [PATCH] fix fused_attn with fp32 --- ppcls/engine/engine.py | 8 ++++++-- .../PULC_label_list/table_classification_label_list.txt | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 ppcls/utils/PULC_label_list/table_classification_label_list.txt diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 6e342c342..23ad0565d 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -66,6 +66,10 @@ class Engine(object): self.is_rec = True else: self.is_rec = False + if self.config["Arch"].get("use_fused_attn", False): + if not self.config.get("AMP", {}).get("use_amp", False): + self.config["Arch"]["use_fused_attn"] = False + self.config["Arch"]["use_fused_linear"] = False # set seed seed = self.config["Global"].get("seed", False) @@ -106,8 +110,8 @@ class Engine(object): # set device assert self.config["Global"]["device"] in [ - "cpu", "gpu", "xpu", "npu", "mlu", "dcu", "ascend", "intel_gpu", "mps", - "gcu" + "cpu", "gpu", "xpu", "npu", "mlu", "dcu", "ascend", "intel_gpu", + "mps", "gcu" ] self.device = paddle.set_device(self.config["Global"]["device"]) logger.info('train with paddle {} and device {}'.format( diff --git a/ppcls/utils/PULC_label_list/table_classification_label_list.txt b/ppcls/utils/PULC_label_list/table_classification_label_list.txt new file mode 100644 index 000000000..d97f80c74 --- /dev/null +++ b/ppcls/utils/PULC_label_list/table_classification_label_list.txt @@ -0,0 +1,2 @@ +0 wired_table +1 wireless_table