fix fused_attn with fp32 and add table cls label list
parent
19e9cfeaf4
commit
fe0e1234b8
|
@ -65,6 +65,10 @@ class Engine(object):
|
||||||
self.is_rec = True
|
self.is_rec = True
|
||||||
else:
|
else:
|
||||||
self.is_rec = False
|
self.is_rec = False
|
||||||
|
if self.config["Arch"].get("use_fused_attn", False):
|
||||||
|
if not self.config.get("AMP", {}).get("use_amp", False):
|
||||||
|
self.config["Arch"]["use_fused_attn"] = False
|
||||||
|
self.config["Arch"]["use_fused_linear"] = False
|
||||||
|
|
||||||
# set seed
|
# set seed
|
||||||
seed = self.config["Global"].get("seed", False)
|
seed = self.config["Global"].get("seed", False)
|
||||||
|
@ -105,7 +109,8 @@ class Engine(object):
|
||||||
|
|
||||||
# set device
|
# set device
|
||||||
assert self.config["Global"]["device"] in [
|
assert self.config["Global"]["device"] in [
|
||||||
"cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps", "gcu"
|
"cpu", "gpu", "xpu", "npu", "mlu", "dcu", "ascend", "intel_gpu",
|
||||||
|
"mps", "gcu"
|
||||||
]
|
]
|
||||||
self.device = paddle.set_device(self.config["Global"]["device"])
|
self.device = paddle.set_device(self.config["Global"]["device"])
|
||||||
logger.info('train with paddle {} and device {}'.format(
|
logger.info('train with paddle {} and device {}'.format(
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
0 wired_table
|
||||||
|
1 wireless_table
|
Loading…
Reference in New Issue