From 3e8c78b8c1210cd07f0ea1ab83dc521937814fd4 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Fri, 30 Sep 2022 07:01:43 +0000 Subject: [PATCH] fix tablerec-rare train error --- ppocr/modeling/heads/table_att_head.py | 35 +++++++++++--------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index 50910c5b7..e3fc8436e 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -82,7 +82,8 @@ class TableAttentionHead(nn.Layer): batch_size = fea.shape[0] hidden = paddle.zeros((batch_size, self.hidden_size)) - output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size)) + output_hiddens = paddle.zeros( + (batch_size, self.max_text_length + 1, self.hidden_size)) if self.training and targets is not None: structure = targets[0] for i in range(self.max_text_length + 1): @@ -91,19 +92,13 @@ class TableAttentionHead(nn.Layer): (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) output_hiddens[:, i, :] = outputs - # output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) - output = paddle.concat(output_hiddens, axis=1) - structure_probs = self.structure_generator(output) - if self.loc_type == 1: - loc_preds = self.loc_generator(output) - loc_preds = F.sigmoid(loc_preds) - else: - loc_fea = fea.transpose([0, 2, 1]) - loc_fea = self.loc_fea_trans(loc_fea) - loc_fea = loc_fea.transpose([0, 2, 1]) - loc_concat = paddle.concat([output, loc_fea], axis=2) - loc_preds = self.loc_generator(loc_concat) - loc_preds = F.sigmoid(loc_preds) + structure_probs = self.structure_generator(output_hiddens) + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) else: temp_elem = paddle.zeros(shape=[batch_size], dtype="int32") structure_probs = None @@ -118,17 +113,15 @@ class TableAttentionHead(nn.Layer): (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) output_hiddens[:, i, :] = outputs - # output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) structure_probs_step = self.structure_generator(outputs) temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") - output = output_hiddens - structure_probs = self.structure_generator(output) + structure_probs = self.structure_generator(output_hiddens) structure_probs = F.softmax(structure_probs) loc_fea = fea.transpose([0, 2, 1]) loc_fea = self.loc_fea_trans(loc_fea) loc_fea = loc_fea.transpose([0, 2, 1]) - loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2) loc_preds = self.loc_generator(loc_concat) loc_preds = F.sigmoid(loc_preds) return {'structure_probs': structure_probs, 'loc_preds': loc_preds} @@ -203,8 +196,10 @@ class SLAHead(nn.Layer): fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) hidden = paddle.zeros((batch_size, self.hidden_size)) - structure_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.num_embeddings)) - loc_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.loc_reg_num)) + structure_preds = paddle.zeros( + (batch_size, self.max_text_length + 1, self.num_embeddings)) + loc_preds = paddle.zeros( + (batch_size, self.max_text_length + 1, self.loc_reg_num)) structure_preds.stop_gradient = True loc_preds.stop_gradient = True if self.training and targets is not None: