# flake8: noqa # modified from https://github.com/jayleicn/ClipBERT import math import sys import torch from torch import nn from transformers.modeling_utils import PreTrainedModel from ..registry import BACKBONES ACT2FN = {'gelu': nn.GELU(), 'relu': torch.nn.functional.relu} class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ def __init__(self, config): super(BertEmbeddings, self).__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=0) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size, padding_idx=0) self.token_type_embeddings = nn.Embedding( config.type_vocab_size, config.hidden_size, padding_idx=0) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): seq_length = input_ids.size(1) position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = words_embeddings + position_embeddings + token_type_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertPooler(nn.Module): def __init__(self, config): super(BertPooler, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertIntermediate(nn.Module): def __init__(self, config): super(BertIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super(BertOutput, self).__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertPreTrainedModel(PreTrainedModel): from transformers import BertConfig config_class = BertConfig # load_tf_weights = load_tf_weights_in_bert base_model_prefix = 'bert' def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_( mean=0.0, std=self.config.initializer_range) elif isinstance(module, BertLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class VisualInputEmbedding(nn.Module): """ Takes input of both image and video (multi-frame) """ def __init__(self, config): super(VisualInputEmbedding, self).__init__() self.config = config # sequence embedding self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.row_position_embeddings = nn.Embedding( config.max_grid_row_position_embeddings, config.hidden_size) self.col_position_embeddings = nn.Embedding( config.max_grid_col_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(1, config.hidden_size) self.LayerNorm = BertLayerNorm( config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, grid): """ Args: grid: (B, n_frm, H, W, C), note that #frm can be 1 Returns: """ bsz, _, _, _, hsz = grid.shape # temporal mean pooling # grid = grid.mean(1) # (B, H, W, d) grid = self.add_2d_positional_embeddings(grid) # (B, H, W, d) # image token sequence visual_tokens = grid.view(bsz, -1, hsz) # (B, H*W, d) # perform random sampling. It is only used in training phase # of pre-training, but not used in inference or downstream tasks. # if hasattr(self.config, "pixel_random_sampling_size") and \ # self.config.pixel_random_sampling_size > 0 and self.training: # sampled_indices = get_random_sample_indices( # seq_len=visual_tokens.shape[1], # num_samples=self.config.pixel_random_sampling_size, # device=visual_tokens.device # ) # visual_tokens = visual_tokens.index_select( # dim=1, index=sampled_indices) # (B, #samples, d) visual_tokens_shape = visual_tokens.shape[:-1] # (B, H*W) device = visual_tokens.device # image token type embeddings. token_type_ids = torch.zeros( visual_tokens_shape, dtype=torch.long, device=device) token_type_embeddings = self.token_type_embeddings(token_type_ids) # embeddings = visual_tokens + position_embeddings + token_type_embeddings embeddings = visual_tokens + token_type_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings # (B, H*W, d) def add_temporal_postion_embeddings(self, grid): """ Args: grid: (B, n_frms, H, W, d) Returns: (B, n_frms, H, W, d) """ n_frms, height, width, hsz = grid.shape[-4:] # add row-wise position embeddings temporal_position_ids = torch.arange( n_frms, dtype=torch.long, device=grid.device) # (n_frms, ) t_position_embeddings = self.temporal_position_embeddings( temporal_position_ids) # (n_frms, d) new_shape = (1, n_frms, 1, 1, hsz) # (1, n_frms, 1, 1, d) grid = grid + t_position_embeddings.view( *new_shape) # broadcast automatically return grid def add_2d_positional_embeddings(self, grid): """ Args: grid: (B, *, H, W, d) Returns: (B, *, H, W, d) """ height, width, hsz = grid.shape[-3:] # add row-wise position embeddings row_position_ids = torch.arange( height, dtype=torch.long, device=grid.device) # (H, ) row_position_embeddings = self.row_position_embeddings( row_position_ids) # (H, d) row_shape = (1, ) * (len(grid.shape) - 3) + (height, 1, hsz ) # (1, *1, H, 1, d) grid = grid + row_position_embeddings.view( *row_shape) # broadcast automatically # add column-wise position embeddings col_position_ids = torch.arange( width, dtype=torch.long, device=grid.device) # (W, ) col_position_embeddings = self.col_position_embeddings( col_position_ids) # (W, d) col_shape = (1, ) * (len(grid.shape) - 3) + (1, width, hsz ) # (1, *1, 1, W, d) grid = grid + col_position_embeddings.view( *col_shape) # broadcast automatically return grid class BertLayer(nn.Module): def __init__(self, config): super().__init__() self.attention = BertAttention(config) self.is_decoder = config.is_decoder if self.is_decoder: self.crossattention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, ): self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[ 1:] # add self attentions if we output attention weights if self.is_decoder and encoder_hidden_states is not None: cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[ 1:] # add cross attentions if we output attention weights intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) outputs = (layer_output, ) + outputs return outputs class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList( [BertLayer(config) for _ in range(config.num_hidden_layers)]) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, ): all_hidden_states = () all_attentions = () for i, layer_module in enumerate(self.layer): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask) hidden_states = layer_outputs[0] if self.output_attentions: all_attentions = all_attentions + (layer_outputs[1], ) # Add last layer if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) outputs = (hidden_states, ) if self.output_hidden_states: outputs = outputs + (all_hidden_states, ) if self.output_attentions: outputs = outputs + (all_attentions, ) return outputs # last-layer hidden state, (all hidden states), (all attentions) class BertSelfAttention(nn.Module): def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and\ not hasattr(config, 'embedding_size'): raise ValueError( 'The hidden size (%d) is not a multiple of the number of attention ' 'heads (%d)' % (config.hidden_size, config.num_attention_heads)) self.output_attentions = config.output_attentions self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, ): mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. if encoder_hidden_states is not None: mixed_key_layer = self.key(encoder_hidden_states) mixed_value_layer = self.value(encoder_hidden_states) attention_mask = encoder_attention_mask else: mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) # Take the dot product between "query" and "key" # to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt( self.attention_head_size) if attention_mask is not None: # Apply the attention mask is # (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + ( self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if self.output_attentions else ( context_layer, ) return outputs class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = BertLayerNorm( config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config): super().__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) self.pruned_heads = set() def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, ): self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them return outputs class ClipBertBaseModel(BertPreTrainedModel): def __init__(self, config_text, config_cross): super().__init__(config_text) self.config_text = config_text self.config_cross = config_cross self.embeddings = BertEmbeddings(config_text) self.encoder_text = BertEncoder(config_text) self.encoder_co = BertEncoder(config_cross) self.pooler = BertPooler(config_cross) self.init_weights() def forward(self, text_input_ids, visual_inputs, attention_mask): input_shape = text_input_ids.size() device = text_input_ids.device text_embedding_output = self.embeddings( input_ids=text_input_ids) # (B, Lt, D) extended_attention_mask: torch.Tensor =\ self.get_extended_attention_mask( attention_mask, input_shape, device) encoder_outputs_text = self.encoder_text( text_embedding_output, attention_mask=extended_attention_mask, head_mask=self.get_head_mask( None, self.config_text.num_hidden_layers) # required input ) sequence_output_text = encoder_outputs_text[0] bsz, hsz = visual_inputs.shape[0], visual_inputs.shape[-1] # visual_embedding_output = visual_inputs.view(bsz, -1, hsz) visual_embedding_output = visual_inputs.reshape(bsz, -1, hsz) visual_attention_mask = attention_mask.new_ones( visual_embedding_output.shape[:2]) attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=-1) # (B, lt+Lv, d) embedding_output = torch.cat( [sequence_output_text, visual_embedding_output], dim=1) # (B, Lt+Lv, d) extended_attention_mask: torch.Tensor =\ self.get_extended_attention_mask( attention_mask, input_shape, device) encoder_outputs = self.encoder_co( embedding_output, attention_mask=extended_attention_mask, head_mask=self.get_head_mask( None, self.config_cross.num_hidden_layers) # required input ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) outputs = ( sequence_output, pooled_output, ) + encoder_outputs[1:] return outputs @BACKBONES.register_module() class ClipBertClassification(BertPreTrainedModel): def __init__(self, config_text, config_cross): from transformers import BertConfig config_text = BertConfig(**config_text) config_cross = BertConfig(**config_cross) super(ClipBertClassification, self).__init__(config_text) self.bert = ClipBertBaseModel(config_text, config_cross) self.dropout = nn.Dropout(config_text.hidden_dropout_prob) self.classifier = nn.Sequential( nn.Linear(config_text.hidden_size, config_text.hidden_size * 2), nn.ReLU(True), nn.Linear(config_text.hidden_size * 2, config_cross.num_labels)) self.init_weights() def forward(self, text_input_ids, visual_inputs, text_input_mask): outputs = self.bert( text_input_ids=text_input_ids, visual_inputs=visual_inputs, attention_mask= text_input_mask, # (B, Lt) note this mask is text only!!! ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) return logits def init_weights(self, pretrained=None): if pretrained == None: self.apply(self._init_weights) else: if isinstance(pretrained, str): loaded_state_dict = torch.load(pretrained, map_location='cpu') else: loaded_state_dict = pretrained model_keys = set([k for k in list(self.state_dict().keys())]) load_keys = set(loaded_state_dict.keys()) toload = {} mismatched_shape_keys = [] for k in model_keys: k_rename = k.replace('encoder_text', 'encoder') k_rename = k_rename.replace('encoder_co', 'encoder') if k_rename in load_keys: if self.state_dict( )[k].shape != loaded_state_dict[k_rename].shape: mismatched_shape_keys.append(k) else: toload[k] = loaded_state_dict[k_rename] self.load_state_dict(toload, strict=False)