From 0d41ffc91d218d0abd6ed7e27c23983740ca3cb1 Mon Sep 17 00:00:00 2001 From: liuhongen1234567 <65936492+liuhongen1234567@users.noreply.github.com> Date: Tue, 24 Dec 2024 17:44:31 +0800 Subject: [PATCH] repair formula bug when export (#14442) --- ppocr/modeling/heads/rec_ppformulanet_head.py | 2 ++ ppocr/modeling/heads/rec_unimernet_head.py | 8 ++------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/ppocr/modeling/heads/rec_ppformulanet_head.py b/ppocr/modeling/heads/rec_ppformulanet_head.py index 286cc45305..ba46c29dc6 100644 --- a/ppocr/modeling/heads/rec_ppformulanet_head.py +++ b/ppocr/modeling/heads/rec_ppformulanet_head.py @@ -424,6 +424,7 @@ class CustomMBartDecoder(MBartDecoder): output_hidden_states=None, return_dict=None, ): + self.is_export = False if self.training else True output_attentions = ( output_attentions @@ -1370,6 +1371,7 @@ class PPFormulaNet_Head(UniMERNetHead): # forward for export def forward(self, inputs, targets=None): + self.is_export = False if self.training else True if not self.training: encoder_outputs = inputs model_kwargs = { diff --git a/ppocr/modeling/heads/rec_unimernet_head.py b/ppocr/modeling/heads/rec_unimernet_head.py index 95916c9dbc..4123fc355d 100644 --- a/ppocr/modeling/heads/rec_unimernet_head.py +++ b/ppocr/modeling/heads/rec_unimernet_head.py @@ -1518,11 +1518,6 @@ class MyMultiheadAttention(nn.Layer): if self.bias_v is not None: xavier_normal_(self.bias_v) - def __setstate__(self, state): - if "_qkv_same_embed_dim" not in state: - state["_qkv_same_embed_dim"] = True - super(nn.MultiheadAttention, self).__setstate__(state) - def forward( self, query: paddle.Tensor, @@ -1680,7 +1675,7 @@ class CustomMBartDecoder(MBartDecoder): output_hidden_states=None, return_dict=None, ): - + self.is_export = False if self.training else True output_attentions = ( output_attentions if output_attentions is not None @@ -2650,6 +2645,7 @@ class UniMERNetHead(nn.Layer): During inference: Returns predicted latex code. During training: Returns logits, predicted counts, and masked labels. """ + self.is_export = False if self.training else True if not self.training: encoder_outputs = inputs if self.is_export: