mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* [Feat] Migrate blip caption to mmpretrain. (#50) * Migrate blip caption to mmpretrain * minor fix * support train * [Feature] Support OFA caption task. (#51) * [Feature] Support OFA caption task. * Remove duplicated files. * [Feature] Support OFA vqa task. (#58) * [Feature] Support OFA vqa task. * Fix lint. * [Feat] Add BLIP retrieval to mmpretrain. (#55) * init * minor fix for train * fix according to comments * refactor * Update Blip retrieval. (#62) * [Feature] Support OFA visual grounding task. (#59) * [Feature] Support OFA visual grounding task. * minor add TODO --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Add flamingos coco caption and vqa. (#60) * first init * init flamingo coco * add vqa * minor fix * remove unnecessary modules * Update config * Use `ApplyToList`. --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 coco retrieval (#53) * [Feature]: Add blip2 retriever * [Feature]: Add blip2 all modules * [Feature]: Refine model * [Feature]: x1 * [Feature]: Runnable coco ret * [Feature]: Runnable version * [Feature]: Fix lint * [Fix]: Fix lint * [Feature]: Use 364 img size * [Feature]: Refactor blip2 * [Fix]: Fix lint * refactor files * minor fix * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Remove * fix blip caption inputs (#68) * [Feat] Add BLIP NLVR support. (#67) * first init * init flamingo coco * add vqa * add nlvr * refactor nlvr * minor fix * minor fix * Update dataset --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 Caption (#70) * [Feature]: Add language model * [Feature]: blip2 caption forward * [Feature]: Reproduce the results * [Feature]: Refactor caption * refine config --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Migrate BLIP VQA to mmpretrain (#69) * reformat * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * refactor code --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Update RefCOCO dataset * [Fix] fix lint * [Feature] Implement inference APIs for multi-modal tasks. (#65) * [Feature] Implement inference APIs for multi-modal tasks. * [Project] Add gradio demo. * [Improve] Update requirements * Update flamingo * Update blip * Add NLVR inferencer * Update flamingo * Update hugging face model register * Update ofa vqa * Update BLIP-vqa (#71) * Update blip-vqa docstring (#72) * Refine flamingo docstring (#73) * [Feature]: BLIP2 VQA (#61) * [Feature]: VQA forward * [Feature]: Reproduce accuracy * [Fix]: Fix lint * [Fix]: Add blank line * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feature]: BLIP2 docstring (#74) * [Feature]: Add caption docstring * [Feature]: Add docstring to blip2 vqa * [Feature]: Add docstring to retrieval * Update BLIP-2 metafile and README (#75) * [Feature]: Add readme and docstring * Update blip2 results --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] BLIP Visual Grounding on MMPretrain Branch (#66) * blip grounding merge with mmpretrain * remove commit * blip grounding test and inference api * refcoco dataset * refcoco dataset refine config * rebasing * gitignore * rebasing * minor edit * minor edit * Update blip-vqa docstring (#72) * rebasing * Revert "minor edit" This reverts commit 639cec757c215e654625ed0979319e60f0be9044. * blip grounding final * precommit * refine config * refine config * Update blip visual grounding --------- Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: mzr1996 <mzr1996@163.com> * Update visual grounding metric * Update OFA docstring, README and metafiles. (#76) * [Docs] Update installation docs and gradio demo docs. (#77) * Update OFA name * Update Visual Grounding Visualizer * Integrate accelerate support * Fix imports. * Fix timm backbone * Update imports * Update README * Update circle ci * Update flamingo config * Add gradio demo README * [Feature]: Add scienceqa (#1571) * [Feature]: Add scienceqa * [Feature]: Change param name * Update docs * Update video --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: yingfhu <yingfhu@gmail.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: Rongjie Li <limo97@163.com>
1321 lines
51 KiB
Python
1321 lines
51 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
# flake8: noqa
|
|
|
|
import math
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor, device
|
|
|
|
try:
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
CausalLMOutputWithCrossAttentions)
|
|
from transformers.modeling_utils import (PreTrainedModel,
|
|
apply_chunking_to_forward,
|
|
find_pruneable_heads_and_indices,
|
|
prune_linear_layer)
|
|
from transformers.models.bert.configuration_bert import BertConfig
|
|
except:
|
|
ACT2FN = None
|
|
BaseModelOutputWithPastAndCrossAttentions = None
|
|
BaseModelOutputWithPoolingAndCrossAttentions = None
|
|
CausalLMOutputWithCrossAttentions = None
|
|
PreTrainedModel = None
|
|
apply_chunking_to_forward = None
|
|
find_pruneable_heads_and_indices = None
|
|
prune_linear_layer = None
|
|
BertConfig = None
|
|
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
class BertEmbeddings(nn.Module):
|
|
"""Construct the embeddings from word and position embeddings."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
padding_idx=config.pad_token_id)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
|
|
config.hidden_size)
|
|
|
|
if config.add_type_embeddings:
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
|
|
config.hidden_size)
|
|
|
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
|
# any TensorFlow checkpoint file
|
|
self.LayerNorm = nn.LayerNorm(
|
|
config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer(
|
|
'position_ids',
|
|
torch.arange(config.max_position_embeddings).expand((1, -1)))
|
|
self.position_embedding_type = getattr(config,
|
|
'position_embedding_type',
|
|
'absolute')
|
|
|
|
self.config = config
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
inputs_embeds=None,
|
|
past_key_values_length=0,
|
|
):
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
else:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
seq_length = input_shape[1]
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, past_key_values_length:
|
|
seq_length +
|
|
past_key_values_length]
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
if token_type_ids is not None:
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings
|
|
else:
|
|
embeddings = inputs_embeds
|
|
|
|
if self.position_embedding_type == 'absolute':
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings += position_embeddings
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class BertPooler(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self, hidden_states):
|
|
# We "pool" the model by simply taking the hidden state corresponding
|
|
# to the first token.
|
|
first_token_tensor = hidden_states[:, 0]
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
class BertPreTrainedModel(PreTrainedModel):
|
|
"""An abstract class to handle weights initialization and a simple
|
|
interface for downloading and loading pretrained models."""
|
|
|
|
config_class = BertConfig
|
|
base_model_prefix = 'bert'
|
|
_keys_to_ignore_on_load_missing = [r'position_ids']
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights."""
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(
|
|
mean=0.0, std=self.config.initializer_range)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
|
|
class BertSelfAttention(nn.Module):
|
|
|
|
def __init__(self, config, is_cross_attention):
|
|
super().__init__()
|
|
self.config = config
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
|
config, 'embedding_size'):
|
|
raise ValueError(
|
|
'The hidden size (%d) is not a multiple of the number of attention '
|
|
'heads (%d)' %
|
|
(config.hidden_size, config.num_attention_heads))
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size /
|
|
config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
if is_cross_attention:
|
|
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
|
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
|
else:
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
self.position_embedding_type = getattr(config,
|
|
'position_embedding_type',
|
|
'absolute')
|
|
if (self.position_embedding_type == 'relative_key'
|
|
or self.position_embedding_type == 'relative_key_query'):
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.distance_embedding = nn.Embedding(
|
|
2 * config.max_position_embeddings - 1,
|
|
self.attention_head_size)
|
|
self.save_attention = False
|
|
|
|
def save_attn_gradients(self, attn_gradients):
|
|
self.attn_gradients = attn_gradients
|
|
|
|
def get_attn_gradients(self):
|
|
return self.attn_gradients
|
|
|
|
def save_attention_map(self, attention_map):
|
|
self.attention_map = attention_map
|
|
|
|
def get_attention_map(self):
|
|
return self.attention_map
|
|
|
|
def transpose_for_scores(self, x):
|
|
new_x_shape = x.size()[:-1] + (
|
|
self.num_attention_heads,
|
|
self.attention_head_size,
|
|
)
|
|
x = x.view(*new_x_shape)
|
|
return x.permute(0, 2, 1, 3)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_value=None,
|
|
output_attentions=False,
|
|
):
|
|
mixed_query_layer = self.query(hidden_states)
|
|
|
|
# If this is instantiated as a cross-attention module, the keys
|
|
# and values come from an encoder; the attention mask needs to be
|
|
# such that the encoder's padding tokens are not attended to.
|
|
is_cross_attention = encoder_hidden_states is not None
|
|
|
|
if is_cross_attention:
|
|
key_layer = self.transpose_for_scores(
|
|
self.key(encoder_hidden_states))
|
|
value_layer = self.transpose_for_scores(
|
|
self.value(encoder_hidden_states))
|
|
attention_mask = encoder_attention_mask
|
|
elif past_key_value is not None:
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
|
else:
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
|
|
past_key_value = (key_layer, value_layer)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_layer,
|
|
key_layer.transpose(-1, -2))
|
|
|
|
if (self.position_embedding_type == 'relative_key'
|
|
or self.position_embedding_type == 'relative_key_query'):
|
|
seq_length = hidden_states.size()[1]
|
|
position_ids_l = torch.arange(
|
|
seq_length, dtype=torch.long,
|
|
device=hidden_states.device).view(-1, 1)
|
|
position_ids_r = torch.arange(
|
|
seq_length, dtype=torch.long,
|
|
device=hidden_states.device).view(1, -1)
|
|
distance = position_ids_l - position_ids_r
|
|
positional_embedding = self.distance_embedding(
|
|
distance + self.max_position_embeddings - 1)
|
|
positional_embedding = positional_embedding.to(
|
|
dtype=query_layer.dtype) # fp16 compatibility
|
|
|
|
if self.position_embedding_type == 'relative_key':
|
|
relative_position_scores = torch.einsum(
|
|
'bhld,lrd->bhlr', query_layer, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores
|
|
elif self.position_embedding_type == 'relative_key_query':
|
|
relative_position_scores_query = torch.einsum(
|
|
'bhld,lrd->bhlr', query_layer, positional_embedding)
|
|
relative_position_scores_key = torch.einsum(
|
|
'bhrd,lrd->bhlr', key_layer, positional_embedding)
|
|
attention_scores = (
|
|
attention_scores + relative_position_scores_query +
|
|
relative_position_scores_key)
|
|
|
|
attention_scores = attention_scores / math.sqrt(
|
|
self.attention_head_size)
|
|
if attention_mask is not None:
|
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
|
|
|
if is_cross_attention and self.save_attention:
|
|
self.save_attention_map(attention_probs)
|
|
attention_probs.register_hook(self.save_attn_gradients)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs_dropped = self.dropout(attention_probs)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attention_probs_dropped = attention_probs_dropped * head_mask
|
|
|
|
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (
|
|
self.all_head_size, )
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
outputs = ((context_layer, attention_probs) if output_attentions else
|
|
(context_layer, ))
|
|
|
|
outputs = outputs + (past_key_value, )
|
|
return outputs
|
|
|
|
|
|
class BertSelfOutput(nn.Module):
|
|
|
|
def __init__(self, config, twin=False, merge=False):
|
|
super().__init__()
|
|
self.LayerNorm = nn.LayerNorm(
|
|
config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
if twin:
|
|
self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
|
|
else:
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
if merge:
|
|
self.act = ACT2FN[config.hidden_act]
|
|
self.merge_layer = nn.Linear(config.hidden_size * 2,
|
|
config.hidden_size)
|
|
self.merge = True
|
|
else:
|
|
self.merge = False
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
if type(hidden_states) == list:
|
|
hidden_states0 = self.dense0(hidden_states[0])
|
|
hidden_states1 = self.dense1(hidden_states[1])
|
|
if self.merge:
|
|
hidden_states = self.merge_layer(
|
|
torch.cat([hidden_states0, hidden_states1], dim=-1))
|
|
else:
|
|
hidden_states = (hidden_states0 + hidden_states1) / 2
|
|
else:
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class BertAttention(nn.Module):
|
|
|
|
def __init__(self, config, is_cross_attention=False, layer_num=-1):
|
|
super().__init__()
|
|
is_nlvr = is_cross_attention and getattr(config, 'nlvr', False)
|
|
if is_nlvr:
|
|
self.self0 = BertSelfAttention(config, is_nlvr)
|
|
self.self1 = BertSelfAttention(config, is_nlvr)
|
|
else:
|
|
self.self = BertSelfAttention(config, is_cross_attention)
|
|
self.output = BertSelfOutput(
|
|
config,
|
|
twin=is_nlvr,
|
|
merge=(is_nlvr and layer_num >= 6),
|
|
)
|
|
self.pruned_heads = set()
|
|
|
|
def prune_heads(self, heads):
|
|
if len(heads) == 0:
|
|
return
|
|
heads, index = find_pruneable_heads_and_indices(
|
|
heads,
|
|
self.self.num_attention_heads,
|
|
self.self.attention_head_size,
|
|
self.pruned_heads,
|
|
)
|
|
|
|
# Prune linear layers
|
|
self.self.query = prune_linear_layer(self.self.query, index)
|
|
self.self.key = prune_linear_layer(self.self.key, index)
|
|
self.self.value = prune_linear_layer(self.self.value, index)
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
|
|
# Update hyper params and store pruned heads
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(
|
|
heads)
|
|
self.self.all_head_size = (
|
|
self.self.attention_head_size * self.self.num_attention_heads)
|
|
self.pruned_heads = self.pruned_heads.union(heads)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_value=None,
|
|
output_attentions=False,
|
|
):
|
|
if type(encoder_hidden_states) == list:
|
|
self_outputs0 = self.self0(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states[0],
|
|
encoder_attention_mask[0],
|
|
past_key_value,
|
|
output_attentions,
|
|
)
|
|
self_outputs1 = self.self1(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states[1],
|
|
encoder_attention_mask[1],
|
|
past_key_value,
|
|
output_attentions,
|
|
)
|
|
attention_output = self.output(
|
|
[self_outputs0[0], self_outputs1[0]], hidden_states)
|
|
|
|
outputs = (attention_output, ) + self_outputs0[
|
|
1:] # add attentions if we output them
|
|
else:
|
|
self_outputs = self.self(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
past_key_value,
|
|
output_attentions,
|
|
)
|
|
attention_output = self.output(self_outputs[0], hidden_states)
|
|
outputs = (attention_output,
|
|
) + self_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class BertIntermediate(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertOutput(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(
|
|
config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class BertLayer(nn.Module):
|
|
|
|
def __init__(self, config, layer_num):
|
|
super().__init__()
|
|
self.config = config
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = BertAttention(config)
|
|
self.layer_num = layer_num
|
|
|
|
# compatibility for ALBEF and BLIP
|
|
try:
|
|
# ALBEF & ALPRO
|
|
fusion_layer = self.config.fusion_layer
|
|
add_cross_attention = (
|
|
fusion_layer <= layer_num and self.config.add_cross_attention)
|
|
|
|
self.fusion_layer = fusion_layer
|
|
except AttributeError:
|
|
# BLIP
|
|
self.fusion_layer = self.config.num_hidden_layers
|
|
add_cross_attention = self.config.add_cross_attention
|
|
|
|
# if self.config.add_cross_attention:
|
|
if self.config.add_cross_attention:
|
|
self.crossattention = BertAttention(
|
|
config,
|
|
is_cross_attention=self.config.add_cross_attention,
|
|
layer_num=layer_num,
|
|
)
|
|
self.intermediate = BertIntermediate(config)
|
|
self.output = BertOutput(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_value=None,
|
|
output_attentions=False,
|
|
mode=None,
|
|
):
|
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
self_attn_past_key_value = (
|
|
past_key_value[:2] if past_key_value is not None else None)
|
|
self_attention_outputs = self.attention(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
output_attentions=output_attentions,
|
|
past_key_value=self_attn_past_key_value,
|
|
)
|
|
attention_output = self_attention_outputs[0]
|
|
|
|
outputs = self_attention_outputs[1:-1]
|
|
present_key_value = self_attention_outputs[-1]
|
|
|
|
# TODO line 482 in albef/models/xbert.py
|
|
# compatibility for ALBEF and BLIP
|
|
if mode in ['multimodal', 'fusion'] and hasattr(
|
|
self, 'crossattention'):
|
|
assert (
|
|
encoder_hidden_states is not None
|
|
), 'encoder_hidden_states must be given for cross-attention layers'
|
|
|
|
cross_attention_outputs = self.crossattention(
|
|
attention_output,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
attention_output = cross_attention_outputs[0]
|
|
outputs = (outputs + cross_attention_outputs[1:-1]
|
|
) # add cross attentions if we output attention weights
|
|
layer_output = apply_chunking_to_forward(
|
|
self.feed_forward_chunk,
|
|
self.chunk_size_feed_forward,
|
|
self.seq_len_dim,
|
|
attention_output,
|
|
)
|
|
outputs = (layer_output, ) + outputs
|
|
|
|
outputs = outputs + (present_key_value, )
|
|
|
|
return outputs
|
|
|
|
def feed_forward_chunk(self, attention_output):
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
|
|
class BertEncoder(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList(
|
|
[BertLayer(config, i) for i in range(config.num_hidden_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_values=None,
|
|
use_cache=None,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
return_dict=True,
|
|
mode='multimodal',
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
all_cross_attentions = (() if output_attentions
|
|
and self.config.add_cross_attention else None)
|
|
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
try:
|
|
# ALBEF
|
|
fusion_layer = self.config.fusion_layer
|
|
except AttributeError:
|
|
# BLIP
|
|
fusion_layer = self.config.num_hidden_layers
|
|
|
|
if mode == 'text':
|
|
start_layer = 0
|
|
# output_layer = self.config.fusion_layer
|
|
output_layer = fusion_layer
|
|
|
|
elif mode == 'fusion':
|
|
# start_layer = self.config.fusion_layer
|
|
start_layer = fusion_layer
|
|
output_layer = self.config.num_hidden_layers
|
|
|
|
elif mode == 'multimodal':
|
|
start_layer = 0
|
|
output_layer = self.config.num_hidden_layers
|
|
|
|
# compatibility for ALBEF and BLIP
|
|
# for i in range(self.config.num_hidden_layers):
|
|
for i in range(start_layer, output_layer):
|
|
layer_module = self.layer[i]
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
past_key_value = past_key_values[
|
|
i] if past_key_values is not None else None
|
|
|
|
# TODO pay attention to this.
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
if use_cache:
|
|
# TODO: logger here
|
|
# logger.warn(
|
|
# "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
# )
|
|
use_cache = False
|
|
|
|
def create_custom_forward(module):
|
|
|
|
def custom_forward(*inputs):
|
|
return module(*inputs, past_key_value,
|
|
output_attentions)
|
|
|
|
return custom_forward
|
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(layer_module),
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
mode=mode,
|
|
)
|
|
else:
|
|
layer_outputs = layer_module(
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
past_key_value,
|
|
output_attentions,
|
|
mode=mode,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
if use_cache:
|
|
next_decoder_cache += (layer_outputs[-1], )
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (
|
|
layer_outputs[1], )
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [
|
|
hidden_states,
|
|
next_decoder_cache,
|
|
all_hidden_states,
|
|
all_self_attentions,
|
|
all_cross_attentions,
|
|
] if v is not None)
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_decoder_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.transform_act_fn = config.hidden_act
|
|
self.LayerNorm = nn.LayerNorm(
|
|
config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.transform_act_fn(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertLMPredictionHead(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.transform = BertPredictionHeadTransform(config)
|
|
|
|
# The output weights are the same as the input embeddings, but there is
|
|
# an output-only bias for each token.
|
|
self.decoder = nn.Linear(
|
|
config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
|
|
|
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
self.decoder.bias = self.bias
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.transform(hidden_states)
|
|
hidden_states = self.decoder(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertOnlyMLMHead(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.predictions = BertLMPredictionHead(config)
|
|
|
|
def forward(self, sequence_output):
|
|
prediction_scores = self.predictions(sequence_output)
|
|
return prediction_scores
|
|
|
|
|
|
@MODELS.register_module()
|
|
class BertModel(BertPreTrainedModel):
|
|
"""The model can behave as an encoder (with only self-attention) as well as
|
|
a decoder, in which case a layer of cross-attention is added between the
|
|
self-attention layers, following the architecture described in `Attention
|
|
is all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani,
|
|
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N.
|
|
|
|
Gomez, Lukasz Kaiser and Illia Polosukhin. argument and
|
|
:obj:`add_cross_attention` set to :obj:`True`; an
|
|
:obj:`encoder_hidden_states` is then expected as an input to the forward
|
|
pass.
|
|
"""
|
|
|
|
def __init__(self, config, add_pooling_layer=True):
|
|
if not isinstance(config, BertConfig):
|
|
config = BertConfig.from_dict(config)
|
|
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.embeddings = BertEmbeddings(config)
|
|
|
|
self.encoder = BertEncoder(config)
|
|
|
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
|
self.init_weights()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
"""Prunes heads of the model.
|
|
|
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
|
class PreTrainedModel
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
|
|
def get_extended_attention_mask(
|
|
self,
|
|
attention_mask: Tensor,
|
|
input_shape: Tuple[int],
|
|
device: device,
|
|
is_decoder: bool,
|
|
) -> Tensor:
|
|
"""Makes broadcastable attention and causal masks so that future and
|
|
masked tokens are ignored.
|
|
|
|
Arguments:
|
|
attention_mask (:obj:`torch.Tensor`):
|
|
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
|
input_shape (:obj:`Tuple[int]`):
|
|
The shape of the input to the model.
|
|
device: (:obj:`torch.device`):
|
|
The device of the input to the model.
|
|
|
|
Returns:
|
|
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
|
"""
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
if attention_mask.dim() == 3:
|
|
extended_attention_mask = attention_mask[:, None, :, :]
|
|
elif attention_mask.dim() == 2:
|
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if is_decoder:
|
|
batch_size, seq_length = input_shape
|
|
|
|
seq_ids = torch.arange(seq_length, device=device)
|
|
causal_mask = (
|
|
seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <=
|
|
seq_ids[None, :, None])
|
|
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
|
# causal and attention masks must have same type with pytorch version < 1.3
|
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
|
|
|
if causal_mask.shape[1] < attention_mask.shape[1]:
|
|
prefix_seq_len = attention_mask.shape[
|
|
1] - causal_mask.shape[1]
|
|
causal_mask = torch.cat(
|
|
[
|
|
torch.ones(
|
|
(batch_size, seq_length, prefix_seq_len),
|
|
device=device,
|
|
dtype=causal_mask.dtype,
|
|
),
|
|
causal_mask,
|
|
],
|
|
axis=-1,
|
|
)
|
|
|
|
extended_attention_mask = (
|
|
causal_mask[:, None, :, :] *
|
|
attention_mask[:, None, None, :])
|
|
else:
|
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
else:
|
|
raise ValueError(
|
|
'Wrong shape for input_ids (shape {}) or attention_mask (shape {})'
|
|
.format(input_shape, attention_mask.shape))
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
extended_attention_mask = extended_attention_mask.to(
|
|
dtype=self.dtype) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
return extended_attention_mask
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
encoder_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_values=None,
|
|
use_cache=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
is_decoder=False,
|
|
mode='multimodal',
|
|
):
|
|
r"""
|
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
|
the model is configured as a decoder.
|
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
|
use_cache (:obj:`bool`, `optional`):
|
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
|
decoding (see :obj:`past_key_values`).
|
|
"""
|
|
output_attentions = (
|
|
output_attentions if output_attentions is not None else
|
|
self.config.output_attentions)
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else
|
|
self.config.output_hidden_states)
|
|
return_dict = (
|
|
return_dict
|
|
if return_dict is not None else self.config.use_return_dict)
|
|
|
|
if is_decoder:
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
else:
|
|
use_cache = False
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError(
|
|
'You cannot specify both input_ids and inputs_embeds at the same time'
|
|
)
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
batch_size, seq_length = input_shape
|
|
device = input_ids.device
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
batch_size, seq_length = input_shape
|
|
device = inputs_embeds.device
|
|
elif encoder_embeds is not None:
|
|
input_shape = encoder_embeds.size()[:-1]
|
|
batch_size, seq_length = input_shape
|
|
device = encoder_embeds.device
|
|
else:
|
|
raise ValueError(
|
|
'You have to specify either input_ids or inputs_embeds or encoder_embeds'
|
|
)
|
|
|
|
# past_key_values_length
|
|
past_key_values_length = (
|
|
past_key_values[0][0].shape[2]
|
|
if past_key_values is not None else 0)
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(
|
|
((batch_size, seq_length + past_key_values_length)),
|
|
device=device)
|
|
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
|
attention_mask, input_shape, device, is_decoder)
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if encoder_hidden_states is not None:
|
|
if type(encoder_hidden_states) == list:
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
|
|
0].size()
|
|
else:
|
|
(
|
|
encoder_batch_size,
|
|
encoder_sequence_length,
|
|
_,
|
|
) = encoder_hidden_states.size()
|
|
encoder_hidden_shape = (encoder_batch_size,
|
|
encoder_sequence_length)
|
|
|
|
if type(encoder_attention_mask) == list:
|
|
encoder_extended_attention_mask = [
|
|
self.invert_attention_mask(mask)
|
|
for mask in encoder_attention_mask
|
|
]
|
|
elif encoder_attention_mask is None:
|
|
encoder_attention_mask = torch.ones(
|
|
encoder_hidden_shape, device=device)
|
|
encoder_extended_attention_mask = self.invert_attention_mask(
|
|
encoder_attention_mask)
|
|
else:
|
|
encoder_extended_attention_mask = self.invert_attention_mask(
|
|
encoder_attention_mask)
|
|
else:
|
|
encoder_extended_attention_mask = None
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
head_mask = self.get_head_mask(head_mask,
|
|
self.config.num_hidden_layers)
|
|
|
|
if encoder_embeds is None:
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
else:
|
|
embedding_output = encoder_embeds
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask=extended_attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_extended_attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
mode=mode,
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = (
|
|
self.pooler(sequence_output) if self.pooler is not None else None)
|
|
|
|
if not return_dict:
|
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
past_key_values=encoder_outputs.past_key_values,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
cross_attentions=encoder_outputs.cross_attentions,
|
|
)
|
|
|
|
|
|
class BaseEncoder(nn.Module):
|
|
"""Base class for primitive encoders, such as ViT, TimeSformer, etc."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward_features(self, samples, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def device(self):
|
|
return list(self.parameters())[0].device
|
|
|
|
|
|
@MODELS.register_module()
|
|
class XBertEncoder(BertModel, BaseEncoder):
|
|
|
|
def __init__(self, med_config, from_pretrained=False):
|
|
|
|
med_config = BertConfig.from_dict(med_config)
|
|
super().__init__(config=med_config, add_pooling_layer=False)
|
|
|
|
def forward_automask(self, tokenized_text, visual_embeds, **kwargs):
|
|
image_atts = torch.ones(
|
|
visual_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
|
|
text = tokenized_text
|
|
text_output = super().forward(
|
|
text.input_ids,
|
|
attention_mask=text.attention_mask,
|
|
encoder_hidden_states=visual_embeds,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
|
|
return text_output
|
|
|
|
def forward_text(self, tokenized_text, **kwargs):
|
|
text = tokenized_text
|
|
token_type_ids = kwargs.get('token_type_ids', None)
|
|
|
|
text_output = super().forward(
|
|
text.input_ids,
|
|
attention_mask=text.attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
return_dict=True,
|
|
mode='text',
|
|
)
|
|
|
|
return text_output
|
|
|
|
|
|
@MODELS.register_module()
|
|
class Linear(torch.nn.Linear):
|
|
"""Wrapper for linear function."""
|
|
|
|
|
|
@MODELS.register_module()
|
|
class BertLMHeadModel(BertPreTrainedModel):
|
|
|
|
_keys_to_ignore_on_load_unexpected = [r'pooler']
|
|
_keys_to_ignore_on_load_missing = [
|
|
r'position_ids', r'predictions.decoder.bias'
|
|
]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False)
|
|
self.cls = BertOnlyMLMHead(config)
|
|
|
|
self.init_weights()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.cls.predictions.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.cls.predictions.decoder = new_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
labels=None,
|
|
past_key_values=None,
|
|
use_cache=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
return_logits=False,
|
|
is_decoder=True,
|
|
reduction='mean',
|
|
mode='multimodal',
|
|
):
|
|
r"""
|
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
|
the model is configured as a decoder.
|
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
|
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
|
use_cache (:obj:`bool`, `optional`):
|
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
|
decoding (see :obj:`past_key_values`).
|
|
Returns:
|
|
Example::
|
|
>>> from transformers import BertTokenizer,
|
|
BertLMHeadModel, BertConfig
|
|
>>> import torch
|
|
>>> tokenizer = BertTokenizer.from_pretrained(
|
|
'bert-base-cased')
|
|
>>> config = BertConfig.from_pretrained(
|
|
"bert-base-cased")
|
|
>>> model = BertLMHeadModel.from_pretrained(
|
|
'bert-base-cased', config=config)
|
|
>>> inputs = tokenizer(
|
|
"Hello, my dog is cute",
|
|
return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
>>> prediction_logits = outputs.logits
|
|
"""
|
|
return_dict = (
|
|
return_dict
|
|
if return_dict is not None else self.config.use_return_dict)
|
|
if labels is not None:
|
|
use_cache = False
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
is_decoder=is_decoder,
|
|
mode=mode,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
if return_logits:
|
|
return prediction_scores[:, :-1, :].contiguous()
|
|
|
|
lm_loss = None
|
|
if labels is not None:
|
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
|
shifted_prediction_scores = prediction_scores[:, :
|
|
-1, :].contiguous()
|
|
labels = labels[:, 1:].contiguous()
|
|
loss_fct = torch.nn.CrossEntropyLoss(
|
|
reduction=reduction, label_smoothing=0.1)
|
|
lm_loss = loss_fct(
|
|
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
|
labels.view(-1))
|
|
if reduction == 'none':
|
|
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores, ) + outputs[2:]
|
|
return ((lm_loss, ) + output) if lm_loss is not None else output
|
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
loss=lm_loss,
|
|
logits=prediction_scores,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(self,
|
|
input_ids,
|
|
past=None,
|
|
attention_mask=None,
|
|
**model_kwargs):
|
|
input_shape = input_ids.shape
|
|
# if model is used as a decoder in encoder-decoder model,
|
|
# the decoder attention mask is created on the fly
|
|
if attention_mask is None:
|
|
attention_mask = input_ids.new_ones(input_shape)
|
|
|
|
# cut decoder_input_ids if past is used
|
|
if past is not None:
|
|
input_ids = input_ids[:, -1:]
|
|
|
|
return {
|
|
'input_ids':
|
|
input_ids,
|
|
'attention_mask':
|
|
attention_mask,
|
|
'past_key_values':
|
|
past,
|
|
'encoder_hidden_states':
|
|
model_kwargs.get('encoder_hidden_states', None),
|
|
'encoder_attention_mask':
|
|
model_kwargs.get('encoder_attention_mask', None),
|
|
'is_decoder':
|
|
True,
|
|
}
|
|
|
|
def _reorder_cache(self, past, beam_idx):
|
|
reordered_past = ()
|
|
for layer_past in past:
|
|
reordered_past += (tuple(
|
|
past_state.index_select(0, beam_idx)
|
|
for past_state in layer_past), )
|
|
return reordered_past
|
|
|
|
|
|
@MODELS.register_module()
|
|
class XBertLMHeadDecoder(BertLMHeadModel):
|
|
"""This class decouples the decoder forward logic from the VL model.
|
|
|
|
In this way, different VL models can share this decoder as long as they
|
|
feed encoder_embeds as required.
|
|
"""
|
|
|
|
def __init__(self, med_config):
|
|
self.med_config = BertConfig.from_dict(med_config)
|
|
super(XBertLMHeadDecoder, self).__init__(config=self.med_config)
|
|
|
|
def generate_from_encoder(self,
|
|
tokenized_prompt,
|
|
visual_embeds,
|
|
sep_token_id,
|
|
pad_token_id,
|
|
use_nucleus_sampling=False,
|
|
num_beams=3,
|
|
max_length=30,
|
|
min_length=10,
|
|
top_p=0.9,
|
|
repetition_penalty=1.0,
|
|
**kwargs):
|
|
|
|
if not use_nucleus_sampling:
|
|
num_beams = num_beams
|
|
visual_embeds = visual_embeds.repeat_interleave(num_beams, dim=0)
|
|
|
|
image_atts = torch.ones(
|
|
visual_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
|
|
|
model_kwargs = {
|
|
'encoder_hidden_states': visual_embeds,
|
|
'encoder_attention_mask': image_atts,
|
|
}
|
|
|
|
if use_nucleus_sampling:
|
|
# nucleus sampling
|
|
outputs = self.generate(
|
|
input_ids=tokenized_prompt.input_ids,
|
|
max_length=max_length,
|
|
min_length=min_length,
|
|
do_sample=True,
|
|
top_p=top_p,
|
|
num_return_sequences=1,
|
|
eos_token_id=sep_token_id,
|
|
pad_token_id=pad_token_id,
|
|
repetition_penalty=1.1,
|
|
**model_kwargs)
|
|
else:
|
|
# beam search
|
|
outputs = self.generate(
|
|
input_ids=tokenized_prompt.input_ids,
|
|
max_length=max_length,
|
|
min_length=min_length,
|
|
num_beams=num_beams,
|
|
eos_token_id=sep_token_id,
|
|
pad_token_id=pad_token_id,
|
|
repetition_penalty=repetition_penalty,
|
|
**model_kwargs)
|
|
|
|
return outputs
|