reset latex ocr (#14047)

pull/14061/head v2.9.0
zhangyubo0722 2024-10-18 22:28:43 +08:00 committed by GitHub
parent cb36de1294
commit ee1aa57e52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -948,15 +948,15 @@ class LaTeXOCRHead(nn.Layer):
b, t = start_tokens.shape
self.net.eval()
out_tmp = start_tokens
out = start_tokens
mask = kwargs.pop("mask", None)
if mask is None:
mask = paddle.full_like(out_tmp, True, dtype=paddle.bool)
mask = paddle.full_like(out, True, dtype=paddle.bool)
i_idx = paddle.full([], 0)
while i_idx < paddle.to_tensor(seq_len):
x = out_tmp[:, -self.max_seq_len :]
x = out[:, -self.max_seq_len :]
paddle.jit.api.set_dynamic_shape(x, [-1, -1])
mask = mask[:, -self.max_seq_len :]
paddle.jit.api.set_dynamic_shape(mask, [-1, -1])
@ -969,7 +969,7 @@ class LaTeXOCRHead(nn.Layer):
probs = F.softmax(filtered_logits / temperature, axis=-1)
sample = paddle.multinomial(probs, 1)
out = paddle.concat((out_tmp, sample), axis=-1)
out = paddle.concat((out, sample), axis=-1)
pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")
mask = paddle.concat((mask, pad_mask), axis=1)