import torch import torch.utils.checkpoint from torch import nn, einsum from einops import rearrange from einops_exts import rearrange_many from typing import List, Optional, Tuple, Union from maskrcnn_benchmark.utils.torch_dropout import Dropout1d import random from collections import OrderedDict from transformers.models.bert.modeling_bert import BertModel, BertEncoder, BertEmbeddings,\ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,\ logger, \ add_start_docstrings_to_model_forward, add_code_sample_docstrings, \ BERT_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC # import torch.nn.utils.rnn as rnn_utils # def get_index_with_padding_batch(a, padding_value=0): # ''' # Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length. # Inputs: # a - (B, M, N) # Outputs: # torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a. # ''' # # Compute the indices of non-zero elements # indices = [torch.nonzero(row)[:, 0] for row in a.reshape(-1, a.shape[-1])] # # Pad sequences and reshape back to the original shape # padded_indices = rnn_utils.pad_sequence(indices, batch_first=True, padding_value=padding_value) # padded_indices = padded_indices.view(a.shape[0], a.shape[1], -1) # return padded_indices @torch.no_grad() def get_index_with_padding_batch(a, padding_value=None): ''' Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length. Inputs: a - (B, M, N) Outputs: torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a. Note!!! padding_value == N, namely, concat a zero vector at the end of vision query as a candidate padding token. ''' if padding_value is None: padding_value = a.shape[-1] else: assert padding_value == a.shape[-1] # Get the indices of non-zero elements, then insert the indices into a new tensor with all padding value. non_zero = (a != 0) max_length = non_zero.sum(-1).max() indices = torch.where(non_zero, torch.arange(a.shape[-1], dtype=torch.long, device=a.device), torch.tensor(padding_value, dtype=torch.long, device=a.device)) # make valid indices at the begining of the tensor, and then split them out. padded_indices = indices.topk(k=max_length, dim=-1, largest=False).values return padded_indices[:, :, :max_length] # def get_index_with_padding_batch(a, padding_value=0): # # TODO: more efficient implement # ''' # Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length. # Inputs: # a - (B, M, N) # Outputs: # torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a. # ''' # B, M, N = a.shape # index_list = [] # max_length = 0 # for i in range(B): # row_indices = [] # for j in range(M): # row_index = torch.nonzero(a[i, j]).squeeze().tolist() # row_indices.append(row_index) # if len(row_index) > max_length: # max_length = len(row_index) # index_list.append(row_indices) # for i in range(len(index_list)): # for j in range(len(index_list[i])): # diff = max_length - len(index_list[i][j]) # index_list[i][j] += [padding_value] * diff # diff = M - len(index_list[i]) # index_list[i] += [[padding_value] * max_length] * diff # return torch.tensor(index_list, device=a.device) def easy_gather(x, indices): # x: B,N,C; indices: B,N B, N, C = x.shape N_new = indices.shape[1] offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N indices = indices + offset out = x.flatten(0,1)[indices.view(-1)].view(B, N_new, C) return out # gated cross attention def exists(val): if val is not None: if len(val) > 0: return True else: return False else: return False def FeedForward(dim, mult = 4, out_dim = None): inner_dim = int(dim * mult) if out_dim is None: out_dim = dim return nn.Sequential( OrderedDict([ ('norm', nn.LayerNorm(dim)), ('linear1', nn.Linear(dim, inner_dim, bias = False)), ('gelu', nn.GELU()), ('linear2', nn.Linear(inner_dim, out_dim, bias = False)) ]) ) class MaskedCrossAttention(nn.Module): def __init__( self, *, input_dim, output_dim = None, dim_head = 64, heads = 8, norm_kv = False, share_kv=False, cfg=None, spase_forward=False, ): super().__init__() self.spase_forward=spase_forward self.scale = dim_head ** -0.5 self.heads = heads self.share_kv=share_kv inner_dim = dim_head * heads if output_dim is None: output_dim = input_dim self.norm = nn.LayerNorm(input_dim) self.norm_kv = None if norm_kv: self.norm_kv = nn.LayerNorm(input_dim) self.to_q = nn.Linear(input_dim, inner_dim, bias = False) if share_kv: self.to_kv = nn.Linear(input_dim, inner_dim, bias = False) else: self.to_kv = nn.Linear(input_dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, output_dim, bias = False) @classmethod def _construct_sparse_inputs(cls, x, vision, attention_mask): ''' Make each text token only attends to a fix number of query vision tokens (typically a small number). Inputs: x - (batch, text, dim) vision - (batch, vision, dim) attention_mask - (batch, vision, text) Outputs: x - (batch * text, 1, dim) vision - (batch * text, num_suport_per_class, dim) e.g., num_suport_per_class = 5 attention_mask: mask padding tokens - (batch * text, 1, num_suport_per_class) ''' B, V, C = vision.shape # batch, vision, dim vision=torch.cat([vision, vision.new_zeros(B, 1, C)], dim=1) # B, N+1, C padding_index=V index = get_index_with_padding_batch(attention_mask.transpose(2,1), padding_value=padding_index) B, T, S = index.shape # batch, text, num_querys vision=easy_gather(vision, index.flatten(1,2)).reshape(B, T, S, C) x = x[:,:,None,...] new_mask=(index[:,:,None,...] != padding_index) # batch, vision, text new_mask=new_mask.transpose(-2,-1) # batch, vision, text return x.flatten(0,1), vision.flatten(0,1), new_mask.flatten(0,1) def forward( self, x, # (batch, text, dim) vision, # (batch, vision, dim) attention_mask = None, # (batch, vision, text) ): if self.spase_forward: batch_size = x.shape[0] x, vision, attention_mask = self._construct_sparse_inputs(x, vision, attention_mask) vision = vision.to(x.dtype) b, v, d = vision.shape h = self.heads x = self.norm(x) if self.norm_kv: vision = self.norm_kv(vision) q = self.to_q(x) # vision = rearrange(vision, 'b s v d -> b (s v) d') if self.share_kv: k = v = self.to_kv(vision) else: k, v = self.to_kv(vision).chunk(2, dim = -1) q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) q = q * self.scale sim = einsum('... i d, ... j d -> ... i j', q, k) # (batch, heads, sequence, vision) if exists(attention_mask): sim=rearrange(sim, 'b h t v -> b h v t') mask = sim.new_zeros(attention_mask.shape) # (b, v, t) mask[attention_mask==0] = -1e4 # for half mask=mask[:, None, ...] # (b, 1, v, t) sim = sim + mask sim=rearrange(sim, 'b h v t -> b h t v') attn = sim.softmax(dim = -1) if exists(attention_mask): attn=rearrange(attn, 'b h t v -> b v t h') attn = attn * attention_mask[..., None] # make sure some ignored tokens got all zero attention # attn[attention_mask==0] = 0 attn=rearrange(attn, 'b v t h -> b h t v') out = einsum('... i j, ... j d -> ... i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') if self.spase_forward: assert out.shape[1]==1 out = rearrange(out, '(b t) n d -> b t (n d)', b=batch_size) out = self.to_out(out) # # self-update: update those with all zero attention masks by themselves, for option 1! # if exists(attention_mask): # update_mask = (attention_mask.sum(1)==0) # (b,t) # # assert out[update_mask].sum()==0 # out = x * update_mask[..., None] + out return out class GatedCrossAttentionBlock(nn.Module): ''' For each target category, extract one roi feature on each scale, i.e., (batch, scales, latents, dim_v), latents always = k shot. "latents" denotes the total length of all vison tokens at each scale. If the attention mask of vision v to all text t is False, return the original text embedding. ''' def __init__( self, *, dim, dim_head = 64, heads = 8, ff_mult = 4, share_kv=False, cfg=None, enable_ffn = True ): super().__init__() self.attn = MaskedCrossAttention(input_dim = dim, dim_head = dim_head, heads = heads, share_kv=share_kv, cfg=cfg, norm_kv=True, spase_forward=True) if cfg.VISION_QUERY.FIX_ATTN_GATE == -1.0: if cfg.VISION_QUERY.CONDITION_GATE: if cfg.VISION_QUERY.NONLINEAR_GATE: if cfg.VISION_QUERY.NO_CAT: self.attn_gate = FeedForward(dim=dim, mult=0.5, out_dim = 1) torch.nn.init.constant_(self.attn_gate.linear2.weight, 0) else: self.attn_gate = FeedForward(dim=dim*2, mult=0.5, out_dim = 1) torch.nn.init.constant_(self.attn_gate.linear2.weight, 0) else: self.attn_gate = nn.Linear(dim, 1, bias=False) torch.nn.init.constant_(self.attn_gate.weight, 0) else: self.attn_gate = nn.Parameter(torch.tensor([0.])) # if cfg.VISION_QUERY.TEXT_DROPOUT > 0.: # self.mask_token = nn.Parameter(torch.randn(dim)) self.enable_ffn = enable_ffn if enable_ffn: self.ff = FeedForward(dim, mult = ff_mult) if cfg.VISION_QUERY.FIX_ATTN_GATE == -1.0: self.ff_gate = nn.Parameter(torch.tensor([0.])) if cfg.VISION_QUERY.ADD_ADAPT_LAYER: self.adaptor = FeedForward(dim, mult = 2) # self.text_dropout=Dropout1d(p=cfg.VISION_QUERY.TEXT_DROPOUT) self.cfg=cfg self.attn_gate_value = 0. def forward( self, x, # text tensor - (batch, text, dim_t) vision, # vision query tensor - (batch, vision, dim_v) attention_mask = None, # boolean tensor indicating masks of media - (batch, vision, text) batched_positive_label_position = None, # batch: {label: (positions)} ): # assert exists(attention_mask) ## do not drop pure text or padding # dropped_x = x # if exists(attention_mask): # dropped_x = self.text_dropout(x) # drooped_mask = (dropped_x.sum(-1)==0) # (b,t) # update_mask = (attention_mask.sum(1)==0) # (b,t) # mask = drooped_mask * update_mask # dropped_x = dropped_x + x * mask # # do not drop pure text or padding, for option 2! # dropped_x = self.text_dropout(x) # drooped_mask = (dropped_x.sum(-1)==0) # (b,t) # update_mask = (attention_mask.sum(1)==0) # (b,t) # mask = drooped_mask * update_mask # dropped_x = dropped_x + x * mask # option1: (1-a)*x1 + a*x2, a \in (0,1) # option2: x1 + a*x2, a \in (-1,1) ## if option2, text drop may be conducted here. Not test yet. ## if option1, text drop may be conducted in MaskedCrossAttention # # Only mask text with vision query # # Only mask text with positive categories # if self.cfg.VISION_QUERY.TEXT_DROPOUT > 0. and self.training: # mask=x.new_zeros(x.shape[:2], dtype=torch.bool) # (batch, text) # pure_text_mask=attention_mask.sum(1) # (batch, text) # for i, pos_label_position in enumerate(batched_positive_label_position): # pos_label_position=pos_label_position.to(torch.bool) # for position in pos_label_position: # text_with_vision_query = (pure_text_mask[i, position].sum()!=0) # if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT) and text_with_vision_query: # mask[i, position] = True # if self.training: # dropped_x = x.clone() # dropped_x[mask] = self.mask_token # else: # dropped_x = x if self.cfg.VISION_QUERY.ADD_ADAPT_LAYER: vision = self.adaptor(vision) + vision dropped_x = x supported_x = self.attn(x, vision, attention_mask = attention_mask) # dropped_x = self.text_dropout(x) if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0: attn_gate = self.cfg.VISION_QUERY.FIX_ATTN_GATE else: if self.cfg.VISION_QUERY.CONDITION_GATE: if self.cfg.VISION_QUERY.NO_CAT or not (self.cfg.VISION_QUERY.NONLINEAR_GATE): attn_gate = self.attn_gate(supported_x).tanh() else: attn_gate = self.attn_gate(torch.cat([supported_x, dropped_x], dim = -1)).tanh() else: attn_gate = self.attn_gate.tanh() if self.cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE: with torch.no_grad(): self.attn_gate_value = attn_gate.mean().item() x = supported_x * attn_gate + dropped_x if self.enable_ffn: if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0: x = self.ff(x) * self.cfg.VISION_QUERY.FIX_ATTN_GATE + x else: x = self.ff(x) * self.ff_gate.tanh() + x return x class PreSelectBlock(nn.Module): def __init__( self, *, dim, out_dim = None, dim_head = 32, heads = 8, ff_mult = 4, share_kv=False, cfg=None, ): super().__init__() self.image_condition = MaskedCrossAttention(input_dim = dim, output_dim = out_dim, dim_head = dim_head, heads = heads, norm_kv=True, share_kv=share_kv, cfg=cfg, spase_forward=False) self.ff = FeedForward(out_dim, mult = ff_mult) if dim != out_dim: self.res_mapping = nn.Linear(in_features=dim, out_features=out_dim, bias=False) else: self.res_mapping = nn.Identity() def forward( self, x, ): vision=x['vision'] # vision query tensor - (batch, vision, dim_v) image=x['image'] # query images - (batch, image, dim_v) # b, s, v, d = vision.shape # vision = rearrange(vision, 'b s v d -> b (s v) d') vision = self.image_condition(vision, image) + self.res_mapping(vision) vision = self.ff(vision) + vision # vision = rearrange(vision, 'b (s v) d -> b s v d', s=s) return {'vision': vision, 'image': image} class PreSelectModule(nn.Module): def __init__( self, *, dim, out_dim, dim_head = 32, heads = 8, ff_mult = 4, num_layers = 2, share_kv=False, cfg=None ): super().__init__() layers = [PreSelectBlock(dim=dim, out_dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg) for _ in range(num_layers-1)] layers.append(PreSelectBlock(dim=dim, out_dim=out_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg)) self.layers = nn.Sequential(*layers) self.scale=cfg.VISION_QUERY.VISION_SCALE self.augment_image_with_query = cfg.VISION_QUERY.AUGMENT_IMAGE_WITH_QUERY if self.augment_image_with_query: assert len(self.layers) > 1 def forward( self, vision, # vision query tensor - (batch, scales, vision, dim_v) image, # query images - (batch, scales, image, dim_v) ): vision = vision * self.scale image = image * self.scale if self.augment_image_with_query: x = self.layers[0]({'vision': image, 'image': vision}) x = {'vision': x['image'], 'image': x['vision']} for layer in self.layers[1:]: x = layer(x) return x else: x = {'vision': vision, 'image': image} return self.layers(x) class QVBertEmbeddings(BertEmbeddings): def __init__(self, config, cfg): super().__init__(config) self.cfg=cfg if (self.cfg.VISION_QUERY.TEXT_DROPOUT > 0.) and (cfg.VISION_QUERY.NEW_MASK_TOKEN): self.mask_tok_qv_layer = nn.Parameter(torch.randn(config.hidden_size)) # add qv_layer to name only for easier paramter freezing def forward( self, input_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values_length: int = 0, batched_pos_category_map=None, ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # issue #5664 if token_type_ids is None: if hasattr(self, "token_type_ids"): buffered_token_type_ids = self.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) if (self.cfg.VISION_QUERY.TEXT_DROPOUT > 0.) and (batched_pos_category_map is not None) and (self.cfg.VISION_QUERY.NEW_MASK_TOKEN) and (self.training): raise NotImplementedError mask_tok_qv_layer = self.mask_tok_qv_layer.to(inputs_embeds.dtype) # inputs_embeds_ = [] # for emb, pos_label_position in zip(inputs_embeds, batched_pos_category_map): # pos_label_position=pos_label_position.to(torch.bool) # for position in pos_label_position: # if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT): # emb=torch.scatter(emb, dim=0, index=position.nonzero()[0][..., None], src=mask_tok_qv_layer[None, ...]) # inputs_embeds_.append(emb) # inputs_embeds = torch.stack(inputs_embeds_) inputs_embeds = inputs_embeds.clone() for i, pos_label_position in enumerate(batched_pos_category_map): pos_label_position=pos_label_position.to(torch.bool) for position in pos_label_position: if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT): inputs_embeds[i, position] = mask_tok_qv_layer token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class QVBertEncoder(BertEncoder): ''' add qv_layer at each bert_layer that deeper than start_qv_layer_index ''' def __init__(self, config, dim, dim_head = 64, heads = 8, ff_mult = 4, start_qv_layer_index = 6, # which layer to start fusing vision share_kv=False, cfg=None, ): super().__init__(config=config) self.start_qv_layer_index = start_qv_layer_index num_hidden_layers = config.num_hidden_layers assert start_qv_layer_index < num_hidden_layers num_qv_layers = num_hidden_layers - start_qv_layer_index self.qv_layer = nn.ModuleList([GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg) for _ in range(num_qv_layers)]) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, vision: Optional[torch.Tensor] = None, vision_attention_mask: Optional[torch.Tensor] = None, batched_pos_category_map=None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if i >= self.start_qv_layer_index and exists(vision): qv_index = i - self.start_qv_layer_index hidden_states = self.qv_layer[qv_index](hidden_states, vision, vision_attention_mask, batched_pos_category_map) if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class QVBertModel(BertModel): def __init__(self, config, dim_t, dim_v, dim_head_t = 64, dim_head_v = 32, heads = 8, ff_mult = 4, num_pre_select_layers = 2, share_kv = False, cfg=None, **kwargs): super().__init__(config=config, **kwargs) self.cfg=cfg self.embeddings = QVBertEmbeddings(config, cfg) self.encoder = QVBertEncoder(config=config, dim=dim_t, dim_head=dim_head_t, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg) self.pre_select = PreSelectModule(dim=dim_v, out_dim=dim_t, dim_head=dim_head_v, heads=heads,ff_mult=ff_mult, num_layers=num_pre_select_layers, share_kv=share_kv, cfg=cfg) # if cfg.VISION_QUERY.NEW_MASK_TOKEN: # self.mask_tok_qv_layer = nn.Parameter(torch.randn(config.hidden_size)) # add qv_layer to name only for easier paramter freezing def get_gate_value(self): attn_gates=[] ff_gates=[] for blk in self.encoder.qv_layer: # try: if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0: attn_gates.append(torch.tensor([self.cfg.VISION_QUERY.FIX_ATTN_GATE], device=self.embeddings.word_embeddings.weight.device)) ff_gates.append(torch.tensor([self.cfg.VISION_QUERY.FIX_ATTN_GATE], device=self.embeddings.word_embeddings.weight.device)) else: if not self.cfg.VISION_QUERY.CONDITION_GATE: attn_gates.append(blk.attn_gate) else: if self.cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE: attn_gates.append(blk.attn_gate_value) else: pass # except: # pass ff_gates.append(blk.ff_gate) return {'attn_gates': attn_gates, 'ffn_gates': ff_gates} @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, vision: Optional[torch.Tensor] = None, # (batch, vision, dim) images: Optional[torch.Tensor] = None, # (batch, image, dim) vision_attention_mask: Optional[torch.Tensor] = None, batched_pos_category_map = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, batched_pos_category_map = batched_pos_category_map, ) # if self.cfg.VISION_QUERY.TEXT_DROPOUT > 0. and batched_pos_category_map is not None and self.training: # if self.cfg.VISION_QUERY.NEW_MASK_TOKEN: # # embedding_output = embedding_output.clone() # mask_tok_qv_layer = self.mask_tok_qv_layer.to(embedding_output.dtype) # for i, pos_label_position in enumerate(batched_pos_category_map): # pos_label_position=pos_label_position.to(torch.bool) # for position in pos_label_position: # if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT): # embedding_output[i, position] = mask_tok_qv_layer augmented_vision = None if (exists(images) and exists(vision)): vision = self.pre_select(vision, images)['vision'] augmented_vision = vision encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, vision=vision, vision_attention_mask=vision_attention_mask, batched_pos_category_map=batched_pos_category_map, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] out=BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) # if self.cfg.VISION_QUERY.GATE_REGULARIZATION: out['vision_query_gates'] = self.get_gate_value() if self.cfg.VISION_QUERY.QUERY_FUSION: out['augmented_vision'] = augmented_vision out['vision_attention_mask'] = vision_attention_mask return out