reset latex ocr (#14047)

pull/14061/head v2.9.0
zhangyubo0722 2024-10-18 22:28:43 +08:00 committed by GitHub
parent cb36de1294
commit ee1aa57e52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -948,15 +948,15 @@ class LaTeXOCRHead(nn.Layer):
b, t = start_tokens.shape b, t = start_tokens.shape
self.net.eval() self.net.eval()
out_tmp = start_tokens out = start_tokens
mask = kwargs.pop("mask", None) mask = kwargs.pop("mask", None)
if mask is None: if mask is None:
mask = paddle.full_like(out_tmp, True, dtype=paddle.bool) mask = paddle.full_like(out, True, dtype=paddle.bool)
i_idx = paddle.full([], 0) i_idx = paddle.full([], 0)
while i_idx < paddle.to_tensor(seq_len): while i_idx < paddle.to_tensor(seq_len):
x = out_tmp[:, -self.max_seq_len :] x = out[:, -self.max_seq_len :]
paddle.jit.api.set_dynamic_shape(x, [-1, -1]) paddle.jit.api.set_dynamic_shape(x, [-1, -1])
mask = mask[:, -self.max_seq_len :] mask = mask[:, -self.max_seq_len :]
paddle.jit.api.set_dynamic_shape(mask, [-1, -1]) paddle.jit.api.set_dynamic_shape(mask, [-1, -1])
@ -969,7 +969,7 @@ class LaTeXOCRHead(nn.Layer):
probs = F.softmax(filtered_logits / temperature, axis=-1) probs = F.softmax(filtered_logits / temperature, axis=-1)
sample = paddle.multinomial(probs, 1) sample = paddle.multinomial(probs, 1)
out = paddle.concat((out_tmp, sample), axis=-1) out = paddle.concat((out, sample), axis=-1)
pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool") pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")
mask = paddle.concat((mask, pad_mask), axis=1) mask = paddle.concat((mask, pad_mask), axis=1)