PaddleOCR/ppocr/modeling/heads/rec_latexocr_head.py

1028 lines
31 KiB
Python

# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/models/transformer.py
"""
import math
import paddle
from paddle import nn, einsum
import paddle.nn.functional as F
from functools import partial
from inspect import isfunction
from collections import namedtuple
from paddle.nn.initializer import (
TruncatedNormal,
Constant,
Normal,
KaimingUniform,
XavierUniform,
)
zeros_ = Constant(value=0.0)
ones_ = Constant(value=1.0)
normal_ = Normal(std=0.02)
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"])
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class always:
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
class not_equals:
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x != self.val
class equals:
def __init__(self, val):
self.val = val
def __call__(self, x, *args, **kwargs):
return x == self.val
def max_neg_value(tensor):
return -paddle.finfo(tensor.dtype).max
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(
partial(string_begins_with, prefix), d
)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
)
return kwargs_without_prefix, kwargs
# positional embeddings
class DepthWiseConv1d(nn.Layer):
def __init__(
self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True, groups=False
):
super().__init__()
groups = default(groups, dim_in)
self.net = nn.Sequential(
nn.Conv1D(
dim_in,
dim_in,
kernel_size=kernel_size,
padding=padding,
groups=dim_in,
stride=stride,
bias_attr=bias,
),
nn.Conv1D(dim_in, dim_out, 1),
)
def forward(self, x):
return self.net(x)
class AbsolutePositionalEmbedding(nn.Layer):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
self.init_()
def init_(self):
normal_(self.emb.weight)
def forward(self, x):
n = paddle.arange(x.shape[1])
return self.emb(n)[None, :, :]
class FixedPositionalEmbedding(nn.Layer):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = (
paddle.arange(
x.shape[seq_dim],
).type_as(self.inv_freq)
+ offset
)
sinusoid_inp = paddle.einsum("i , j -> i j", t, self.inv_freq)
emb = paddle.concat((sinusoid_inp.sin(), sinusoid_inp.cos()), axis=-1)
return emb[None, :, :]
class Scale(nn.Layer):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.value, *rest)
class Rezero(nn.Layer):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = paddle.create_parameter([1], dtype="float32")
zeros_(self.g)
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.g, *rest)
class ScaleNorm(nn.Layer):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = paddle.create_parameter([1], dtype="float32")
ones_(self.g)
def forward(self, x):
norm = paddle.norm(x, axis=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSNorm(nn.Layer):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = paddle.create_parameter([dim])
ones_(self.g)
def forward(self, x):
norm = paddle.norm(x, axis=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class Residual(nn.Layer):
def forward(self, x, residual):
return x + residual
class GEGLU(nn.Layer):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, axis=-1)
return x * F.gelu(gate)
class FeedForward(nn.Layer):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Layer):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
collab_heads=False,
collab_compression=0.3,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.0,
on_attn=False,
gate_values=False,
is_export=False,
):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
self.mask = mask
self.is_export = is_export
qk_dim = v_dim = dim_head * heads
# collaborative heads
self.collab_heads = collab_heads
if self.collab_heads:
qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(paddle.randn(heads, qk_dim))
self.to_q = nn.Linear(dim, qk_dim, bias_attr=False)
self.to_k = nn.Linear(dim, qk_dim, bias_attr=False)
self.to_v = nn.Linear(dim, v_dim, bias_attr=False)
self.dropout = nn.Dropout(dropout)
# add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, v_dim)
zeros_(self.to_v_gate.weight)
ones_(self.to_v_gate.bias)
# talking heads
self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_proj = nn.Parameter(paddle.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(paddle.randn(heads, heads))
# explicit topk sparse attention
self.sparse_topk = sparse_topk
self.attn_fn = F.softmax
# add memory key / values
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(paddle.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(paddle.randn(heads, num_mem_kv, dim_head))
# attention on attention
self.attn_on_attn = on_attn
self.to_out = (
nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU())
if on_attn
else nn.Linear(v_dim, dim)
)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
rotary_pos_emb=None,
prev_attn=None,
mem=None,
seq_len=0,
):
if not self.training:
self.is_export = True
b, n, _, h, talking_heads, collab_heads, has_context = (
*x.shape,
self.heads,
self.talking_heads,
self.collab_heads,
exists(context),
)
kv_input = default(context, x)
q_input = x
k_input = kv_input
v_input = kv_input
if exists(mem):
k_input = paddle.concat((mem, k_input), axis=-2)
v_input = paddle.concat((mem, v_input), axis=-2)
if exists(sinusoidal_emb):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset = k_input.shape[-2] - q_input.shape[-2]
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
k_input = k_input + sinusoidal_emb(k_input)
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
def rearrange_q_k_v(x, h, is_export):
if is_export:
b, n, h_d = paddle.shape(x)
else:
b, n, h_d = x.shape
d = h_d // h
return x.reshape([b, n, h, d]).transpose([0, 2, 1, 3])
q, k, v = map(
lambda t: rearrange_q_k_v(t, h, is_export=self.is_export), (q, k, v)
)
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(
mask,
lambda: paddle.ones(
(b, n),
).cast(paddle.bool),
)
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(
k_mask, lambda: paddle.ones((b, k.shape[-2])).cast(paddle.bool)
)
q_mask = q_mask.reshape([q_mask.shape[0], 1, q_mask.shape[1], 1])
k_mask = k_mask.reshape([k_mask.shape[0], 1, 1, k_mask.shape[1]])
input_mask = q_mask * k_mask
if collab_heads:
k = k.expand(-1, h, -1, -1)
dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
mask_value = max_neg_value(dots)
if exists(prev_attn):
dots = dots + prev_attn
pre_softmax_attn = dots.clone()
if talking_heads:
dots = einsum(
"b h i j, h k -> b k i j", dots, self.pre_softmax_proj
).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
input_mask = input_mask.cast(paddle.bool)
if exists(input_mask):
dots.masked_fill_(~input_mask, mask_value)
del input_mask
if self.causal:
i, j = dots.shape[-2:]
r = paddle.arange(i)
r_shape = r.shape[0]
mask = r.reshape([1, 1, r_shape, 1]) < r.reshape([1, 1, 1, r_shape])
if self.is_export:
pad_list = [
paddle.to_tensor(0, dtype="int32"),
paddle.to_tensor(0, dtype="int32"),
paddle.to_tensor(j - i, dtype="int32"),
paddle.to_tensor(0, dtype="int32"),
]
mask = F.pad(
mask.cast(paddle.int32),
paddle.to_tensor(pad_list).cast(paddle.int32),
value=False,
).cast(paddle.bool)
dots = dots.masked_fill_(mask, mask_value)
else:
mask = F.pad(mask.cast(paddle.int32), (0, 0, j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
top, _ = dots.topk(self.sparse_topk, dim=-1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, mask_value)
del mask
attn = self.attn_fn(dots, axis=-1)
post_softmax_attn = attn.clone()
attn = self.dropout(attn)
if talking_heads:
attn = einsum(
"b h i j, h k -> b k i j", attn, self.post_softmax_proj
).contiguous()
out = einsum("b h i j, b h j d -> b h i d", attn, v)
b, h, n, d = out.shape
out = out.transpose([0, 2, 1, 3]).reshape([b, n, h * d])
if exists(self.to_v_gate):
gates = self.gate_v(x)
out = out * gates.sigmoid()
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
)
return self.to_out(out), intermediates
class AttentionLayers(nn.Layer):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_bias=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
rotary_pos_emb=False,
rotary_emb_dim=None,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
is_export=False,
**kwargs,
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.LayerList([])
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
self.pia_pos_emb = (
FixedPositionalEmbedding(dim) if position_infused_attn else None
)
assert (
rel_pos_num_buckets <= rel_pos_max_distance
), "number of relative position buckets must be less than the relative position max distance"
self.pre_norm = pre_norm
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
self.cross_attend = cross_attend
self.rel_pos = None
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_fn = partial(norm_class, dim)
norm_fn = nn.Identity if use_rezero else norm_fn
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ("a", "c", "f")
elif cross_attend and only_cross:
default_block = ("c", "f")
else:
default_block = ("a", "f")
if macaron:
default_block = ("f",) + default_block
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, "par ratio out of range"
default_block = tuple(filter(not_equals("f"), default_block))
par_attn = par_depth // par_ratio
depth_cut = (
par_depth * 2 // 3
) # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert (
len(default_block) <= par_width
), "default block is too large for par_ratio"
par_block = default_block + ("f",) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ("f",) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert (
sandwich_coef > 0 and sandwich_coef <= depth
), "sandwich coefficient should be less than the depth"
layer_types = (
("a",) * sandwich_coef
+ default_block * (depth - sandwich_coef)
+ ("f",) * sandwich_coef
)
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
for layer_type in self.layer_types:
if layer_type == "a":
layer = Attention(
dim, heads=heads, causal=causal, is_export=is_export, **attn_kwargs
)
elif layer_type == "c":
layer = Attention(dim, heads=heads, is_export=is_export, **attn_kwargs)
elif layer_type == "f":
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f"invalid layer type {layer_type}")
if isinstance(layer, Attention) and exists(branch_fn):
layer = branch_fn(layer)
residual_fn = Residual()
self.layers.append(nn.LayerList([norm_fn(), layer, residual_fn]))
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
seq_len=0,
return_hiddens=False,
):
assert not (
self.cross_attend ^ exists(context)
), "context must be passed in if cross_attend is set to True"
hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
rotary_pos_emb = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)
):
is_last = ind == (len(self.layers) - 1)
if layer_type == "a":
hiddens.append(x)
layer_mem = mems.pop(0)
residual = x
if self.pre_norm:
x = norm(x)
if layer_type == "a":
out, inter = block(
x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
rotary_pos_emb=rotary_pos_emb,
prev_attn=prev_attn,
mem=layer_mem,
)
elif layer_type == "c":
out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn,
)
elif layer_type == "f":
out = block(x)
x = residual_fn(out, residual)
if layer_type in ("a", "c"):
intermediates.append(inter)
if layer_type == "a" and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == "c" and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if not self.pre_norm and not is_last:
x = norm(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens, attn_intermediates=intermediates
)
return x, intermediates
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert "causal" not in kwargs, "cannot set causality on encoder"
super().__init__(causal=False, **kwargs)
class Decoder(AttentionLayers):
def __init__(self, **kwargs):
assert "causal" not in kwargs, "cannot set causality on decoder"
super().__init__(causal=True, **kwargs)
class CrossAttender(AttentionLayers):
def __init__(self, **kwargs):
super().__init__(cross_attend=True, only_cross=True, **kwargs)
def create_latex_parameter(shape):
return paddle.create_parameter(
shape=shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Assign(paddle.randn(shape)),
)
class TransformerDecoder(nn.Layer):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.0,
emb_dropout=0.0,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True,
is_export=False,
):
super().__init__()
assert isinstance(
attn_layers, AttentionLayers
), "attention layers must be one of Encoder or Decoder"
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = (
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.is_export = is_export
self.init_()
self.to_logits = (
nn.Linear(dim, num_tokens)
if not tie_embedding
else lambda t: t @ self.token_emb.weight.t()
)
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = create_latex_parameter([num_memory_tokens, dim])
# let funnel encoder know number of memory tokens, if specified
# TODO: think of a cleaner solution
if hasattr(attn_layers, "num_memory_tokens"):
attn_layers.num_memory_tokens = num_memory_tokens
def init_(self):
normal_(self.token_emb.weight)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
seq_len=0,
mems=None,
**kwargs,
):
b, n, num_mem = *x.shape, self.num_memory_tokens
x = self.token_emb(x)
x = x + self.pos_emb(x)
x = self.emb_dropout(x)
x = self.project_emb(x)
x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, seq_len=seq_len, **kwargs
)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
out = self.to_logits(x) if not return_embeddings else x
if return_mems:
hiddens = intermediates.hiddens
new_mems = (
list(map(lambda pair: paddle.concat(pair, axis=-2), zip(mems, hiddens)))
if exists(mems)
else hiddens
)
new_mems = list(
map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
)
return out, new_mems
if return_attn:
attn_maps = list(
map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
)
return out, attn_maps
return out
def top_p(logits, thres=0.9):
sorted_logits, sorted_indices = paddle.sort(logits, descending=True)
cum_probs = paddle.cumsum(F.softmax(sorted_logits, axis=-1), axis=-1)
sorted_indices_to_remove = cum_probs > (1 - thres)
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
sorted_logits[sorted_indices_to_remove] = float("-inf")
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# topk
def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = paddle.topk(logits, k)
probs = paddle.full_like(logits, float("-inf"))
probs = paddle.put_along_axis(probs, ind, val, 1)
return probs
class LaTeXOCRHead(nn.Layer):
"""Implementation of LaTeX OCR decoder.
Args:
encoded_feat: The encoded features with shape[N, 1, H//16, W//16]
tgt_seq: LaTeX-OCR labels with shape [N, L] , L is the max sequence length
xi: The first N-1 LaTeX-OCR sequences in tgt_seq with shape [N, L-1]
mask: The first N-1 LaTeX-OCR attention mask with shape [N, L-1] , L is the max sequence length
Returns:
The predicted LaTeX sequences with shape [N, L-1, C], C is the number of LaTeX classes
"""
def __init__(
self,
net=None,
in_channels=256,
out_channels=256,
pad_value=0,
decoder_args=None,
is_export=False,
):
super().__init__()
decoder = Decoder(
dim=256, depth=4, heads=8, is_export=is_export, **decoder_args
)
transformer_decoder = TransformerDecoder(
num_tokens=8000,
max_seq_len=512,
attn_layers=decoder,
is_export=is_export,
)
self.temperature = 0.333
self.bos_token = 1
self.eos_token = 2
self.max_length = 512
self.pad_value = pad_value
self.net = transformer_decoder
self.max_seq_len = self.net.max_seq_len
self.is_export = is_export
@paddle.no_grad()
def generate(
self,
start_tokens,
seq_len,
eos_token=None,
temperature=1.0,
filter_logits_fn=top_k,
filter_thres=0.9,
**kwargs,
):
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
mask = kwargs.pop("mask", None)
if mask is None:
mask = paddle.full_like(out, True, dtype=paddle.bool)
for _ in range(seq_len):
x = out[:, -self.max_seq_len :]
mask = mask[:, -self.max_seq_len :]
logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, axis=-1)
else:
raise NotImplementedError("The filter_logits_fn is not supported ")
sample = paddle.multinomial(probs, 1)
out = paddle.concat((out, sample), axis=-1)
pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")
mask = paddle.concat((mask, pad_mask), axis=1)
if (
eos_token is not None
and (
paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1
).all()
):
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
@paddle.no_grad()
def generate_export(
self,
start_tokens,
seq_len,
eos_token=None,
context=None,
temperature=1.0,
filter_logits_fn=None,
filter_thres=0.9,
**kwargs,
):
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
mask = kwargs.pop("mask", None)
if mask is None:
mask = paddle.full_like(out, True, dtype=paddle.bool)
i_idx = paddle.full([], 0)
while i_idx < paddle.to_tensor(seq_len):
x = out[:, -self.max_seq_len :]
paddle.jit.api.set_dynamic_shape(x, [-1, -1])
mask = mask[:, -self.max_seq_len :]
paddle.jit.api.set_dynamic_shape(mask, [-1, -1])
logits = self.net(x, mask=mask, context=context, seq_len=i_idx, **kwargs)[
:, -1, :
]
if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, axis=-1)
sample = paddle.multinomial(probs, 1)
out = paddle.concat((out, sample), axis=-1)
pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")
mask = paddle.concat((mask, pad_mask), axis=1)
if (
eos_token is not None
and (
paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1
).all()
):
break
i_idx += 1
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
return out
# forward for export
def forward(self, inputs, targets=None):
if not self.training:
self.is_export = True
encoded_feat = inputs
batch_num = encoded_feat.shape[0]
bos_tensor = paddle.full([batch_num, 1], self.bos_token, dtype=paddle.int64)
if self.is_export:
word_pred = self.generate_export(
bos_tensor,
self.max_seq_len,
eos_token=self.eos_token,
context=encoded_feat,
temperature=self.temperature,
filter_logits_fn=top_k,
)
else:
word_pred = self.generate(
bos_tensor,
self.max_seq_len,
eos_token=self.eos_token,
context=encoded_feat,
temperature=self.temperature,
filter_logits_fn=top_k,
)
return word_pred
encoded_feat, tgt_seq, mask = inputs
kwargs = {"context": encoded_feat, "mask": mask.cast(paddle.bool)}
x = tgt_seq
xi = x[:, :-1]
mask = kwargs.get("mask", None)
if mask is not None and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
kwargs["mask"] = mask
out = self.net(xi, **kwargs)
return out