fix tablerec-rare train error
parent
48a785f9cd
commit
3e8c78b8c1
|
@ -82,7 +82,8 @@ class TableAttentionHead(nn.Layer):
|
|||
batch_size = fea.shape[0]
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size))
|
||||
output_hiddens = paddle.zeros(
|
||||
(batch_size, self.max_text_length + 1, self.hidden_size))
|
||||
if self.training and targets is not None:
|
||||
structure = targets[0]
|
||||
for i in range(self.max_text_length + 1):
|
||||
|
@ -91,19 +92,13 @@ class TableAttentionHead(nn.Layer):
|
|||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, fea, elem_onehots)
|
||||
output_hiddens[:, i, :] = outputs
|
||||
# output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
structure_probs = self.structure_generator(output)
|
||||
if self.loc_type == 1:
|
||||
loc_preds = self.loc_generator(output)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
structure_probs = self.structure_generator(output_hiddens)
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
structure_probs = None
|
||||
|
@ -118,17 +113,15 @@ class TableAttentionHead(nn.Layer):
|
|||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, fea, elem_onehots)
|
||||
output_hiddens[:, i, :] = outputs
|
||||
# output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
structure_probs_step = self.structure_generator(outputs)
|
||||
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
||||
|
||||
output = output_hiddens
|
||||
structure_probs = self.structure_generator(output)
|
||||
structure_probs = self.structure_generator(output_hiddens)
|
||||
structure_probs = F.softmax(structure_probs)
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||
loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
|
||||
|
@ -203,8 +196,10 @@ class SLAHead(nn.Layer):
|
|||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
structure_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.num_embeddings))
|
||||
loc_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.loc_reg_num))
|
||||
structure_preds = paddle.zeros(
|
||||
(batch_size, self.max_text_length + 1, self.num_embeddings))
|
||||
loc_preds = paddle.zeros(
|
||||
(batch_size, self.max_text_length + 1, self.loc_reg_num))
|
||||
structure_preds.stop_gradient = True
|
||||
loc_preds.stop_gradient = True
|
||||
if self.training and targets is not None:
|
||||
|
|
Loading…
Reference in New Issue