From 2e25edc5ea97c08250d1a510b2494ebe11f9d3a5 Mon Sep 17 00:00:00 2001 From: zhangyubo0722 <94225063+zhangyubo0722@users.noreply.github.com> Date: Tue, 22 Apr 2025 15:04:57 +0800 Subject: [PATCH] fix fused_attn with fp32 and add table cls label list (#3377) --- ppcls/engine/engine.py | 7 ++++++- .../PULC_label_list/table_classification_label_list.txt | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 ppcls/utils/PULC_label_list/table_classification_label_list.txt diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 1bc54249b..202b418f4 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -65,6 +65,10 @@ class Engine(object): self.is_rec = True else: self.is_rec = False + if self.config["Arch"].get("use_fused_attn", False): + if not self.config.get("AMP", {}).get("use_amp", False): + self.config["Arch"]["use_fused_attn"] = False + self.config["Arch"]["use_fused_linear"] = False # set seed seed = self.config["Global"].get("seed", False) @@ -105,7 +109,8 @@ class Engine(object): # set device assert self.config["Global"]["device"] in [ - "cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps", "gcu" + "cpu", "gpu", "xpu", "npu", "mlu", "dcu", "ascend", "intel_gpu", + "mps", "gcu" ] self.device = paddle.set_device(self.config["Global"]["device"]) logger.info('train with paddle {} and device {}'.format( diff --git a/ppcls/utils/PULC_label_list/table_classification_label_list.txt b/ppcls/utils/PULC_label_list/table_classification_label_list.txt new file mode 100644 index 000000000..d97f80c74 --- /dev/null +++ b/ppcls/utils/PULC_label_list/table_classification_label_list.txt @@ -0,0 +1,2 @@ +0 wired_table +1 wireless_table