MQ-Det/maskrcnn_benchmark/modeling/language_backbone/modeling_bert_new.py

853 lines
36 KiB
Python

import torch
import torch.utils.checkpoint
from torch import nn, einsum
from einops import rearrange
from einops_exts import rearrange_many
from typing import List, Optional, Tuple, Union
from maskrcnn_benchmark.utils.torch_dropout import Dropout1d
import random
from collections import OrderedDict
from transformers.models.bert.modeling_bert import BertModel, BertEncoder, BertEmbeddings,\
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,\
logger, \
add_start_docstrings_to_model_forward, add_code_sample_docstrings, \
BERT_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC
# import torch.nn.utils.rnn as rnn_utils
# def get_index_with_padding_batch(a, padding_value=0):
# '''
# Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length.
# Inputs:
# a - (B, M, N)
# Outputs:
# torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a.
# '''
# # Compute the indices of non-zero elements
# indices = [torch.nonzero(row)[:, 0] for row in a.reshape(-1, a.shape[-1])]
# # Pad sequences and reshape back to the original shape
# padded_indices = rnn_utils.pad_sequence(indices, batch_first=True, padding_value=padding_value)
# padded_indices = padded_indices.view(a.shape[0], a.shape[1], -1)
# return padded_indices
@torch.no_grad()
def get_index_with_padding_batch(a, padding_value=None):
'''
Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length.
Inputs:
a - (B, M, N)
Outputs:
torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a.
Note!!!
padding_value == N, namely, concat a zero vector at the end of vision query as a candidate padding token.
'''
if padding_value is None:
padding_value = a.shape[-1]
else:
assert padding_value == a.shape[-1]
# Get the indices of non-zero elements, then insert the indices into a new tensor with all padding value.
non_zero = (a != 0)
max_length = non_zero.sum(-1).max()
indices = torch.where(non_zero, torch.arange(a.shape[-1], dtype=torch.long, device=a.device), torch.tensor(padding_value, dtype=torch.long, device=a.device))
# make valid indices at the begining of the tensor, and then split them out.
padded_indices = indices.topk(k=max_length, dim=-1, largest=False).values
return padded_indices[:, :, :max_length]
# def get_index_with_padding_batch(a, padding_value=0):
# # TODO: more efficient implement
# '''
# Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length.
# Inputs:
# a - (B, M, N)
# Outputs:
# torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a.
# '''
# B, M, N = a.shape
# index_list = []
# max_length = 0
# for i in range(B):
# row_indices = []
# for j in range(M):
# row_index = torch.nonzero(a[i, j]).squeeze().tolist()
# row_indices.append(row_index)
# if len(row_index) > max_length:
# max_length = len(row_index)
# index_list.append(row_indices)
# for i in range(len(index_list)):
# for j in range(len(index_list[i])):
# diff = max_length - len(index_list[i][j])
# index_list[i][j] += [padding_value] * diff
# diff = M - len(index_list[i])
# index_list[i] += [[padding_value] * max_length] * diff
# return torch.tensor(index_list, device=a.device)
def easy_gather(x, indices):
# x: B,N,C; indices: B,N
B, N, C = x.shape
N_new = indices.shape[1]
offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
indices = indices + offset
out = x.flatten(0,1)[indices.view(-1)].view(B, N_new, C)
return out
# gated cross attention
def exists(val):
if val is not None:
if len(val) > 0:
return True
else:
return False
else:
return False
def FeedForward(dim, mult = 4, out_dim = None):
inner_dim = int(dim * mult)
if out_dim is None:
out_dim = dim
return nn.Sequential(
OrderedDict([
('norm', nn.LayerNorm(dim)),
('linear1', nn.Linear(dim, inner_dim, bias = False)),
('gelu', nn.GELU()),
('linear2', nn.Linear(inner_dim, out_dim, bias = False))
])
)
class MaskedCrossAttention(nn.Module):
def __init__(
self,
*,
input_dim,
output_dim = None,
dim_head = 64,
heads = 8,
norm_kv = False,
share_kv=False,
cfg=None,
spase_forward=False,
):
super().__init__()
self.spase_forward=spase_forward
self.scale = dim_head ** -0.5
self.heads = heads
self.share_kv=share_kv
inner_dim = dim_head * heads
if output_dim is None:
output_dim = input_dim
self.norm = nn.LayerNorm(input_dim)
self.norm_kv = None
if norm_kv:
self.norm_kv = nn.LayerNorm(input_dim)
self.to_q = nn.Linear(input_dim, inner_dim, bias = False)
if share_kv:
self.to_kv = nn.Linear(input_dim, inner_dim, bias = False)
else:
self.to_kv = nn.Linear(input_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, output_dim, bias = False)
@classmethod
def _construct_sparse_inputs(cls, x, vision, attention_mask):
'''
Make each text token only attends to a fix number of query vision tokens (typically a small number).
Inputs:
x - (batch, text, dim)
vision - (batch, vision, dim)
attention_mask - (batch, vision, text)
Outputs:
x - (batch * text, 1, dim)
vision - (batch * text, num_suport_per_class, dim) e.g., num_suport_per_class = 5
attention_mask: mask padding tokens - (batch * text, 1, num_suport_per_class)
'''
B, V, C = vision.shape # batch, vision, dim
vision=torch.cat([vision, vision.new_zeros(B, 1, C)], dim=1) # B, N+1, C
padding_index=V
index = get_index_with_padding_batch(attention_mask.transpose(2,1), padding_value=padding_index)
B, T, S = index.shape # batch, text, num_querys
vision=easy_gather(vision, index.flatten(1,2)).reshape(B, T, S, C)
x = x[:,:,None,...]
new_mask=(index[:,:,None,...] != padding_index) # batch, vision, text
new_mask=new_mask.transpose(-2,-1) # batch, vision, text
return x.flatten(0,1), vision.flatten(0,1), new_mask.flatten(0,1)
def forward(
self,
x, # (batch, text, dim)
vision, # (batch, vision, dim)
attention_mask = None, # (batch, vision, text)
):
if self.spase_forward:
batch_size = x.shape[0]
x, vision, attention_mask = self._construct_sparse_inputs(x, vision, attention_mask)
vision = vision.to(x.dtype)
b, v, d = vision.shape
h = self.heads
x = self.norm(x)
if self.norm_kv:
vision = self.norm_kv(vision)
q = self.to_q(x)
# vision = rearrange(vision, 'b s v d -> b (s v) d')
if self.share_kv:
k = v = self.to_kv(vision)
else:
k, v = self.to_kv(vision).chunk(2, dim = -1)
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
q = q * self.scale
sim = einsum('... i d, ... j d -> ... i j', q, k) # (batch, heads, sequence, vision)
if exists(attention_mask):
sim=rearrange(sim, 'b h t v -> b h v t')
mask = sim.new_zeros(attention_mask.shape) # (b, v, t)
mask[attention_mask==0] = -1e4 # for half
mask=mask[:, None, ...] # (b, 1, v, t)
sim = sim + mask
sim=rearrange(sim, 'b h v t -> b h t v')
attn = sim.softmax(dim = -1)
if exists(attention_mask):
attn=rearrange(attn, 'b h t v -> b v t h')
attn = attn * attention_mask[..., None] # make sure some ignored tokens got all zero attention
# attn[attention_mask==0] = 0
attn=rearrange(attn, 'b v t h -> b h t v')
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
if self.spase_forward:
assert out.shape[1]==1
out = rearrange(out, '(b t) n d -> b t (n d)', b=batch_size)
out = self.to_out(out)
# # self-update: update those with all zero attention masks by themselves, for option 1!
# if exists(attention_mask):
# update_mask = (attention_mask.sum(1)==0) # (b,t)
# # assert out[update_mask].sum()==0
# out = x * update_mask[..., None] + out
return out
class GatedCrossAttentionBlock(nn.Module):
'''
For each target category, extract one roi feature on each scale, i.e., (batch, scales, latents, dim_v), latents always = k shot.
"latents" denotes the total length of all vison tokens at each scale.
If the attention mask of vision v to all text t is False, return the original text embedding.
'''
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
share_kv=False,
cfg=None,
enable_ffn = True
):
super().__init__()
self.attn = MaskedCrossAttention(input_dim = dim, dim_head = dim_head, heads = heads, share_kv=share_kv, cfg=cfg, norm_kv=True, spase_forward=True)
if cfg.VISION_QUERY.FIX_ATTN_GATE == -1.0:
if cfg.VISION_QUERY.CONDITION_GATE:
if cfg.VISION_QUERY.NONLINEAR_GATE:
if cfg.VISION_QUERY.NO_CAT:
self.attn_gate = FeedForward(dim=dim, mult=0.5, out_dim = 1)
torch.nn.init.constant_(self.attn_gate.linear2.weight, 0)
else:
self.attn_gate = FeedForward(dim=dim*2, mult=0.5, out_dim = 1)
torch.nn.init.constant_(self.attn_gate.linear2.weight, 0)
else:
self.attn_gate = nn.Linear(dim, 1, bias=False)
torch.nn.init.constant_(self.attn_gate.weight, 0)
else:
self.attn_gate = nn.Parameter(torch.tensor([0.]))
# if cfg.VISION_QUERY.TEXT_DROPOUT > 0.:
# self.mask_token = nn.Parameter(torch.randn(dim))
self.enable_ffn = enable_ffn
if enable_ffn:
self.ff = FeedForward(dim, mult = ff_mult)
if cfg.VISION_QUERY.FIX_ATTN_GATE == -1.0:
self.ff_gate = nn.Parameter(torch.tensor([0.]))
if cfg.VISION_QUERY.ADD_ADAPT_LAYER:
self.adaptor = FeedForward(dim, mult = 2)
# self.text_dropout=Dropout1d(p=cfg.VISION_QUERY.TEXT_DROPOUT)
self.cfg=cfg
self.attn_gate_value = 0.
def forward(
self,
x, # text tensor - (batch, text, dim_t)
vision, # vision query tensor - (batch, vision, dim_v)
attention_mask = None, # boolean tensor indicating masks of media - (batch, vision, text)
batched_positive_label_position = None, # batch: {label: (positions)}
):
# assert exists(attention_mask)
## do not drop pure text or padding
# dropped_x = x
# if exists(attention_mask):
# dropped_x = self.text_dropout(x)
# drooped_mask = (dropped_x.sum(-1)==0) # (b,t)
# update_mask = (attention_mask.sum(1)==0) # (b,t)
# mask = drooped_mask * update_mask
# dropped_x = dropped_x + x * mask
# # do not drop pure text or padding, for option 2!
# dropped_x = self.text_dropout(x)
# drooped_mask = (dropped_x.sum(-1)==0) # (b,t)
# update_mask = (attention_mask.sum(1)==0) # (b,t)
# mask = drooped_mask * update_mask
# dropped_x = dropped_x + x * mask
# option1: (1-a)*x1 + a*x2, a \in (0,1)
# option2: x1 + a*x2, a \in (-1,1)
## if option2, text drop may be conducted here. Not test yet.
## if option1, text drop may be conducted in MaskedCrossAttention
# # Only mask text with vision query
# # Only mask text with positive categories
# if self.cfg.VISION_QUERY.TEXT_DROPOUT > 0. and self.training:
# mask=x.new_zeros(x.shape[:2], dtype=torch.bool) # (batch, text)
# pure_text_mask=attention_mask.sum(1) # (batch, text)
# for i, pos_label_position in enumerate(batched_positive_label_position):
# pos_label_position=pos_label_position.to(torch.bool)
# for position in pos_label_position:
# text_with_vision_query = (pure_text_mask[i, position].sum()!=0)
# if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT) and text_with_vision_query:
# mask[i, position] = True
# if self.training:
# dropped_x = x.clone()
# dropped_x[mask] = self.mask_token
# else:
# dropped_x = x
if self.cfg.VISION_QUERY.ADD_ADAPT_LAYER:
vision = self.adaptor(vision) + vision
dropped_x = x
supported_x = self.attn(x, vision, attention_mask = attention_mask)
# dropped_x = self.text_dropout(x)
if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0:
attn_gate = self.cfg.VISION_QUERY.FIX_ATTN_GATE
else:
if self.cfg.VISION_QUERY.CONDITION_GATE:
if self.cfg.VISION_QUERY.NO_CAT or not (self.cfg.VISION_QUERY.NONLINEAR_GATE):
attn_gate = self.attn_gate(supported_x).tanh()
else:
attn_gate = self.attn_gate(torch.cat([supported_x, dropped_x], dim = -1)).tanh()
else:
attn_gate = self.attn_gate.tanh()
if self.cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE:
with torch.no_grad():
self.attn_gate_value = attn_gate.mean().item()
x = supported_x * attn_gate + dropped_x
if self.enable_ffn:
if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0:
x = self.ff(x) * self.cfg.VISION_QUERY.FIX_ATTN_GATE + x
else:
x = self.ff(x) * self.ff_gate.tanh() + x
return x
class PreSelectBlock(nn.Module):
def __init__(
self,
*,
dim,
out_dim = None,
dim_head = 32,
heads = 8,
ff_mult = 4,
share_kv=False,
cfg=None,
):
super().__init__()
self.image_condition = MaskedCrossAttention(input_dim = dim, output_dim = out_dim, dim_head = dim_head, heads = heads, norm_kv=True, share_kv=share_kv, cfg=cfg, spase_forward=False)
self.ff = FeedForward(out_dim, mult = ff_mult)
if dim != out_dim:
self.res_mapping = nn.Linear(in_features=dim, out_features=out_dim, bias=False)
else:
self.res_mapping = nn.Identity()
def forward(
self,
x,
):
vision=x['vision'] # vision query tensor - (batch, vision, dim_v)
image=x['image'] # query images - (batch, image, dim_v)
# b, s, v, d = vision.shape
# vision = rearrange(vision, 'b s v d -> b (s v) d')
vision = self.image_condition(vision, image) + self.res_mapping(vision)
vision = self.ff(vision) + vision
# vision = rearrange(vision, 'b (s v) d -> b s v d', s=s)
return {'vision': vision, 'image': image}
class PreSelectModule(nn.Module):
def __init__(
self,
*,
dim,
out_dim,
dim_head = 32,
heads = 8,
ff_mult = 4,
num_layers = 2,
share_kv=False,
cfg=None
):
super().__init__()
layers = [PreSelectBlock(dim=dim, out_dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg) for _ in range(num_layers-1)]
layers.append(PreSelectBlock(dim=dim, out_dim=out_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg))
self.layers = nn.Sequential(*layers)
self.scale=cfg.VISION_QUERY.VISION_SCALE
self.augment_image_with_query = cfg.VISION_QUERY.AUGMENT_IMAGE_WITH_QUERY
if self.augment_image_with_query:
assert len(self.layers) > 1
def forward(
self,
vision, # vision query tensor - (batch, scales, vision, dim_v)
image, # query images - (batch, scales, image, dim_v)
):
vision = vision * self.scale
image = image * self.scale
if self.augment_image_with_query:
x = self.layers[0]({'vision': image, 'image': vision})
x = {'vision': x['image'], 'image': x['vision']}
for layer in self.layers[1:]:
x = layer(x)
return x
else:
x = {'vision': vision, 'image': image}
return self.layers(x)
class QVBertEmbeddings(BertEmbeddings):
def __init__(self, config, cfg):
super().__init__(config)
self.cfg=cfg
if (self.cfg.VISION_QUERY.TEXT_DROPOUT > 0.) and (cfg.VISION_QUERY.NEW_MASK_TOKEN):
self.mask_tok_qv_layer = nn.Parameter(torch.randn(config.hidden_size)) # add qv_layer to name only for easier paramter freezing
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
batched_pos_category_map=None,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if (self.cfg.VISION_QUERY.TEXT_DROPOUT > 0.) and (batched_pos_category_map is not None) and (self.cfg.VISION_QUERY.NEW_MASK_TOKEN) and (self.training):
raise NotImplementedError
mask_tok_qv_layer = self.mask_tok_qv_layer.to(inputs_embeds.dtype)
# inputs_embeds_ = []
# for emb, pos_label_position in zip(inputs_embeds, batched_pos_category_map):
# pos_label_position=pos_label_position.to(torch.bool)
# for position in pos_label_position:
# if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT):
# emb=torch.scatter(emb, dim=0, index=position.nonzero()[0][..., None], src=mask_tok_qv_layer[None, ...])
# inputs_embeds_.append(emb)
# inputs_embeds = torch.stack(inputs_embeds_)
inputs_embeds = inputs_embeds.clone()
for i, pos_label_position in enumerate(batched_pos_category_map):
pos_label_position=pos_label_position.to(torch.bool)
for position in pos_label_position:
if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT):
inputs_embeds[i, position] = mask_tok_qv_layer
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class QVBertEncoder(BertEncoder):
'''
add qv_layer at each bert_layer that deeper than start_qv_layer_index
'''
def __init__(self,
config,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
start_qv_layer_index = 6, # which layer to start fusing vision
share_kv=False,
cfg=None,
):
super().__init__(config=config)
self.start_qv_layer_index = start_qv_layer_index
num_hidden_layers = config.num_hidden_layers
assert start_qv_layer_index < num_hidden_layers
num_qv_layers = num_hidden_layers - start_qv_layer_index
self.qv_layer = nn.ModuleList([GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg)
for _ in range(num_qv_layers)])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
vision: Optional[torch.Tensor] = None,
vision_attention_mask: Optional[torch.Tensor] = None,
batched_pos_category_map=None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if i >= self.start_qv_layer_index and exists(vision):
qv_index = i - self.start_qv_layer_index
hidden_states = self.qv_layer[qv_index](hidden_states, vision, vision_attention_mask, batched_pos_category_map)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class QVBertModel(BertModel):
def __init__(self,
config,
dim_t,
dim_v,
dim_head_t = 64,
dim_head_v = 32,
heads = 8,
ff_mult = 4,
num_pre_select_layers = 2,
share_kv = False,
cfg=None,
**kwargs):
super().__init__(config=config, **kwargs)
self.cfg=cfg
self.embeddings = QVBertEmbeddings(config, cfg)
self.encoder = QVBertEncoder(config=config, dim=dim_t, dim_head=dim_head_t, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg)
self.pre_select = PreSelectModule(dim=dim_v, out_dim=dim_t, dim_head=dim_head_v, heads=heads,ff_mult=ff_mult, num_layers=num_pre_select_layers, share_kv=share_kv, cfg=cfg)
# if cfg.VISION_QUERY.NEW_MASK_TOKEN:
# self.mask_tok_qv_layer = nn.Parameter(torch.randn(config.hidden_size)) # add qv_layer to name only for easier paramter freezing
def get_gate_value(self):
attn_gates=[]
ff_gates=[]
for blk in self.encoder.qv_layer:
# try:
if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0:
attn_gates.append(torch.tensor([self.cfg.VISION_QUERY.FIX_ATTN_GATE], device=self.embeddings.word_embeddings.weight.device))
ff_gates.append(torch.tensor([self.cfg.VISION_QUERY.FIX_ATTN_GATE], device=self.embeddings.word_embeddings.weight.device))
else:
if not self.cfg.VISION_QUERY.CONDITION_GATE:
attn_gates.append(blk.attn_gate)
else:
if self.cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE:
attn_gates.append(blk.attn_gate_value)
else:
pass
# except:
# pass
ff_gates.append(blk.ff_gate)
return {'attn_gates': attn_gates, 'ffn_gates': ff_gates}
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
vision: Optional[torch.Tensor] = None, # (batch, vision, dim)
images: Optional[torch.Tensor] = None, # (batch, image, dim)
vision_attention_mask: Optional[torch.Tensor] = None,
batched_pos_category_map = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
batched_pos_category_map = batched_pos_category_map,
)
# if self.cfg.VISION_QUERY.TEXT_DROPOUT > 0. and batched_pos_category_map is not None and self.training:
# if self.cfg.VISION_QUERY.NEW_MASK_TOKEN:
# # embedding_output = embedding_output.clone()
# mask_tok_qv_layer = self.mask_tok_qv_layer.to(embedding_output.dtype)
# for i, pos_label_position in enumerate(batched_pos_category_map):
# pos_label_position=pos_label_position.to(torch.bool)
# for position in pos_label_position:
# if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT):
# embedding_output[i, position] = mask_tok_qv_layer
augmented_vision = None
if (exists(images) and exists(vision)):
vision = self.pre_select(vision, images)['vision']
augmented_vision = vision
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
vision=vision,
vision_attention_mask=vision_attention_mask,
batched_pos_category_map=batched_pos_category_map,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
out=BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
# if self.cfg.VISION_QUERY.GATE_REGULARIZATION:
out['vision_query_gates'] = self.get_gate_value()
if self.cfg.VISION_QUERY.QUERY_FUSION:
out['augmented_vision'] = augmented_vision
out['vision_attention_mask'] = vision_attention_mask
return out