Merge pull request #4719 from tink2123/fix_ce_loss_2.3
[cherry-pick] fix attn loss for cepull/4332/head^2
commit
ccf749e3ee
|
@ -318,7 +318,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len:
|
||||
if len(text) >= self.max_text_len - 1:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||
|
|
|
@ -75,6 +75,7 @@ class AttentionHead(nn.Layer):
|
|||
probs_step, axis=1)], axis=1)
|
||||
next_input = probs_step.argmax(axis=1)
|
||||
targets = next_input
|
||||
if not self.training:
|
||||
probs = paddle.nn.functional.softmax(probs, axis=2)
|
||||
return probs
|
||||
|
||||
|
|
Loading…
Reference in New Issue