fix fused_attn with fp32 and add table cls label list

pull/3377/head
zhangyubo0722 2025-04-21 14:03:32 +00:00
parent 19e9cfeaf4
commit fe0e1234b8
2 changed files with 8 additions and 1 deletions

View File

@ -65,6 +65,10 @@ class Engine(object):
self.is_rec = True self.is_rec = True
else: else:
self.is_rec = False self.is_rec = False
if self.config["Arch"].get("use_fused_attn", False):
if not self.config.get("AMP", {}).get("use_amp", False):
self.config["Arch"]["use_fused_attn"] = False
self.config["Arch"]["use_fused_linear"] = False
# set seed # set seed
seed = self.config["Global"].get("seed", False) seed = self.config["Global"].get("seed", False)
@ -105,7 +109,8 @@ class Engine(object):
# set device # set device
assert self.config["Global"]["device"] in [ assert self.config["Global"]["device"] in [
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps", "gcu" "cpu", "gpu", "xpu", "npu", "mlu", "dcu", "ascend", "intel_gpu",
"mps", "gcu"
] ]
self.device = paddle.set_device(self.config["Global"]["device"]) self.device = paddle.set_device(self.config["Global"]["device"])
logger.info('train with paddle {} and device {}'.format( logger.info('train with paddle {} and device {}'.format(

View File

@ -0,0 +1,2 @@
0 wired_table
1 wireless_table