repair formula bug when export (#14442)
parent
d523388ed1
commit
0d41ffc91d
|
@ -424,6 +424,7 @@ class CustomMBartDecoder(MBartDecoder):
|
|||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
self.is_export = False if self.training else True
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
|
@ -1370,6 +1371,7 @@ class PPFormulaNet_Head(UniMERNetHead):
|
|||
|
||||
# forward for export
|
||||
def forward(self, inputs, targets=None):
|
||||
self.is_export = False if self.training else True
|
||||
if not self.training:
|
||||
encoder_outputs = inputs
|
||||
model_kwargs = {
|
||||
|
|
|
@ -1518,11 +1518,6 @@ class MyMultiheadAttention(nn.Layer):
|
|||
if self.bias_v is not None:
|
||||
xavier_normal_(self.bias_v)
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "_qkv_same_embed_dim" not in state:
|
||||
state["_qkv_same_embed_dim"] = True
|
||||
super(nn.MultiheadAttention, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: paddle.Tensor,
|
||||
|
@ -1680,7 +1675,7 @@ class CustomMBartDecoder(MBartDecoder):
|
|||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
|
||||
self.is_export = False if self.training else True
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
|
@ -2650,6 +2645,7 @@ class UniMERNetHead(nn.Layer):
|
|||
During inference: Returns predicted latex code.
|
||||
During training: Returns logits, predicted counts, and masked labels.
|
||||
"""
|
||||
self.is_export = False if self.training else True
|
||||
if not self.training:
|
||||
encoder_outputs = inputs
|
||||
if self.is_export:
|
||||
|
|
Loading…
Reference in New Issue