PaddleOCR/ppocr/modeling/heads/rec_unimernet_head.py

2674 lines
93 KiB
Python
Raw Normal View History

# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/opendatalab/UniMERNet/blob/main/unimernet/models/unimernet/configuration_unimernet_decoder.py
"""
import copy
import math
import re
import numpy as np
import inspect
import warnings
from collections import OrderedDict
from typing import Optional, Tuple, Union, List, Dict, Any
from dataclasses import dataclass, fields, is_dataclass
import paddle
import paddle.nn as nn
from paddle import Tensor
import paddle.nn.functional as F
from paddle.nn import CrossEntropyLoss
from paddle.nn.initializer import (
TruncatedNormal,
Constant,
Normal,
KaimingUniform,
XavierUniform,
XavierNormal,
)
zeros_ = Constant(value=0.0)
ones_ = Constant(value=1.0)
kaiming_normal_ = KaimingUniform(nonlinearity="relu")
trunc_normal_ = TruncatedNormal(std=0.02)
xavier_uniform_ = XavierUniform()
xavier_normal_ = XavierNormal()
class ModelOutput(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __post_init__(self):
class_fields = fields(self)
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(
f"{self.__class__.__name__} should not have more than one required field."
)
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(
getattr(self, field.name) is None for field in class_fields[1:]
)
if other_fields_are_none:
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
if first_field_iterator:
for idx, element in enumerate(iterator):
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
if idx == 0:
self[class_fields[0].name] = first_field
else:
raise ValueError(
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
)
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def setdefault(self, *args, **kwargs):
raise Exception(
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def pop(self, *args, **kwargs):
raise Exception(
f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
)
def update(self, *args, **kwargs):
raise Exception(
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
super().__setitem__(key, value)
super().__setattr__(key, value)
def __reduce__(self):
if not is_dataclass(self):
return super().__reduce__()
callable, _args, *remaining = super().__reduce__()
args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining
def to_tuple(self):
return tuple(self[k] for k in self.keys())
@dataclass
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
last_hidden_state = None
past_key_values = None
hidden_states = None
attentions = None
cross_attentions = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@dataclass
class Seq2SeqLMOutput(ModelOutput):
loss = None
logits = None
past_key_values = None
decoder_hidden_states = None
decoder_attentions = None
cross_attentions = None
encoder_last_hidden_state = None
encoder_hidden_states = None
encoder_attentions = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class MBartConfig(object):
model_type = "mbart"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
}
def __init__(
self,
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
dropout=0.1,
output_hidden_states=False,
use_return_dict=True,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
forced_eos_token_id=2,
_attn_implementation="eager",
hidden_size=1024,
use_parallel=False,
parallel_step=2,
is_export=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.output_hidden_states = output_hidden_states
self.use_return_dict = use_return_dict
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = (
scale_embedding # scale factor will be sqrt(d_model) if True
)
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.is_encoder_decoder = is_encoder_decoder
self.forced_eos_token_id = forced_eos_token_id
self._attn_implementation = _attn_implementation
self.use_parallel = use_parallel
self.parallel_step = parallel_step
self.is_export = is_export
super().__init__()
@dataclass
class AttentionMaskConverter:
"""
A utility class for converting attention masks used in transformer models.
This class handles the conversion of attention masks based on whether the
attention mechanism is causal (i.e., preventing information flow from future
tokens to past tokens) and whether a sliding window approach is used.
Attributes:
is_causal (bool): Indicates if the attention mechanism is causal.
sliding_window (Optional[int]): Specifies the size of the sliding window
for local attention, if applicable.
Args:
is_causal (bool): Determines if the attention mask should enforce causality.
sliding_window (Optional[int], optional): The size of the sliding window
for local attention. Default is None.
"""
is_causal: bool
sliding_window: int
def __init__(self, is_causal: bool, sliding_window=None):
self.is_causal = is_causal
self.sliding_window = sliding_window
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError(
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
)
@staticmethod
def _make_causal_mask(
input_ids_shape,
dtype,
past_key_values_length=0,
sliding_window=None,
is_export=False,
):
bsz, tgt_len = input_ids_shape
if is_export:
mask = paddle.full(
(tgt_len, tgt_len), paddle.finfo(dtype).min, dtype="float64"
)
else:
mask = paddle.full((tgt_len, tgt_len), paddle.finfo(dtype).min)
mask_cond = paddle.arange(mask.shape[-1])
mask = mask.masked_fill_(
mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
)
return mask[None, None, :, :].expand(
[bsz, 1, tgt_len, tgt_len + past_key_values_length]
)
def to_4d_export(
self,
attention_mask_2d,
query_length,
dtype,
key_value_length,
is_export=False,
):
input_shape = (attention_mask_2d.shape[0], query_length)
expanded_attn_mask = self._expand_mask(
attention_mask_2d, dtype, tgt_len=input_shape[-1]
)
expanded_4d_mask = expanded_attn_mask
return expanded_4d_mask
def to_4d(
self,
attention_mask_2d,
query_length,
dtype,
key_value_length,
is_export=False,
):
input_shape = (attention_mask_2d.shape[0], query_length)
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
is_export=is_export,
)
elif self.sliding_window is not None:
raise NotImplementedError(
"Sliding window is currently only implemented for causal masking"
)
expanded_attn_mask = self._expand_mask(
attention_mask_2d, dtype, tgt_len=input_shape[-1]
)
if causal_4d_mask is not None:
if is_export:
expanded_attn_mask = causal_4d_mask
return expanded_attn_mask
else:
expanded_attn_mask = causal_4d_mask.masked_fill_(
expanded_attn_mask.cast(paddle.bool), paddle.finfo(dtype).min
)
expanded_4d_mask = expanded_attn_mask
return expanded_4d_mask
def _expand_mask(self, mask, dtype, tgt_len=None):
bsz, src_len = mask.shape
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).cast(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill_(
inverted_mask.cast(paddle.bool), paddle.finfo(dtype).min
)
def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _prepare_4d_causal_attention_mask_export(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window=None,
is_export=False,
):
attn_mask_converter = AttentionMaskConverter(
is_causal=True, sliding_window=sliding_window
)
key_value_length = input_shape[-1] + past_key_values_length
shape = attention_mask.shape
len_shape = len(shape)
attention_mask = attn_mask_converter.to_4d_export(
attention_mask,
input_shape[-1],
key_value_length=key_value_length,
dtype=inputs_embeds.dtype,
is_export=is_export,
)
return attention_mask
def _prepare_4d_causal_attention_mask(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window=None,
is_export=False,
):
attn_mask_converter = AttentionMaskConverter(
is_causal=True, sliding_window=sliding_window
)
key_value_length = input_shape[-1] + past_key_values_length
shape = attention_mask.shape
len_shape = len(shape)
if (attention_mask is not None) and (len_shape == 2):
attention_mask = attn_mask_converter.to_4d(
attention_mask,
input_shape[-1],
key_value_length=key_value_length,
dtype=inputs_embeds.dtype,
is_export=is_export,
)
return attention_mask
elif attention_mask is not None and len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill_(
inverted_mask.to(paddle.bool), paddle.finfo(inputs_embeds.dtype).min
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0],
input_shape[-1],
key_value_length,
dtype=inputs_embeds.dtype,
)
return attention_mask
class MBartLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings, embedding_dim):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids, past_key_values_length=0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
positions = paddle.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=paddle.int64
).expand([bsz, -1])
return nn.Embedding.forward(self, positions + self.offset)
class MBartPreTrainedModel(nn.Layer):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
_supports_flash_attn_2 = True
def __init__(self, config):
super().__init__()
self.config = config
def _initialize_weights(self, module):
"""
Initialize the weights if they are not already initialized.
"""
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
def post_init(self):
self.apply(self._initialize_weights)
def _init_weights(self, module):
std = self.config.init_std
normal_ = Normal(mean=0.0, std=std)
if isinstance(module, nn.Linear):
normal_(module.weight)
if module.bias is not None:
zeros_(module.bias)
elif isinstance(module, nn.Embedding):
normal_(module.weight)
if module._padding_idx is not None:
zeros_(module.weight[module._padding_idx])
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = paddle.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
class MBartAttention(nn.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim,
num_heads,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config=None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
def _shape(self, tensor, seq_len, bsz):
return tensor.reshape([bsz, seq_len, self.num_heads, self.head_dim]).transpose(
[0, 2, 1, 3]
)
def forward(
self,
hidden_states,
key_value_states=None,
past_key_value=None,
attention_mask=None,
layer_head_mask=None,
output_attentions=False,
):
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = paddle.shape(hidden_states)
query_states = self.q_proj(hidden_states) * self.scaling
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = paddle.concat([past_key_value[0], key_states], axis=2)
value_states = paddle.concat([past_key_value[1], value_states], axis=2)
else:
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape)
key_states = key_states.reshape(proj_shape)
value_states = value_states.reshape(proj_shape)
src_len = key_states.shape[1]
attn_weights = paddle.bmm(query_states, key_states.transpose([0, 2, 1]))
if attention_mask is not None:
attn_weights = (
attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
+ attention_mask
)
attn_weights = attn_weights.reshape(
[bsz * self.num_heads, tgt_len, src_len]
)
attn_weights = nn.functional.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
if tuple(layer_head_mask.shape) != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of shape {(self.num_heads,)}, but is"
f" {layer_head_mask.shape}"
)
attn_weights = layer_head_mask.reshape(
[1, -1, 1, 1]
) * attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
attn_weights = attn_weights.reshape(
[bsz * self.num_heads, tgt_len, src_len]
)
if output_attentions:
attn_weights_reshaped = attn_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len]
)
attn_weights = attn_weights_reshaped.reshape(
[bsz * self.num_heads, tgt_len, src_len]
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = paddle.bmm(attn_probs, value_states)
attn_output = attn_output.reshape([bsz, self.num_heads, tgt_len, self.head_dim])
attn_output = attn_output.transpose([0, 2, 1, 3])
attn_output = attn_output.reshape([bsz, tgt_len, self.embed_dim])
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
MBART_ATTENTION_CLASSES = {
"eager": MBartAttention,
}
class MBartDecoderLayer(nn.Layer):
def __init__(self, config):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
config=config,
)
self.is_export = config.is_export
self.dropout = config.dropout
self.activation_fn = F.gelu
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> paddle.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None
)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
cross_attn_past_key_value = (
past_key_value[-2:] if past_key_value is not None else None
)
hidden_states, cross_attn_weights, cross_attn_present_key_value = (
self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
present_key_value = present_key_value + cross_attn_present_key_value
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(
hidden_states, p=self.activation_dropout, training=self.training
)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if self.is_export:
outputs += (present_key_value,)
else:
if use_cache:
outputs += (present_key_value,)
return outputs
class MBartForCausalLM(MBartPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
super().__init__(config)
self.model = MBartDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False)
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
labels = labels
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.reshape([-1, self.config.vocab_size]), labels.reshape([-1])
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
**kwargs,
):
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past
class myLayerNorm(nn.LayerNorm):
"""
Custom implementation of Layer Normalization, with additional options.
This class extends the standard LayerNorm to include optional features,
such as drop block regularization, which might be used for improving
model generalization.
Args:
num_channels (int): The number of features or channels in the input.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-5.
affine (bool, optional): If True, this module has learnable affine parameters (gamma and beta). Default is True.
drop_block (optional): Additional regularization technique that might be applied. Default is None.
"""
def __init__(
self,
num_channels,
eps=1e-5,
affine=True,
drop_block=None,
):
super(nn.LayerNorm, self).__init__()
self._epsilon = eps
self.num_channels = num_channels
if affine:
self.weight = paddle.create_parameter([num_channels], dtype="float32")
self.bias = paddle.create_parameter([num_channels], dtype="float32")
ones_(self.weight)
zeros_(self.bias)
def forward(self, x):
x = F.layer_norm(
x,
self.num_channels,
weight=self.weight,
bias=self.bias,
epsilon=self._epsilon,
)
return x
class MBartDecoder(MBartPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
Args:
config
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(
config.vocab_size, config.d_model, self.padding_idx
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = MBartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.LayerList(
[MBartDecoderLayer(config) for _ in range(config.decoder_layers)]
)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = myLayerNorm(config.d_model, affine=True)
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
self.is_export = config.is_export
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input = input_ids
input_shape = input.shape
input_ids = input_ids.reshape([-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
if self._use_flash_attention_2:
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
is_export=self.is_export,
)
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self._use_flash_attention_2:
encoder_attention_mask = (
encoder_attention_mask if 0 in encoder_attention_mask else None
)
else:
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions
positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
if self.gradient_checkpointing and self.training:
if use_cache:
print(
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = (
() if (output_attentions and encoder_hidden_states is not None) else None
)
next_decoder_cache = () if use_cache else None
for attn_mask, mask_name in zip(
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
):
if attn_mask is not None:
if attn_mask.shape[0] != len(self.layers):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {attn_mask.shape[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = paddle.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
(
cross_attn_head_mask[idx]
if cross_attn_head_mask is not None
else None
),
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx]
if cross_attn_head_mask is not None
else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class MBartDecoderWrapper(MBartPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the [`EncoderDecoderModel`] framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = MBartDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
def _in_projection(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
w_q: paddle.Tensor,
w_k: paddle.Tensor,
w_v: paddle.Tensor,
b_q: Optional[paddle.Tensor] = None,
b_k: Optional[paddle.Tensor] = None,
b_v: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
assert w_q.shape == (
Eq,
Eq,
), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
assert w_k.shape == (
Eq,
Ek,
), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
assert w_v.shape == (
Eq,
Ev,
), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
assert b_q is None or b_q.shape == (
Eq,
), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
assert b_k is None or b_k.shape == (
Eq,
), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
assert b_v is None or b_v.shape == (
Eq,
), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
def _scaled_dot_product_attention(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
attn_mask: Optional[paddle.Tensor] = None,
dropout_p: float = 0.0,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
B, Nt, E = q.shape
q = q / math.sqrt(E)
attn = paddle.bmm(q, k.transpose([0, 2, 1]))
if attn_mask is not None:
attn += attn_mask
attn = F.softmax(attn, axis=-1)
if dropout_p > 0.0:
attn = F.dropout(attn, p=dropout_p)
output = paddle.bmm(attn, v)
return output, attn
def linear(x, w, b, is_transpose):
if b is not None:
return paddle.matmul(x, w, transpose_y=is_transpose) + b
else:
return paddle.matmul(x, w, transpose_y=is_transpose)
def _in_projection_packed(
q: Tensor,
k: Tensor,
v: Tensor,
w: Tensor,
b: Optional[Tensor] = None,
is_export=False,
) -> List[Tensor]:
E = paddle.shape(q)[-1]
if k is v:
if q is k:
proj = linear(q, w, b, is_transpose=True)
if is_export:
B, D, L = paddle.shape(proj)
proj = proj.reshape([B, D, 3, E])
proj = (
proj.unsqueeze(0)
.transpose([3, 1, 2, 0, 4])
.squeeze(-2)
.contiguous()
)
else:
proj = (
proj.unflatten(-1, (3, E))
.unsqueeze(0)
.transpose([3, 1, 2, 0, 4])
.squeeze(-2)
.contiguous()
)
return proj[0], proj[1], proj[2]
else:
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
def multi_head_attention_forward(
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: paddle.Tensor,
in_proj_bias: Optional[paddle.Tensor],
bias_k: Optional[paddle.Tensor],
bias_v: Optional[paddle.Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: paddle.Tensor,
out_proj_bias: Optional[paddle.Tensor],
training: bool = True,
key_padding_mask: Optional[paddle.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[paddle.Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[paddle.Tensor] = None,
k_proj_weight: Optional[paddle.Tensor] = None,
v_proj_weight: Optional[paddle.Tensor] = None,
static_k: Optional[paddle.Tensor] = None,
static_v: Optional[paddle.Tensor] = None,
is_export=False,
):
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
if isinstance(embed_dim, paddle.Tensor):
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
q, k, v = _in_projection_packed(
query, key, value, in_proj_weight, in_proj_bias, is_export
)
if key_padding_mask is not None and key_padding_mask.dtype == paddle.uint8:
warnings.warn(
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(paddle.bool)
if bias_k is not None and bias_v is not None: # False
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = paddle.concat([k, bias_k.repeat(1, bsz, 1)])
v = paddle.concat([v, bias_v.repeat(1, bsz, 1)])
else:
assert bias_k is None
assert bias_v is None
q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])
if static_k is None: # True
k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
else:
assert (
static_k.shape[0] == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.shape[0]}"
assert (
static_k.shape[2] == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.shape[2]}"
k = static_k
if static_v is None: # True
v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
else:
assert (
static_v.shape[0] == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.shape[0]}"
assert (
static_v.shape[2] == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.shape[2]}"
v = static_v
src_len = k.shape[1]
if not training:
dropout_p = 0.0
attn_output, attn_output_weights = _scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p
)
attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
attn_output = linear(
attn_output, out_proj_weight, out_proj_bias, is_transpose=False
)
if need_weights:
attn_output_weights = attn_output_weights.reshape(
[bsz, num_heads, tgt_len, src_len]
)
return attn_output, attn_output_weights.sum(axis=1) / num_heads
else:
return attn_output, None
class MyMultiheadAttention(nn.Layer):
"""
Custom implementation of a multi-head attention layer.
Attributes:
__constants__ (list): List of constant attributes.
bias_k (Optional[paddle.Tensor]): Optional tensor for key bias.
bias_v (Optional[paddle.Tensor]): Optional tensor for value bias.
Args:
embed_dim (int): Total dimension of the model. This is the size of the input feature vectors.
num_heads (int): Number of parallel attention heads. The input dimension must be divisible by the number of heads.
dropout (float, optional): Dropout probability on the attention weights. Default is 0.0.
bias (bool, optional): If True, adds a learnable bias to the output. Default is True.
add_bias_kv (bool, optional): If True, adds bias to the key and value sequences. Default is False.
add_zero_attn (bool, optional): If True, adds a zero attention head. Default is False.
kdim (int, optional): Total number of features for keys. If None, defaults to embed_dim.
vdim (int, optional): Total number of features for values. If None, defaults to embed_dim.
batch_first (bool, optional): If True, the input and output tensors are provided as (batch, seq, feature). Default is False.
device (optional): The device on which the layer's parameters should be initialized. Default is None.
dtype (optional): The data type for the parameters. Default is None.
is_export (bool, optional): If True, the layer is set up for export, potentially changing behavior for compatibility. Default is False.
"""
__constants__ = ["batch_first"]
bias_k: Optional[paddle.Tensor]
bias_v: Optional[paddle.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
is_export=False,
) -> None:
super(MyMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
self.is_export = is_export
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if self._qkv_same_embed_dim is False:
pass
else:
if dtype is None:
dtype = paddle.float32
self.in_proj_weight = paddle.create_parameter(
(3 * embed_dim, embed_dim), dtype
)
self.q_proj_weight = None
self.k_proj_weight = None
self.v_proj_weight = None
if bias:
self.in_proj_bias = paddle.create_parameter((3 * embed_dim,), dtype)
zeros_(self.in_proj_bias)
else:
self.in_proj_bias = None
self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
if add_bias_kv:
pass
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
zeros_(self.in_proj_bias)
zeros_(self.out_proj.bias)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
key_padding_mask: Optional[paddle.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
attn_output, attn_output_weights = multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
is_export=self.is_export,
)
return attn_output, attn_output_weights
class LogitsProcessorList(list):
"""
A list of logits processors that can be applied sequentially.
Methods:
__call__(input_ids, scores, **kwargs): Apply all processors to the given inputs.
"""
def __call__(self, input_ids, scores, **kwargs):
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
class ForcedEOSTokenLogitsProcessor(object):
"""
A processor that forces the generation of an end-of-sequence (EOS) token
at a specified position in the sequence.
This is typically used in language generation tasks to ensure that the
generated sequence ends properly when it reaches a certain length.
Args:
max_length (int): The maximum length of the sequence. Forces EOS when this length is reached.
eos_token_id (Union[int, List[int]]): The ID(s) of the EOS token(s) to be forced in the sequence.
"""
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
def __call__(self, input_ids, scores):
cur_len = input_ids.shape[-1]
scores_processed = scores
if cur_len == self.max_length - 1:
scores_processed = paddle.full_like(scores, -math.inf)
scores_processed[:, self.eos_token_id] = 0
return scores_processed
@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
loss = None
logits = None
past_key_values = None
hidden_states = None
attentions = None
cross_attentions = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@dataclass
class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
"""
logits = None
counting = None
past_key_values = None
hidden_states = None
attentions = None
cross_attentions = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class CustomMBartDecoder(MBartDecoder):
"""
A custom MBartDecoder that includes additional processing layers.
This class extends the MBartDecoder by adding a customizable neural network
component called `counting_context_weight`, which applies a series of linear
transformations followed by ReLU activations. This can be used to modify or
enhance the decoder's behavior for specific tasks.
Args:
config: The configuration object containing model parameters.
"""
def __init__(self, config):
super().__init__(config)
hidden_size = config.d_model
self.is_export = config.is_export
self.counting_context_weight = nn.Sequential(
nn.Linear(config.vocab_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, config.d_model),
)
def forward(
self,
input_ids=None,
attention_mask=None,
count_pred=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
self.is_export = False if self.training else True
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input = input_ids
input_shape = input.shape
input_ids = input_ids.reshape([-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
if self._use_flash_attention_2:
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
if self.is_export:
attention_mask = _prepare_4d_causal_attention_mask_export(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
is_export=self.is_export,
).cast(paddle.float32)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
is_export=self.is_export,
)
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self._use_flash_attention_2:
encoder_attention_mask = (
encoder_attention_mask if 0 in encoder_attention_mask else None
)
else:
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions
positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions
# TODO: add counting context weight to hidden_states
if count_pred is not None:
count_context_weight = self.counting_context_weight(count_pred)
hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
if self.gradient_checkpointing and self.training:
if use_cache:
print(
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = (
() if (output_attentions and encoder_hidden_states is not None) else None
)
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip(
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
):
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {attn_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = paddle.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
(
cross_attn_head_mask[idx]
if cross_attn_head_mask is not None
else None
),
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx]
if cross_attn_head_mask is not None
else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if self.is_export:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
else:
if use_cache:
next_decoder_cache += (
layer_outputs[3 if output_attentions else 1],
)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.is_export:
next_cache = next_decoder_cache
else:
next_cache = next_decoder_cache if use_cache else None
if not self.is_export:
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class SelfAttentionBlock(nn.Layer):
"""
A self-attention block that implements multi-head self-attention
followed by a feed-forward network, typically used in transformer architectures.
Args:
embed_size (int): The size of the embedding vector.
num_heads (int): The number of attention heads.
is_export (bool): Flag indicating whether to configure the layer for export.
"""
def __init__(self, embed_size, num_heads, is_export):
super(SelfAttentionBlock, self).__init__()
self.self_attention = MyMultiheadAttention(
embed_dim=embed_size, num_heads=num_heads, is_export=is_export
)
self.norm = nn.LayerNorm(embed_size)
def forward(self, x):
attn_output, _ = self.self_attention(x, x, x)
x = self.norm(attn_output + x)
return x
class SeqCountingDecoder(nn.Layer):
"""
A custom sequence counting decoder that incorporates multi-head attention layers
and feed-forward networks to process sequences, potentially for latex code counting .
Args:
in_features (int): The number of input features.
out_features (int): The number of output features.
num_heads (int): The number of attention heads. Defaults to 8.
num_layers (int): The number of attention layers. Defaults to 4.
is_export (bool): Flag indicating whether to configure the layer for export.
"""
def __init__(
self, in_features, out_features, num_heads=8, num_layers=4, is_export=False
):
super(SeqCountingDecoder, self).__init__()
self.attention_blocks = nn.LayerList(
[
SelfAttentionBlock(
embed_size=in_features, num_heads=num_heads, is_export=is_export
)
for i in range(num_layers)
]
)
self.fc1 = nn.Linear(in_features, in_features // 2)
self.relu = nn.ReLU()
self.global_avg_pool = nn.AdaptiveAvgPool1D(1)
self.fc2 = nn.Linear(in_features // 2, out_features)
def forward(self, x):
for block in self.attention_blocks:
x = block(x)
x = self.fc1(x)
x = self.relu(x)
x = x.transpose([0, 2, 1])
x = self.global_avg_pool(x)
x = x.squeeze(-1)
x = self.fc2(x)
return x
class CustomMBartForCausalLM(MBartForCausalLM):
"""
Custom MBart model for causal language modeling with a custom decoder.
This class extends the MBartForCausalLM by replacing its decoder with a
custom decoder, allowing for additional flexibility and features in the
decoding process.
Args:
config: The configuration object containing model parameters.
length_aware (bool): A flag to enable or configure length-aware mechanisms.
"""
def __init__(self, config, length_aware=True):
super().__init__(config)
self.model.decoder = CustomMBartDecoder(config)
self.counting_decoder = SeqCountingDecoder(
config.d_model, config.vocab_size, is_export=config.is_export
)
self.length_aware = length_aware
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
count_gt=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if self.length_aware:
count_pred = self.counting_decoder(encoder_hidden_states)
else:
count_pred = None
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
count_pred=count_pred,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
return CausalLMOutputWithCrossAttentionsAndCounting(
logits=logits,
counting=count_pred,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class UniMERNetHead(nn.Layer):
"""Implementation of UniMERNetHead decoder.
Args:
max_new_tokens (int): Maximum number of new tokens to generate.
decoder_start_token_id (int): ID of the token that starts the decoding.
temperature (float): Sampling temperature for generation.
do_sample (bool): Whether to use sampling; if False, uses greedy decoding.
top_p (float): Top-p (nucleus) sampling parameter.
in_channels (int): Number of input channels/features.
encoder_hidden_size (int): Hidden size of the encoder.
decoder_hidden_size (int): Hidden size of the decoder.
decoder_ffn_dim (int): Dimension of the decoder's feed-forward network.
decoder_layers (int): Number of layers in the decoder.
is_export (bool): Flag indicating if the model is being prepared for export.
length_aware (bool): Flag to enable length-aware mechanisms.
"""
def __init__(
self,
max_new_tokens=1536,
decoder_start_token_id=0,
temperature=0.2,
do_sample=False,
top_p=0.95,
in_channels=1024,
encoder_hidden_size=1024,
decoder_hidden_size=1024,
decoder_ffn_dim=4096,
decoder_layers=8,
is_export=False,
length_aware=True,
):
super().__init__()
mbart_config_dict = {
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": True,
"add_final_layer_norm": True,
"attention_dropout": 0.0,
"bos_token_id": 0,
"classifier_dropout": 0.0,
"d_model": decoder_hidden_size,
"decoder_attention_heads": 16,
"decoder_ffn_dim": decoder_ffn_dim,
"decoder_layerdrop": 0.0,
"decoder_layers": decoder_layers,
"dropout": 0.1,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"init_std": 0.02,
"is_decoder": True,
"is_encoder_decoder": False,
"output_hidden_states": False,
"max_position_embeddings": max_new_tokens,
"model_type": "mbart",
"num_hidden_layers": 12,
"pad_token_id": 1,
"scale_embedding": True,
"tie_word_embeddings": False,
"transformers_version": "4.40.0",
"use_cache": True,
"use_return_dict": True,
"vocab_size": 50000,
"_attn_implementation": "eager",
"hidden_size": decoder_hidden_size,
"is_export": is_export,
}
self.max_new_tokens = max_new_tokens
self.decoder_start_token_id = decoder_start_token_id
self.temperature = temperature
self.do_sample = do_sample
self.top_p = top_p
self.max_seq_len = max_new_tokens
self.config_decoder = MBartConfig(**mbart_config_dict)
self.encoder_hidden_size = encoder_hidden_size
self.is_export = self.config_decoder.is_export
self.decoder = CustomMBartForCausalLM(
self.config_decoder, length_aware=length_aware
)
if self.config_decoder.hidden_size != self.encoder_hidden_size:
self.enc_to_dec_proj = nn.Linear(
self.encoder_hidden_size, self.config_decoder.hidden_size
)
generation_config = {
"max_length": 1537,
"forced_eos_token_id": 2,
}
self.eos_token_id = generation_config["forced_eos_token_id"]
self.pad_token_id = self.config_decoder.pad_token_id
self.logits_processor = LogitsProcessorList()
self.logits_processor.append(
ForcedEOSTokenLogitsProcessor(
generation_config["max_length"],
generation_config["forced_eos_token_id"],
)
)
def _get_decoder_start_token_id(
self, decoder_start_token_id=None, bos_token_id=None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = (
bos_token_id
if bos_token_id is not None
else self.generation_config.bos_token_id
)
if decoder_start_token_id is not None:
return decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
def _prepare_decoder_input_ids_for_generation(
self,
batch_size,
model_kwargs,
decoder_start_token_id=None,
bos_token_id=None,
):
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("input_ids")
else:
decoder_input_ids = None
decoder_start_token_id = self._get_decoder_start_token_id(
decoder_start_token_id, bos_token_id
)
if isinstance(decoder_start_token_id, list):
if len(decoder_start_token_id) != batch_size:
raise ValueError(
f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}"
)
decoder_input_ids_start = paddle.to_tensor(
decoder_start_token_id,
dtype=paddle.int64,
)
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
else:
decoder_input_ids_start = (
paddle.ones(
(batch_size, 1),
dtype=paddle.int64,
)
* decoder_start_token_id
)
if decoder_input_ids is None:
decoder_input_ids = decoder_input_ids_start
elif (
self.config.model_type == "vision-encoder-decoder"
and "donut" in self.name_or_path.lower()
):
pass
elif self.config.model_type in ["whisper"]:
pass
elif (
isinstance(decoder_start_token_id, int)
and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
) or (
isinstance(decoder_start_token_id, paddle.Tensor)
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
):
decoder_input_ids = paddle.concat(
[decoder_input_ids_start, decoder_input_ids], axis=-1
)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = paddle.cat(
(
paddle.ones_like(decoder_attention_mask)[:, :1],
decoder_attention_mask,
),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
return decoder_input_ids, model_kwargs
def prepare_inputs_for_generation_mbart(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
**kwargs,
):
if attention_mask is None:
attention_mask = paddle.ones(input_ids.shape)
if past_key_values:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
decoder_inputs = self.prepare_inputs_for_generation_mbart(
input_ids, past_key_values=past_key_values
)
decoder_attention_mask = (
decoder_inputs["attention_mask"]
if "attention_mask" in decoder_inputs
else None
)
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
}
return input_dict
def prepare_inputs_for_generation_export(
self,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
input_dict = {
"decoder_attention_mask": None,
"use_cache": use_cache,
}
return input_dict
def _extract_past_from_model_output(
self, outputs: ModelOutput, standardize_cache_format: bool = False
):
past_key_values = None
if "past_key_values" in outputs:
past_key_values = outputs.past_key_values
elif "mems" in outputs:
past_key_values = outputs.mems
elif "past_buckets_states" in outputs:
past_key_values = outputs.past_buckets_states
return past_key_values
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = paddle.concat(
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1
)
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = paddle.concat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
axis=-1,
)
else:
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = paddle.concat(
[
decoder_attention_mask,
decoder_attention_mask.new_ones(
(decoder_attention_mask.shape[0], 1)
),
],
axis=-1,
)
if (
"cache_position" in model_kwargs
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
return model_kwargs
def stopping_criteria(self, input_ids):
if self.is_export:
return input_ids[:, -1] == paddle.to_tensor([self.eos_token_id])
is_done = paddle.isin(input_ids[:, -1], paddle.to_tensor([self.eos_token_id]))
return is_done
def generate_single_iter(
self,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
encoder_hidden_states = encoder_outputs[0]
if self.config_decoder.hidden_size != self.encoder_hidden_size:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
kwargs_decoder = {}
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
inputs_embeds=None,
output_attentions=False,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
return Seq2SeqLMOutput(
loss=None,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@paddle.no_grad()
def generate(
self,
model_kwargs,
):
"""
Generate sequences using the UniMERNetHead for inference tasks.
Args:
model_kwargs (dict): A dictionary of model configurations and inputs, which typically include:
- encoder_outputs: Outputs from the encoder.
- use_cache: Boolean flag to indicate if caching should be used.
- output_attentions: Boolean flag for outputting attention scores.
- output_hidden_states: Boolean flag for outputting hidden states.
Returns:
A tensor containing the generated sequences.
"""
batch_size = model_kwargs["encoder_outputs"]["last_hidden_state"].shape[0]
generation_config = {
"decoder_start_token_id": 0,
"bos_token_id": 0,
}
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config["decoder_start_token_id"],
bos_token_id=generation_config["bos_token_id"],
)
model_kwargs["key use_cache"] = True
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = paddle.arange(cur_len)
pad_token_id = self.pad_token_id
eos_token_id = [self.eos_token_id]
eos_token = self.eos_token_id
unfinished_sequences = paddle.ones(batch_size, dtype=paddle.int64)
for idx in range(self.max_seq_len):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self.generate_single_iter(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
next_tokens = paddle.argmax(next_tokens_scores, axis=-1)
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
input_ids = paddle.concat([input_ids, next_tokens[:, None]], axis=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config_decoder.is_encoder_decoder,
)
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
input_ids
).cast(paddle.int64)
if (
eos_token is not None
and (
paddle.cumsum((input_ids == eos_token).cast(paddle.int64), 1)[:, -1]
>= 1
).all()
):
break
return input_ids
@paddle.no_grad()
def generate_export(
self,
encoder_outputs,
model_kwargs,
):
batch_size = encoder_outputs["last_hidden_state"].shape[0]
generation_config = {
"decoder_start_token_id": 0,
"bos_token_id": 0,
}
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config["decoder_start_token_id"],
bos_token_id=generation_config["bos_token_id"],
)
input_ids = input_ids.reshape([-1, 1])
decoder_input_ids = input_ids
model_kwargs["key use_cache"] = True
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
cache_position = paddle.arange(cur_len)
pad_token_id = self.pad_token_id
eos_token_id = [self.eos_token_id]
eos_token = self.eos_token_id
unfinished_sequences = paddle.ones([batch_size], dtype=paddle.int64)
i_idx = paddle.full([], 0)
past_key_values = []
for i in range(8):
init_arr = paddle.zeros([batch_size, 16, 0, 64])
paddle.jit.api.set_dynamic_shape(init_arr, [-1, -1, -1, -1])
cache = (init_arr, init_arr, init_arr, init_arr)
past_key_values.append(cache)
idx = 0
while i_idx < paddle.to_tensor(self.max_seq_len):
model_inputs = self.prepare_inputs_for_generation_export(
past_key_values=past_key_values, **model_kwargs
)
decoder_attention_mask = model_inputs["decoder_attention_mask"]
decoder_attention_mask = paddle.ones(input_ids.shape)
paddle.jit.api.set_dynamic_shape(decoder_input_ids, [-1, -1])
paddle.jit.api.set_dynamic_shape(decoder_attention_mask, [-1, -1])
outputs = self.generate_single_iter(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
next_tokens = paddle.argmax(next_tokens_scores, axis=-1)
if eos_token_id is not None:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
input_ids = paddle.concat([input_ids, next_tokens.unsqueeze(1)], axis=-1)
past_length = past_key_values[0][0].shape[2]
decoder_input_ids = next_tokens.unsqueeze(1)
past_key_values = outputs.past_key_values
cache_position = cache_position[-1:] + 1
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
input_ids
).cast(paddle.int64)
if (
eos_token is not None
and (
paddle.cumsum((input_ids == eos_token).cast(paddle.int64), 1)[:, -1]
>= 1
).all()
):
break
i_idx += 1
return input_ids
def forwad_train(
self,
encoder_outputs,
decoder_input_ids,
decoder_attention_mask,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
"""
Training for the UniMERNetHead.
Args:
encoder_outputs: Outputs from the encoder, used as input to the decoder.
decoder_input_ids: Input IDs for the decoder.
decoder_attention_mask: Attention mask for the decoder inputs.
past_key_values: Cached key/values for faster decoding.
decoder_inputs_embeds: Optional embeddings for the decoder inputs.
labels: Target labels for calculating loss.
use_cache: Whether to use cache during decoding.
output_attentions: Whether to return attention scores.
output_hidden_states: Whether to return hidden states.
return_dict: Whether to return a dictionary of outputs.
**kwargs: Additional keyword arguments.
Returns:
logits: The raw, unnormalized predictions from the model.
count_pred: Optional prediction related to sequence length or other counts.
masked_labels: The labels used during training, possibly masked.
"""
labels = decoder_input_ids * 1
labels = labels.masked_fill_(labels == self.pad_token_id, -100)
input_decoder_input_ids = decoder_input_ids[:, :-1]
input_decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_hidden_states = encoder_outputs[0]
if self.config_decoder.hidden_size != self.encoder_hidden_size:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
kwargs_decoder = {}
decoder_outputs = self.decoder(
input_ids=input_decoder_input_ids,
attention_mask=input_decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
inputs_embeds=None,
output_attentions=False,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
logits = decoder_outputs.logits
count_pred = decoder_outputs.counting
return logits, count_pred, labels
def forward(self, inputs, targets=None):
"""
Forward pass for the UniMERNetHead, handling both training and inference.
Args:
inputs: The input data, which can vary based on training or inference.
targets: The target labels, used only during training.
Returns:
During inference: Returns predicted latex code.
During training: Returns logits, predicted counts, and masked labels.
"""
self.is_export = False if self.training else True
if not self.training:
encoder_outputs = inputs
if self.is_export:
model_kwargs = {
"output_attentions": False,
"output_hidden_states": False,
"use_cache": True,
}
word_pred = self.generate_export(encoder_outputs, model_kwargs)
else:
model_kwargs = {
"output_attentions": False,
"output_hidden_states": False,
"use_cache": True,
"encoder_outputs": encoder_outputs,
}
word_pred = self.generate(model_kwargs)
return word_pred
encoder_outputs, tgt_seq, mask = inputs
logits, count_pred, masked_labels = self.forwad_train(
encoder_outputs, tgt_seq, mask
)
return logits, count_pred, masked_labels