PaddleOCR/ppocr/modeling/heads/rec_ppformulanet_head.py

1390 lines
52 KiB
Python

# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import re
import numpy as np
import inspect
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import CrossEntropyLoss
from paddle import Tensor
from collections import OrderedDict
from typing import Optional, Tuple, Union, List, Dict, Any
from dataclasses import dataclass, fields, is_dataclass
from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModelOutput
from ppocr.modeling.heads.rec_unimernet_head import (
MBartForCausalLM,
MBartDecoder,
MBartConfig,
ModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
zeros_,
ones_,
kaiming_normal_,
trunc_normal_,
xavier_uniform_,
CausalLMOutputWithCrossAttentions,
LogitsProcessorList,
ForcedEOSTokenLogitsProcessor,
UniMERNetHead,
)
@dataclass
class AttentionMaskConverter:
"""
A class to convert attention masks based on specific configurations.
This class is designed to handle the conversion of attention masks with options for causal masking
and sliding window attention, which are commonly used in transformer models.
Attributes:
is_causal (bool): Flag indicating whether the attention mask should enforce causal masking,
which ensures each position can only attend to previous positions.
sliding_window (int, optional): Size of the sliding window for local attention. If set,
attention is restricted to a local window of this size.
"""
is_causal: bool
sliding_window: int
def __init__(self, is_causal: bool, sliding_window=None):
self.is_causal = is_causal
self.sliding_window = sliding_window
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError(
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
)
@staticmethod
def _make_causal_mask(
input_ids_shape,
dtype,
past_key_values_length=0,
sliding_window=None,
is_export=False,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
if is_export:
mask = paddle.full(
(tgt_len, tgt_len), paddle.finfo(dtype).min, dtype="float64"
)
mask_cond = paddle.arange(mask.shape[-1])
mask.masked_fill_(
mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
)
else:
mask = paddle.full((tgt_len, tgt_len), paddle.finfo(dtype).min)
mask_cond = paddle.arange(mask.shape[-1])
mask.masked_fill_(
mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
)
mask = mask.cast(dtype)
if past_key_values_length > 0:
mask = paddle.concat(
[paddle.zeros(tgt_len, past_key_values_length, dtype=dtype), mask],
axis=-1,
)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window - 1
context_mask = paddle.tril(
paddle.ones_like(mask, dtype=paddle.bool), diagonal=diagonal
)
mask.masked_fill_(context_mask, paddle.finfo(dtype).min)
return mask[None, None, :, :].expand(
[bsz, 1, tgt_len, tgt_len + past_key_values_length]
)
@staticmethod
def _make_causal_mask_parallel(
input_ids_shape,
dtype,
past_key_values_length=0,
sliding_window=None,
parallel_step=1,
is_export=False,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = paddle.full((tgt_len, tgt_len), paddle.finfo(dtype).min)
mask_cond = paddle.arange(mask.shape[-1])
mask_cond_parallel = paddle.arange(mask.shape[-1])
mask_parallel = paddle.arange(0, tgt_len, step=parallel_step).reshape([1, -1])
mask_parallel = paddle.repeat_interleave(mask_parallel, parallel_step, 1)[
:, :tgt_len
]
mask.masked_fill_(
mask_cond < (mask_parallel + parallel_step).reshape([mask.shape[-1], 1]), 0
)
mask = mask.cast(dtype)
if past_key_values_length > 0:
mask = paddle.concat(
[paddle.zeros([tgt_len, past_key_values_length], dtype=dtype), mask],
axis=-1,
)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window - 1
context_mask = paddle.tril(
paddle.ones_like(mask, dtype=paddle.bool), diagonal=diagonal
)
mask.masked_fill_(context_mask, paddle.finfo(dtype).min)
return mask[None, None, :, :].expand(
[bsz, 1, tgt_len, tgt_len + past_key_values_length]
)
def to_4d(
self,
attention_mask_2d,
query_length,
dtype,
key_value_length,
use_parallel=False,
parallel_step=3,
is_export=False,
):
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
causal_4d_mask = None
if use_parallel:
step = parallel_step
else:
step = 1
if (
input_shape[-1] > step or self.sliding_window is not None
) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
if use_parallel:
causal_4d_mask = self._make_causal_mask_parallel(
input_shape,
dtype,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
parallel_step=parallel_step,
is_export=is_export,
)
else:
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
is_export=is_export,
)
elif self.sliding_window is not None:
raise NotImplementedError(
"Sliding window is currently only implemented for causal masking"
)
expanded_attn_mask = self._expand_mask(
attention_mask_2d, dtype, tgt_len=input_shape[-1]
)
if causal_4d_mask is not None:
expanded_attn_mask = causal_4d_mask.masked_fill_(
expanded_attn_mask.cast(paddle.bool), paddle.finfo(dtype).min
)
expanded_4d_mask = expanded_attn_mask
return expanded_4d_mask
def to_4d_export(
self,
attention_mask_2d,
query_length,
dtype,
key_value_length,
use_parallel=False,
parallel_step=3,
is_export=False,
):
input_shape = (attention_mask_2d.shape[0], query_length)
expanded_attn_mask = self._expand_mask_export(
attention_mask_2d, dtype, tgt_len=input_shape[-1]
)
expanded_4d_mask = expanded_attn_mask
return expanded_4d_mask
def _expand_mask(self, mask, dtype, tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.shape
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).cast(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill_(
inverted_mask.cast(paddle.bool), paddle.finfo(dtype).min
)
def _expand_mask_export(self, mask, dtype, tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = paddle.shape(mask)
expanded_mask = (
mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).cast(dtype)
)
paddle.jit.api.set_dynamic_shape(expanded_mask, [-1, -1, -1, -1])
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill_(
inverted_mask.cast(paddle.bool), paddle.finfo(dtype).min
)
def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _prepare_4d_causal_attention_mask(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window=None,
use_parallel=False,
parallel_step=3,
is_export=False,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
attention_mask (`paddle.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
input_shape (`tuple(int)` or `list(int)` or `paddle.Size`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
inputs_embeds (`paddle.Tensor`):
The embedded inputs as a paddle Tensor.
past_key_values_length (`int`):
The length of the key value cache.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
"""
attn_mask_converter = AttentionMaskConverter(
is_causal=True, sliding_window=sliding_window
)
key_value_length = input_shape[-1] + past_key_values_length
# 4d mask is passed through the layers
if attention_mask is not None and len(attention_mask.shape) == 2:
attention_mask = attn_mask_converter.to_4d(
attention_mask,
input_shape[-1],
key_value_length=key_value_length,
dtype=inputs_embeds.dtype,
use_parallel=use_parallel,
parallel_step=parallel_step,
is_export=is_export,
)
elif attention_mask is not None and len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill_(
inverted_mask.to(paddle.bool), paddle.finfo(inputs_embeds.dtype).min
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0],
input_shape[-1],
key_value_length,
dtype=inputs_embeds.dtype,
)
return attention_mask
def _prepare_4d_causal_attention_mask_export(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window=None,
use_parallel=False,
parallel_step=3,
is_export=False,
):
"""
Prepare a 4D causal attention mask for export.
This function prepares a 4-dimensional causal attention mask, which is used to ensure that each position in the
sequence can only attend to previous positions. It is specifically designed to handle scenarios where the model
is being exported, potentially with additional options like sliding window or parallel processing.
Args:
attention_mask: The initial attention mask, typically used to avoid attending to padding tokens.
input_shape: Shape of the input tensor, usually in the form (batch_size, sequence_length).
inputs_embeds: Embeddings of the input sequence, used to derive certain dimensions if needed.
past_key_values_length: Length of past key values, used in contexts like transformer decoders with caching.
sliding_window: Optional parameter. If provided, specifies the size of a sliding window for local attention.
use_parallel: Flag indicating whether to use parallel processing for attention computation.
parallel_step: Number of steps to use in parallel processing, relevant if `use_parallel` is True.
is_export: Flag indicating whether the attention mask is being prepared for model export.
Returns:
A 4D causal attention mask suitable for use in transformer models, ensuring correct causal masking.
"""
attn_mask_converter = AttentionMaskConverter(
is_causal=True, sliding_window=sliding_window
)
key_value_length = input_shape[-1] + past_key_values_length
shape = attention_mask.shape
len_shape = len(shape)
attention_mask = attn_mask_converter.to_4d_export(
attention_mask,
input_shape[-1],
key_value_length=key_value_length,
dtype=inputs_embeds.dtype,
use_parallel=use_parallel,
parallel_step=parallel_step,
is_export=is_export,
)
return attention_mask
class CustomMBartDecoder(MBartDecoder):
def __init__(self, config):
super().__init__(config)
hidden_size = config.d_model
self.is_export = config.is_export
self.config_decoder = config
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input = input_ids
input_shape = input.shape
input_ids = input_ids.reshape([-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
input = inputs_embeds[:, :, -1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
# 4d mask is passed through the layers
if self.is_export:
attention_mask = _prepare_4d_causal_attention_mask_export(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
use_parallel=self.config_decoder.use_parallel,
parallel_step=self.config_decoder.parallel_step,
is_export=self.is_export,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
use_parallel=self.config_decoder.use_parallel,
parallel_step=self.config_decoder.parallel_step,
is_export=self.is_export,
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self._use_flash_attention_2:
encoder_attention_mask = (
encoder_attention_mask if 0 in encoder_attention_mask else None
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions
positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
if self.gradient_checkpointing and self.training:
if use_cache:
print(
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = (
() if (output_attentions and encoder_hidden_states is not None) else None
)
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip(
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
):
if attn_mask is not None:
if attn_mask.size()[0] != len(self.layers):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {attn_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = paddle.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
(
cross_attn_head_mask[idx]
if cross_attn_head_mask is not None
else None
),
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx]
if cross_attn_head_mask is not None
else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if self.is_export:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
else:
if use_cache:
next_decoder_cache += (
layer_outputs[3 if output_attentions else 1],
)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.is_export:
next_cache = next_decoder_cache
else:
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class CustomMBartForCausalLM(MBartForCausalLM):
def __init__(self, config):
super().__init__(config)
# Modify the decoder within MBartDecoderWrapper
self.model.decoder = CustomMBartDecoder(config)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
return CausalLMOutputWithCrossAttentions(
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class PPFormulaNet_Head(UniMERNetHead):
"""
PPFormulaNet_Head
Args:
max_new_tokens (int): Maximum number of new tokens to generate. Default is 1536.
decoder_start_token_id (int): Start token ID for the decoder. Default is 0.
temperature (float): Temperature parameter for controlling randomness in sampling. Default is 0.2.
do_sample (bool): Flag to determine whether to use sampling for generation. Default is False.
top_p (float): Top-p (nucleus) sampling parameter for controlling diversity. Default is 0.95.
in_channels (int): Number of input channels for the model. Default is 1024.
decoder_layers (int): Number of layers in the decoder. Default is 8.
encoder_hidden_size (int): Size of the hidden layer in the encoder. Default is 1024.
decoder_ffn_dim (int): Dimension of the feed-forward network in the decoder. Default is 4096.
decoder_hidden_size (int): Size of the hidden layer in the decoder. Default is 1024.
is_export (bool): Flag indicating whether the model is to be exported. Default is False.
length_aware (bool): Flag to determine if the model should be aware of input sequence length. Default is True.
use_parallel (bool): Flag to enable or disable parallel processing. Default is False.
parallel_step (int): Number of steps to use in parallel processing. Default is 3.
"""
def __init__(
self,
max_new_tokens=1536,
decoder_start_token_id=0,
temperature=0.2,
do_sample=False,
top_p=0.95,
in_channels=1024,
decoder_layers=8,
encoder_hidden_size=1024,
decoder_ffn_dim=4096,
decoder_hidden_size=1024,
is_export=False,
length_aware=True,
use_parallel=False,
parallel_step=3,
):
super().__init__()
mbart_config_dict = {
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": True,
"add_final_layer_norm": True,
"attention_dropout": 0.0,
"bos_token_id": 0,
"classifier_dropout": 0.0,
"d_model": decoder_hidden_size,
"decoder_attention_heads": 16,
"decoder_ffn_dim": decoder_ffn_dim,
"decoder_layerdrop": 0.0,
"decoder_layers": decoder_layers,
"dropout": 0.1,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"init_std": 0.02,
"is_decoder": True,
"is_encoder_decoder": False,
"output_hidden_states": False,
"max_position_embeddings": (
max_new_tokens + parallel_step if use_parallel else max_new_tokens
),
"model_type": "mbart",
"num_hidden_layers": 12,
"pad_token_id": 1,
"scale_embedding": True,
"tie_word_embeddings": False,
"transformers_version": "4.40.0",
"use_cache": True,
"use_return_dict": True,
"vocab_size": 50000,
"_attn_implementation": "eager",
"hidden_size": decoder_hidden_size,
"use_parallel": use_parallel,
"parallel_step": int(parallel_step),
"is_export": is_export,
}
self.decoder_start_token_id = decoder_start_token_id
self.temperature = temperature
self.do_sample = do_sample
self.top_p = top_p
self.is_export = is_export
self.max_seq_len = max_new_tokens
self.config_decoder = MBartConfig(**mbart_config_dict)
self.encoder_hidden_size = encoder_hidden_size
self.decoder = CustomMBartForCausalLM(self.config_decoder)
if self.config_decoder.hidden_size != self.encoder_hidden_size:
self.enc_to_dec_proj = nn.Linear(
self.encoder_hidden_size, self.config_decoder.hidden_size
)
generation_config = {
"max_length": 1537,
"forced_eos_token_id": 2,
}
self.eos_token_id = generation_config["forced_eos_token_id"]
self.pad_token_id = self.config_decoder.pad_token_id
self.logits_processor = LogitsProcessorList()
self.logits_processor.append(
ForcedEOSTokenLogitsProcessor(
generation_config["max_length"],
generation_config["forced_eos_token_id"],
)
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
decoder_inputs = self.prepare_inputs_for_generation_mbart(
input_ids, past_key_values=past_key_values
)
decoder_attention_mask = (
decoder_inputs["attention_mask"]
if "attention_mask" in decoder_inputs
else None
)
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
}
return input_dict
def _extract_past_from_model_output(
self, outputs: ModelOutput, standardize_cache_format: bool = False
):
past_key_values = None
if "past_key_values" in outputs:
past_key_values = outputs.past_key_values
elif "mems" in outputs:
past_key_values = outputs.mems
elif "past_buckets_states" in outputs:
past_key_values = outputs.past_buckets_states
return past_key_values
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = paddle.concat(
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1
)
if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = paddle.concat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
axis=-1,
)
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = paddle.concat(
[
decoder_attention_mask,
decoder_attention_mask.new_ones(
(decoder_attention_mask.shape[0], 1)
),
],
dim=-1,
)
if (
"cache_position" in model_kwargs
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
return model_kwargs
def stopping_criteria(self, input_ids):
if self.is_export:
return input_ids[:, -1] == paddle.to_tensor([self.eos_token_id])
is_done = paddle.isin(input_ids[:, -1], paddle.to_tensor([self.eos_token_id]))
return is_done
def stopping_criteria_parallel(self, input_ids):
parallel_step = self.config_decoder.parallel_step
if self.is_export:
is_done_list = []
for i in range(parallel_step, 0, -1):
cur_is_done = input_ids[:, -i] == paddle.to_tensor([self.eos_token_id])
is_done_list.append(cur_is_done)
is_done_list = paddle.to_tensor(is_done_list).transpose([1, 0])
return is_done_list
else:
is_done = paddle.isin(
input_ids[:, -parallel_step:],
paddle.to_tensor([self.eos_token_id]).reshape([1, 1]),
)
return paddle.to_tensor(is_done)
def generate_single_iter(
self,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
encoder_hidden_states = encoder_outputs[0]
if self.config_decoder.hidden_size != self.encoder_hidden_size:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
kwargs_decoder = {}
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
inputs_embeds=None,
output_attentions=False,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
return Seq2SeqLMOutput(
loss=None,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def _prepare_decoder_input_ids_for_generation(
self,
batch_size,
model_kwargs,
decoder_start_token_id=None,
bos_token_id=None,
):
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("input_ids")
else:
decoder_input_ids = None
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id = self._get_decoder_start_token_id(
decoder_start_token_id, bos_token_id
)
if isinstance(decoder_start_token_id, list):
if len(decoder_start_token_id) != batch_size:
raise ValueError(
f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}"
)
decoder_input_ids_start = paddle.to_tensor(
decoder_start_token_id,
dtype=paddle.int64,
)
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
else:
use_parallel = self.config_decoder.use_parallel
parallel_step = self.config_decoder.parallel_step
if use_parallel:
decoder_input_ids_start = (
paddle.ones(
(batch_size, parallel_step),
dtype=paddle.int64,
)
* decoder_start_token_id
)
else:
decoder_input_ids_start = (
paddle.ones(
(batch_size, 1),
dtype=paddle.int64,
)
* decoder_start_token_id
)
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
decoder_input_ids = decoder_input_ids_start
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
elif (
self.config.model_type == "vision-encoder-decoder"
and "donut" in self.name_or_path.lower()
):
pass
elif self.config.model_type in ["whisper"]:
pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (
isinstance(decoder_start_token_id, int)
and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
) or (
isinstance(decoder_start_token_id, paddle.Tensor)
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
):
decoder_input_ids = paddle.concat(
[decoder_input_ids_start, decoder_input_ids], axis=-1
)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = paddle.cat(
(
paddle.ones_like(decoder_attention_mask)[:, :1],
decoder_attention_mask,
),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
return decoder_input_ids, model_kwargs
@paddle.no_grad()
def generate_export(
self,
encoder_outputs,
model_kwargs,
):
use_parallel = self.config_decoder.use_parallel
parallel_step = self.config_decoder.parallel_step
batch_size = encoder_outputs["last_hidden_state"].shape[0]
generation_config = {
"decoder_start_token_id": 0,
"bos_token_id": 0,
}
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config["decoder_start_token_id"],
bos_token_id=generation_config["bos_token_id"],
)
if not use_parallel:
input_ids = input_ids.reshape([-1, 1])
decoder_input_ids = input_ids
model_kwargs["key use_cache"] = True
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
cache_position = paddle.arange(cur_len)
pad_token_id = self.pad_token_id
eos_token_id = [self.eos_token_id]
eos_token = self.eos_token_id
if use_parallel:
unfinished_sequences = paddle.ones(
[batch_size, parallel_step], dtype=paddle.int64
)
parallel_length = math.ceil(self.max_seq_len // parallel_step)
else:
unfinished_sequences = paddle.ones(batch_size, dtype=paddle.int64)
parallel_length = self.max_seq_len
i_idx = paddle.full([], 0)
past_key_values = []
decoder_attention_heads = self.config_decoder.decoder_attention_heads
decoder_attention_heads_dim = int(
self.config_decoder.d_model / decoder_attention_heads
)
for i in range(self.config_decoder.decoder_layers):
init_arr = paddle.zeros(
[batch_size, decoder_attention_heads, 0, decoder_attention_heads_dim]
)
paddle.jit.api.set_dynamic_shape(init_arr, [-1, -1, -1, -1])
cache = (init_arr, init_arr, init_arr, init_arr)
past_key_values.append(cache)
while i_idx < paddle.to_tensor(parallel_length):
model_inputs = self.prepare_inputs_for_generation_export(
past_key_values=past_key_values, **model_kwargs
)
decoder_attention_mask = paddle.ones(paddle.shape(input_ids))
paddle.jit.api.set_dynamic_shape(decoder_input_ids, [-1, -1])
paddle.jit.api.set_dynamic_shape(decoder_attention_mask, [-1, -1])
outputs = self.generate_single_iter(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
if use_parallel:
next_token_logits = outputs.logits[:, -parallel_step:, :]
else:
next_token_logits = outputs.logits[:, -1, :]
next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
next_tokens = paddle.argmax(next_tokens_scores, axis=-1)
if eos_token_id is not None:
# False
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
if use_parallel:
input_ids = paddle.concat([input_ids, next_tokens], axis=-1)
decoder_input_ids = next_tokens
else:
input_ids = paddle.concat(
[input_ids, next_tokens.unsqueeze(1)], axis=-1
)
decoder_input_ids = next_tokens.unsqueeze(1)
past_length = past_key_values[0][0].shape[2]
past_key_values = outputs.past_key_values
cache_position = cache_position[-1:] + 1
if use_parallel:
unfinished_sequences = (
unfinished_sequences
& ~self.stopping_criteria_parallel(input_ids).cast(paddle.int64)
)
else:
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
input_ids
).cast(paddle.int64)
if (
eos_token is not None
and (
paddle.cumsum((input_ids == eos_token).cast(paddle.int64), 1)[:, -1]
>= 1
).all()
):
break
i_idx += 1
# break
return input_ids
@paddle.no_grad()
def generate(
self,
encoder_outputs,
model_kwargs,
):
"""
Generate sequences from the model without computing gradients.
This method is used to generate sequences from the model based on the given encoder outputs.
It does not compute gradients, making it suitable for inference.
Args:
encoder_outputs: The outputs from the encoder, typically including hidden states necessary for generation.
model_kwargs: Additional keyword arguments that may include parameters such as maximum length,
temperature, top-k/top-p sampling parameters, and other generation-specific settings.
Returns:
Generated sequences based on the encoder outputs and specified generation parameters.
"""
use_parallel = self.config_decoder.use_parallel
parallel_step = self.config_decoder.parallel_step
batch_size = encoder_outputs["last_hidden_state"].shape[0]
generation_config = {
"decoder_start_token_id": 0,
"bos_token_id": 0,
}
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config["decoder_start_token_id"],
bos_token_id=generation_config["bos_token_id"],
)
decoder_input_ids = input_ids
model_kwargs["key use_cache"] = True
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = paddle.arange(cur_len)
pad_token_id = self.pad_token_id
eos_token_id = [self.eos_token_id]
eos_token = self.eos_token_id
if use_parallel:
unfinished_sequences = paddle.ones(
[batch_size, parallel_step], dtype=paddle.int64
)
parallel_length = math.ceil(self.max_seq_len // parallel_step)
else:
unfinished_sequences = paddle.ones(batch_size, dtype=paddle.int64)
parallel_length = self.max_seq_len
past_key_values = []
for idx in range(parallel_length):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self.generate_single_iter(
**model_inputs,
encoder_outputs=encoder_outputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
if use_parallel:
next_token_logits = outputs.logits[:, :, :]
else:
next_token_logits = outputs.logits[:, -1, :]
next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
next_tokens = paddle.argmax(next_tokens_scores, axis=-1)
if eos_token_id is not None:
# False
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
if use_parallel:
input_ids = paddle.concat([input_ids, next_tokens], axis=-1)
else:
input_ids = paddle.concat([input_ids, next_tokens[:, None]], axis=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config_decoder.is_encoder_decoder,
)
if use_parallel:
unfinished_sequences = (
unfinished_sequences
& ~self.stopping_criteria_parallel(input_ids).cast(paddle.int64)
)
else:
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
input_ids
).cast(paddle.int64)
if (
eos_token is not None
and (
paddle.cumsum((input_ids == eos_token).cast(paddle.int64), 1)[:, -1]
>= 1
).all()
):
break
return input_ids
def forwad_train(
self,
encoder_outputs,
decoder_input_ids,
decoder_attention_mask,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
"""
Forward pass for training the model.
Args:
encoder_outputs: The outputs from the encoder, typically including hidden states.
decoder_input_ids: Input IDs for the decoder.
decoder_attention_mask: Attention mask for the decoder inputs to avoid attending to padding tokens.
past_key_values: Previously computed key and value states for the decoder, used for fast generation.
decoder_inputs_embeds: Optional embeddings for decoder inputs, used instead of decoder_input_ids if provided.
labels: Labels for computing the training loss.
use_cache: Whether to use a cache of past key values for faster generation.
output_attentions: Whether to output attention weights.
output_hidden_states: Whether to output hidden states of all layers.
return_dict: Whether to return the output as a dictionary.
**kwargs: Additional keyword arguments.
Returns:
Depending on the `return_dict` flag, returns either a dictionary of model outputs or a tuple.
"""
if self.config_decoder.use_parallel:
batch = decoder_input_ids.shape[0]
add_sos_token = self.config_decoder.parallel_step - 1
start_token = paddle.zeros([batch, add_sos_token]).cast(paddle.int64)
start_mask = paddle.ones([batch, add_sos_token]).cast(paddle.int64)
decoder_input_ids = paddle.concat([start_token, decoder_input_ids], axis=1)
decoder_attention_mask = paddle.concat(
[start_mask, decoder_attention_mask], axis=1
)
labels = decoder_input_ids * 1
labels = labels.masked_fill_(labels == self.pad_token_id, -100)
if self.config_decoder.use_parallel:
input_decoder_input_ids = decoder_input_ids[
:, : -self.config_decoder.parallel_step
]
input_decoder_attention_mask = decoder_attention_mask[
:, : -self.config_decoder.parallel_step
]
else:
input_decoder_input_ids = decoder_input_ids[:, :-1]
input_decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_hidden_states = encoder_outputs[0]
kwargs_decoder = {}
if self.config_decoder.hidden_size != self.encoder_hidden_size:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
decoder_outputs = self.decoder(
input_ids=input_decoder_input_ids,
attention_mask=input_decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
inputs_embeds=None,
output_attentions=False,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
logits = decoder_outputs.logits
return logits, labels
# forward for export
def forward(self, inputs, targets=None):
if not self.training:
encoder_outputs = inputs
model_kwargs = {
"output_attentions": False,
"output_hidden_states": False,
"use_cache": True,
}
if self.is_export:
word_pred = self.generate_export(encoder_outputs, model_kwargs)
else:
word_pred = self.generate(encoder_outputs, model_kwargs)
return word_pred
encoder_outputs, tgt_seq, mask = inputs
logits, masked_labels = self.forwad_train(encoder_outputs, tgt_seq, mask)
return logits, masked_labels