repair formula bug when export (#14442)

pull/14467/head
liuhongen1234567 2024-12-24 17:44:31 +08:00 committed by GitHub
parent d523388ed1
commit 0d41ffc91d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 6 deletions

View File

@ -424,6 +424,7 @@ class CustomMBartDecoder(MBartDecoder):
output_hidden_states=None,
return_dict=None,
):
self.is_export = False if self.training else True
output_attentions = (
output_attentions
@ -1370,6 +1371,7 @@ class PPFormulaNet_Head(UniMERNetHead):
# forward for export
def forward(self, inputs, targets=None):
self.is_export = False if self.training else True
if not self.training:
encoder_outputs = inputs
model_kwargs = {

View File

@ -1518,11 +1518,6 @@ class MyMultiheadAttention(nn.Layer):
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(nn.MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: paddle.Tensor,
@ -1680,7 +1675,7 @@ class CustomMBartDecoder(MBartDecoder):
output_hidden_states=None,
return_dict=None,
):
self.is_export = False if self.training else True
output_attentions = (
output_attentions
if output_attentions is not None
@ -2650,6 +2645,7 @@ class UniMERNetHead(nn.Layer):
During inference: Returns predicted latex code.
During training: Returns logits, predicted counts, and masked labels.
"""
self.is_export = False if self.training else True
if not self.training:
encoder_outputs = inputs
if self.is_export: